feat(backfill): implement backfill tool for generating missing embeddings
This commit is contained in:
84
README.md
84
README.md
@@ -41,6 +41,7 @@ A Go MCP server for capturing and retrieving thoughts, memory, and project conte
|
||||
| `recall_context` | Semantic + recency context block for injection |
|
||||
| `link_thoughts` | Create a typed relationship between thoughts |
|
||||
| `related_thoughts` | Explicit links + semantic neighbours |
|
||||
| `backfill_embeddings` | Generate missing embeddings for stored thoughts |
|
||||
|
||||
## Configuration
|
||||
|
||||
@@ -74,6 +75,89 @@ Alternatively, pass `client_id` and `client_secret` as body parameters instead o
|
||||
|
||||
See `llm/plan.md` for full architecture and implementation plan.
|
||||
|
||||
## Backfill
|
||||
|
||||
Run `backfill_embeddings` after switching embedding models or importing thoughts without vectors.
|
||||
|
||||
```json
|
||||
{
|
||||
"project": "optional-project-name",
|
||||
"limit": 100,
|
||||
"include_archived": false,
|
||||
"older_than_days": 0,
|
||||
"dry_run": false
|
||||
}
|
||||
```
|
||||
|
||||
- `dry_run: true` — report counts without calling the embedding provider
|
||||
- `limit` — max thoughts per call (default 100)
|
||||
- Embeddings are generated in parallel (4 workers) and upserted; one failure does not abort the run
|
||||
|
||||
**Automatic backfill** (optional, config-gated):
|
||||
|
||||
```yaml
|
||||
backfill:
|
||||
enabled: true
|
||||
run_on_startup: true # run once on server start
|
||||
interval: "15m" # repeat every 15 minutes
|
||||
batch_size: 20
|
||||
max_per_run: 100
|
||||
include_archived: false
|
||||
```
|
||||
|
||||
**Search fallback**: when no embeddings exist for the active model in scope, `search_thoughts`, `recall_context`, `get_project_context`, `summarize_thoughts`, and `related_thoughts` automatically fall back to Postgres full-text search so results are never silently empty.
|
||||
|
||||
## Client Setup
|
||||
|
||||
### Claude Code
|
||||
|
||||
```bash
|
||||
# API key auth
|
||||
claude mcp add --transport http amcs http://localhost:8080/mcp --header "x-brain-key: <key>"
|
||||
|
||||
# Bearer token auth
|
||||
claude mcp add --transport http amcs http://localhost:8080/mcp --header "Authorization: Bearer <token>"
|
||||
```
|
||||
|
||||
### OpenAI Codex
|
||||
|
||||
Add to `~/.codex/config.toml`:
|
||||
|
||||
```toml
|
||||
[[mcp_servers]]
|
||||
name = "amcs"
|
||||
url = "http://localhost:8080/mcp"
|
||||
|
||||
[mcp_servers.headers]
|
||||
x-brain-key = "<key>"
|
||||
```
|
||||
|
||||
### OpenCode
|
||||
|
||||
```bash
|
||||
# API key auth
|
||||
opencode mcp add --name amcs --type remote --url http://localhost:8080/mcp --header "x-brain-key=<key>"
|
||||
|
||||
# Bearer token auth
|
||||
opencode mcp add --name amcs --type remote --url http://localhost:8080/mcp --header "Authorization=Bearer <token>"
|
||||
```
|
||||
|
||||
Or add directly to `opencode.json` / `~/.config/opencode/config.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcp": {
|
||||
"amcs": {
|
||||
"type": "remote",
|
||||
"url": "http://localhost:8080/mcp",
|
||||
"headers": {
|
||||
"x-brain-key": "<key>"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
Run the SQL migrations against a local database with:
|
||||
|
||||
@@ -79,3 +79,11 @@ logging:
|
||||
observability:
|
||||
metrics_enabled: true
|
||||
pprof_enabled: false
|
||||
|
||||
backfill:
|
||||
enabled: false
|
||||
run_on_startup: false
|
||||
interval: "15m"
|
||||
batch_size: 20
|
||||
max_per_run: 100
|
||||
include_archived: false
|
||||
|
||||
@@ -74,6 +74,25 @@ func Run(ctx context.Context, configPath string) error {
|
||||
slog.String("provider", provider.Name()),
|
||||
)
|
||||
|
||||
if cfg.Backfill.Enabled && cfg.Backfill.RunOnStartup {
|
||||
go runBackfillPass(ctx, db, provider, cfg.Backfill, logger)
|
||||
}
|
||||
|
||||
if cfg.Backfill.Enabled && cfg.Backfill.Interval > 0 {
|
||||
go func() {
|
||||
ticker := time.NewTicker(cfg.Backfill.Interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
runBackfillPass(ctx, db, provider, cfg.Backfill, logger)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
|
||||
Handler: routes(logger, cfg, db, provider, keyring, oauthRegistry, tokenStore, authCodes, dynClients, activeProjects),
|
||||
@@ -123,6 +142,7 @@ func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.P
|
||||
Recall: tools.NewRecallTool(db, provider, cfg.Search, activeProjects),
|
||||
Summarize: tools.NewSummarizeTool(db, provider, cfg.Search, activeProjects),
|
||||
Links: tools.NewLinksTool(db, provider, cfg.Search),
|
||||
Backfill: tools.NewBackfillTool(db, provider, activeProjects, logger),
|
||||
}
|
||||
|
||||
mcpHandler := mcpserver.New(cfg.MCP, toolSet)
|
||||
@@ -136,6 +156,7 @@ func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.P
|
||||
mux.HandleFunc("/oauth/token", oauthTokenHandler(oauthRegistry, tokenStore, authCodes, logger))
|
||||
}
|
||||
mux.HandleFunc("/favicon.ico", serveFavicon)
|
||||
mux.HandleFunc("/images/project.jpg", serveHomeImage)
|
||||
mux.HandleFunc("/llm", serveLLMInstructions)
|
||||
|
||||
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -155,8 +176,49 @@ func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.P
|
||||
})
|
||||
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Method != http.MethodGet && r.Method != http.MethodHead {
|
||||
w.Header().Set("Allow", "GET, HEAD")
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
const homePage = `<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>AMCS</title>
|
||||
<style>
|
||||
body { margin: 0; font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; background: #f5f7fb; color: #172033; }
|
||||
main { max-width: 860px; margin: 48px auto; background: #fff; border-radius: 12px; box-shadow: 0 10px 28px rgba(23, 32, 51, 0.12); overflow: hidden; }
|
||||
.content { padding: 28px; }
|
||||
h1 { margin: 0 0 12px 0; font-size: 2rem; }
|
||||
p { margin: 0; line-height: 1.5; color: #334155; }
|
||||
img { display: block; width: 100%; height: auto; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<main>
|
||||
<img src="/images/project.jpg" alt="Avelon Memory Crystal project image">
|
||||
<div class="content">
|
||||
<h1>Avelon Memory Crystal Server (AMCS)</h1>
|
||||
<p>AMCS is a memory server that captures, links, and retrieves structured project thoughts for AI assistants using semantic search, summaries, and MCP tools.</p>
|
||||
</div>
|
||||
</main>
|
||||
</body>
|
||||
</html>`
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("amcs is running"))
|
||||
if r.Method == http.MethodHead {
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = w.Write([]byte(homePage))
|
||||
})
|
||||
|
||||
return observability.Chain(
|
||||
@@ -167,3 +229,39 @@ func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.P
|
||||
observability.Timeout(cfg.Server.WriteTimeout),
|
||||
)
|
||||
}
|
||||
|
||||
func runBackfillPass(ctx context.Context, db *store.DB, provider ai.Provider, cfg config.BackfillConfig, logger *slog.Logger) {
|
||||
backfiller := tools.NewBackfillTool(db, provider, nil, logger)
|
||||
_, out, err := backfiller.Handle(ctx, nil, tools.BackfillInput{
|
||||
Limit: cfg.MaxPerRun,
|
||||
IncludeArchived: cfg.IncludeArchived,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("auto backfill failed", slog.String("error", err.Error()))
|
||||
return
|
||||
}
|
||||
logger.Info("auto backfill pass",
|
||||
slog.String("model", out.Model),
|
||||
slog.Int("scanned", out.Scanned),
|
||||
slog.Int("embedded", out.Embedded),
|
||||
slog.Int("failed", out.Failed),
|
||||
)
|
||||
}
|
||||
|
||||
func serveHomeImage(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet && r.Method != http.MethodHead {
|
||||
w.Header().Set("Allow", "GET, HEAD")
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "image/jpeg")
|
||||
w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
if r.Method == http.MethodHead {
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = w.Write(homeImage)
|
||||
}
|
||||
|
||||
@@ -1,15 +1,9 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed static/favicon.ico
|
||||
faviconICO []byte
|
||||
)
|
||||
|
||||
func serveFavicon(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "image/x-icon")
|
||||
w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
|
||||
|
||||
BIN
internal/app/static/avelonmemorycrystal.jpg
Normal file
BIN
internal/app/static/avelonmemorycrystal.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 260 KiB |
24
internal/app/static_assets.go
Normal file
24
internal/app/static_assets.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed static/*
|
||||
staticFiles embed.FS
|
||||
|
||||
faviconICO = mustReadStaticFile("favicon.ico")
|
||||
homeImage = mustReadStaticFile("avelonmemorycrystal.jpg")
|
||||
)
|
||||
|
||||
func mustReadStaticFile(name string) []byte {
|
||||
data, err := fs.ReadFile(staticFiles, "static/"+name)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to read embedded static file %q: %v", name, err))
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
@@ -17,6 +17,7 @@ type Config struct {
|
||||
Search SearchConfig `yaml:"search"`
|
||||
Logging LoggingConfig `yaml:"logging"`
|
||||
Observability ObservabilityConfig `yaml:"observability"`
|
||||
Backfill BackfillConfig `yaml:"backfill"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
@@ -135,3 +136,12 @@ type ObservabilityConfig struct {
|
||||
MetricsEnabled bool `yaml:"metrics_enabled"`
|
||||
PprofEnabled bool `yaml:"pprof_enabled"`
|
||||
}
|
||||
|
||||
type BackfillConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
RunOnStartup bool `yaml:"run_on_startup"`
|
||||
Interval time.Duration `yaml:"interval"`
|
||||
BatchSize int `yaml:"batch_size"`
|
||||
MaxPerRun int `yaml:"max_per_run"`
|
||||
IncludeArchived bool `yaml:"include_archived"`
|
||||
}
|
||||
|
||||
@@ -95,6 +95,13 @@ func defaultConfig() Config {
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
},
|
||||
Backfill: BackfillConfig{
|
||||
Enabled: false,
|
||||
RunOnStartup: false,
|
||||
Interval: 15 * time.Minute,
|
||||
BatchSize: 20,
|
||||
MaxPerRun: 100,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -81,5 +81,14 @@ func (c Config) Validate() error {
|
||||
return fmt.Errorf("invalid config: logging.level is required")
|
||||
}
|
||||
|
||||
if c.Backfill.Enabled {
|
||||
if c.Backfill.BatchSize <= 0 {
|
||||
return fmt.Errorf("invalid config: backfill.batch_size must be greater than zero when backfill is enabled")
|
||||
}
|
||||
if c.Backfill.MaxPerRun < c.Backfill.BatchSize {
|
||||
return fmt.Errorf("invalid config: backfill.max_per_run must be >= backfill.batch_size")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ type ToolSet struct {
|
||||
Recall *tools.RecallTool
|
||||
Summarize *tools.SummarizeTool
|
||||
Links *tools.LinksTool
|
||||
Backfill *tools.BackfillTool
|
||||
}
|
||||
|
||||
func New(cfg config.MCPConfig, toolSet ToolSet) http.Handler {
|
||||
@@ -117,6 +118,11 @@ func New(cfg config.MCPConfig, toolSet ToolSet) http.Handler {
|
||||
Description: "Retrieve explicit links and semantic neighbors for a thought.",
|
||||
}, toolSet.Links.Related)
|
||||
|
||||
addTool(server, &mcp.Tool{
|
||||
Name: "backfill_embeddings",
|
||||
Description: "Generate missing embeddings for stored thoughts using the active embedding model.",
|
||||
}, toolSet.Backfill.Handle)
|
||||
|
||||
return mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server {
|
||||
return server
|
||||
}, &mcp.StreamableHTTPOptions{
|
||||
|
||||
@@ -358,6 +358,140 @@ func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, em
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (db *DB) HasEmbeddingsForModel(ctx context.Context, model string, projectID *uuid.UUID) (bool, error) {
|
||||
args := []any{model}
|
||||
conditions := []string{
|
||||
"e.model = $1",
|
||||
"t.archived_at is null",
|
||||
}
|
||||
if projectID != nil {
|
||||
args = append(args, *projectID)
|
||||
conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args)))
|
||||
}
|
||||
|
||||
query := `select exists(select 1 from embeddings e join thoughts t on t.guid = e.thought_id where ` +
|
||||
strings.Join(conditions, " and ") + `)`
|
||||
|
||||
var exists bool
|
||||
if err := db.pool.QueryRow(ctx, query, args...).Scan(&exists); err != nil {
|
||||
return false, fmt.Errorf("check embeddings for model: %w", err)
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func (db *DB) ListThoughtsMissingEmbedding(ctx context.Context, model string, limit int, projectID *uuid.UUID, includeArchived bool, olderThanDays int) ([]thoughttypes.Thought, error) {
|
||||
args := []any{model}
|
||||
conditions := []string{"e.id is null"}
|
||||
|
||||
if !includeArchived {
|
||||
conditions = append(conditions, "t.archived_at is null")
|
||||
}
|
||||
if projectID != nil {
|
||||
args = append(args, *projectID)
|
||||
conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args)))
|
||||
}
|
||||
if olderThanDays > 0 {
|
||||
args = append(args, time.Now().Add(-time.Duration(olderThanDays)*24*time.Hour))
|
||||
conditions = append(conditions, fmt.Sprintf("t.created_at < $%d", len(args)))
|
||||
}
|
||||
args = append(args, limit)
|
||||
|
||||
query := `
|
||||
select t.guid, t.content, t.metadata, t.project_id, t.archived_at, t.created_at, t.updated_at
|
||||
from thoughts t
|
||||
left join embeddings e on e.thought_id = t.guid and e.model = $1
|
||||
where ` + strings.Join(conditions, " and ") + `
|
||||
order by t.created_at asc
|
||||
limit $` + fmt.Sprintf("%d", len(args))
|
||||
|
||||
rows, err := db.pool.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list thoughts missing embedding: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
thoughts := make([]thoughttypes.Thought, 0, limit)
|
||||
for rows.Next() {
|
||||
var thought thoughttypes.Thought
|
||||
var metadataBytes []byte
|
||||
if err := rows.Scan(&thought.ID, &thought.Content, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan missing-embedding thought: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil {
|
||||
return nil, fmt.Errorf("decode missing-embedding metadata: %w", err)
|
||||
}
|
||||
thoughts = append(thoughts, thought)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate missing-embedding thoughts: %w", err)
|
||||
}
|
||||
return thoughts, nil
|
||||
}
|
||||
|
||||
func (db *DB) UpsertEmbedding(ctx context.Context, thoughtID uuid.UUID, model string, embedding []float32) error {
|
||||
_, err := db.pool.Exec(ctx, `
|
||||
insert into embeddings (thought_id, model, dim, embedding)
|
||||
values ($1, $2, $3, $4)
|
||||
on conflict (thought_id, model) do update
|
||||
set embedding = excluded.embedding,
|
||||
dim = excluded.dim,
|
||||
updated_at = now()
|
||||
`, thoughtID, model, len(embedding), pgvector.NewVector(embedding))
|
||||
if err != nil {
|
||||
return fmt.Errorf("upsert embedding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) SearchThoughtsText(ctx context.Context, query string, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) {
|
||||
args := []any{query}
|
||||
conditions := []string{
|
||||
"t.archived_at is null",
|
||||
"to_tsvector('simple', t.content) @@ websearch_to_tsquery('simple', $1)",
|
||||
}
|
||||
if projectID != nil {
|
||||
args = append(args, *projectID)
|
||||
conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args)))
|
||||
}
|
||||
if excludeID != nil {
|
||||
args = append(args, *excludeID)
|
||||
conditions = append(conditions, fmt.Sprintf("t.guid <> $%d", len(args)))
|
||||
}
|
||||
args = append(args, limit)
|
||||
|
||||
q := `
|
||||
select t.guid, t.content, t.metadata,
|
||||
ts_rank_cd(to_tsvector('simple', t.content), websearch_to_tsquery('simple', $1)) as similarity,
|
||||
t.created_at
|
||||
from thoughts t
|
||||
where ` + strings.Join(conditions, " and ") + `
|
||||
order by similarity desc
|
||||
limit $` + fmt.Sprintf("%d", len(args))
|
||||
|
||||
rows, err := db.pool.Query(ctx, q, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("text search thoughts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
results := make([]thoughttypes.SearchResult, 0, limit)
|
||||
for rows.Next() {
|
||||
var result thoughttypes.SearchResult
|
||||
var metadataBytes []byte
|
||||
if err := rows.Scan(&result.ID, &result.Content, &metadataBytes, &result.Similarity, &result.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan text search result: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(metadataBytes, &result.Metadata); err != nil {
|
||||
return nil, fmt.Errorf("decode text search metadata: %w", err)
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate text search results: %w", err)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func topCounts(in map[string]int, limit int) []thoughttypes.KeyCount {
|
||||
type kv struct {
|
||||
key string
|
||||
|
||||
141
internal/tools/backfill.go
Normal file
141
internal/tools/backfill.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai"
|
||||
"git.warky.dev/wdevs/amcs/internal/session"
|
||||
"git.warky.dev/wdevs/amcs/internal/store"
|
||||
)
|
||||
|
||||
const backfillConcurrency = 4
|
||||
|
||||
type BackfillTool struct {
|
||||
store *store.DB
|
||||
provider ai.Provider
|
||||
sessions *session.ActiveProjects
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
type BackfillInput struct {
|
||||
Project string `json:"project,omitempty" jsonschema:"optional project name or id to scope the backfill"`
|
||||
Limit int `json:"limit,omitempty" jsonschema:"maximum number of thoughts to process in one call; defaults to 100"`
|
||||
IncludeArchived bool `json:"include_archived,omitempty" jsonschema:"whether to include archived thoughts; defaults to false"`
|
||||
OlderThanDays int `json:"older_than_days,omitempty" jsonschema:"only backfill thoughts older than N days; 0 means no restriction"`
|
||||
DryRun bool `json:"dry_run,omitempty" jsonschema:"report counts and sample ids without generating embeddings"`
|
||||
}
|
||||
|
||||
type BackfillFailure struct {
|
||||
ID string `json:"id"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type BackfillOutput struct {
|
||||
Model string `json:"model"`
|
||||
Scanned int `json:"scanned"`
|
||||
Embedded int `json:"embedded"`
|
||||
Skipped int `json:"skipped"`
|
||||
Failed int `json:"failed"`
|
||||
DryRun bool `json:"dry_run"`
|
||||
Failures []BackfillFailure `json:"failures,omitempty"`
|
||||
}
|
||||
|
||||
func NewBackfillTool(db *store.DB, provider ai.Provider, sessions *session.ActiveProjects, logger *slog.Logger) *BackfillTool {
|
||||
return &BackfillTool{store: db, provider: provider, sessions: sessions, logger: logger}
|
||||
}
|
||||
|
||||
func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in BackfillInput) (*mcp.CallToolResult, BackfillOutput, error) {
|
||||
limit := in.Limit
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
|
||||
if err != nil {
|
||||
return nil, BackfillOutput{}, err
|
||||
}
|
||||
|
||||
var projectID *uuid.UUID
|
||||
if project != nil {
|
||||
projectID = &project.ID
|
||||
}
|
||||
|
||||
model := t.provider.EmbeddingModel()
|
||||
|
||||
thoughts, err := t.store.ListThoughtsMissingEmbedding(ctx, model, limit, projectID, in.IncludeArchived, in.OlderThanDays)
|
||||
if err != nil {
|
||||
return nil, BackfillOutput{}, err
|
||||
}
|
||||
|
||||
out := BackfillOutput{
|
||||
Model: model,
|
||||
Scanned: len(thoughts),
|
||||
DryRun: in.DryRun,
|
||||
}
|
||||
|
||||
if in.DryRun || len(thoughts) == 0 {
|
||||
return nil, out, nil
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
sem := semaphore.NewWeighted(backfillConcurrency)
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for _, thought := range thoughts {
|
||||
if ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
if err := sem.Acquire(ctx, 1); err != nil {
|
||||
break
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(id uuid.UUID, content string) {
|
||||
defer wg.Done()
|
||||
defer sem.Release(1)
|
||||
|
||||
vec, embedErr := t.provider.Embed(ctx, content)
|
||||
if embedErr != nil {
|
||||
mu.Lock()
|
||||
out.Failures = append(out.Failures, BackfillFailure{ID: id.String(), Error: embedErr.Error()})
|
||||
mu.Unlock()
|
||||
t.logger.Warn("backfill embed failed", slog.String("thought_id", id.String()), slog.String("error", embedErr.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if upsertErr := t.store.UpsertEmbedding(ctx, id, model, vec); upsertErr != nil {
|
||||
mu.Lock()
|
||||
out.Failures = append(out.Failures, BackfillFailure{ID: id.String(), Error: upsertErr.Error()})
|
||||
mu.Unlock()
|
||||
t.logger.Warn("backfill upsert failed", slog.String("thought_id", id.String()), slog.String("error", upsertErr.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
out.Embedded++
|
||||
mu.Unlock()
|
||||
}(thought.ID, thought.Content)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
out.Failed = len(out.Failures)
|
||||
out.Skipped = out.Scanned - out.Embedded - out.Failed
|
||||
|
||||
t.logger.Info("backfill completed",
|
||||
slog.String("model", model),
|
||||
slog.Int("scanned", out.Scanned),
|
||||
slog.Int("embedded", out.Embedded),
|
||||
slog.Int("failed", out.Failed),
|
||||
slog.Duration("duration", time.Since(start)),
|
||||
)
|
||||
|
||||
return nil, out, nil
|
||||
}
|
||||
@@ -72,11 +72,7 @@ func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in P
|
||||
|
||||
query := strings.TrimSpace(in.Query)
|
||||
if query != "" {
|
||||
embedding, err := t.provider.Embed(ctx, query)
|
||||
if err != nil {
|
||||
return nil, ProjectContextOutput{}, err
|
||||
}
|
||||
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, &project.ID, nil)
|
||||
semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, &project.ID, nil)
|
||||
if err != nil {
|
||||
return nil, ProjectContextOutput{}, err
|
||||
}
|
||||
|
||||
@@ -117,11 +117,7 @@ func (t *LinksTool) Related(ctx context.Context, _ *mcp.CallToolRequest, in Rela
|
||||
}
|
||||
|
||||
if includeSemantic {
|
||||
embedding, err := t.provider.Embed(ctx, thought.Content)
|
||||
if err != nil {
|
||||
return nil, RelatedOutput{}, err
|
||||
}
|
||||
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, t.search.DefaultLimit, thought.ProjectID, &thought.ID)
|
||||
semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, thought.Content, t.search.DefaultLimit, t.search.DefaultThreshold, thought.ProjectID, &thought.ID)
|
||||
if err != nil {
|
||||
return nil, RelatedOutput{}, err
|
||||
}
|
||||
|
||||
@@ -48,17 +48,13 @@ func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Re
|
||||
}
|
||||
|
||||
limit := normalizeLimit(in.Limit, t.search)
|
||||
embedding, err := t.provider.Embed(ctx, query)
|
||||
if err != nil {
|
||||
return nil, RecallOutput{}, err
|
||||
}
|
||||
|
||||
var projectID *uuid.UUID
|
||||
if project != nil {
|
||||
projectID = &project.ID
|
||||
}
|
||||
|
||||
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, projectID, nil)
|
||||
semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
|
||||
if err != nil {
|
||||
return nil, RecallOutput{}, err
|
||||
}
|
||||
|
||||
41
internal/tools/retrieval.go
Normal file
41
internal/tools/retrieval.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
"git.warky.dev/wdevs/amcs/internal/store"
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
// semanticSearch runs vector similarity search if embeddings exist for the active model
|
||||
// in the given scope, otherwise falls back to Postgres full-text search.
|
||||
func semanticSearch(
|
||||
ctx context.Context,
|
||||
db *store.DB,
|
||||
provider ai.Provider,
|
||||
search config.SearchConfig,
|
||||
query string,
|
||||
limit int,
|
||||
threshold float64,
|
||||
projectID *uuid.UUID,
|
||||
excludeID *uuid.UUID,
|
||||
) ([]thoughttypes.SearchResult, error) {
|
||||
hasEmbeddings, err := db.HasEmbeddingsForModel(ctx, provider.EmbeddingModel(), projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if hasEmbeddings {
|
||||
embedding, err := provider.Embed(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db.SearchSimilarThoughts(ctx, embedding, provider.EmbeddingModel(), threshold, limit, projectID, excludeID)
|
||||
}
|
||||
|
||||
return db.SearchThoughtsText(ctx, query, limit, projectID, excludeID)
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai"
|
||||
@@ -44,24 +45,18 @@ func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Se
|
||||
limit := normalizeLimit(in.Limit, t.search)
|
||||
threshold := normalizeThreshold(in.Threshold, t.search.DefaultThreshold)
|
||||
|
||||
embedding, err := t.provider.Embed(ctx, query)
|
||||
if err != nil {
|
||||
return nil, SearchOutput{}, err
|
||||
}
|
||||
|
||||
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
|
||||
if err != nil {
|
||||
return nil, SearchOutput{}, err
|
||||
}
|
||||
|
||||
model := t.provider.EmbeddingModel()
|
||||
var results []thoughttypes.SearchResult
|
||||
var projectID *uuid.UUID
|
||||
if project != nil {
|
||||
results, err = t.store.SearchSimilarThoughts(ctx, embedding, model, threshold, limit, &project.ID, nil)
|
||||
projectID = &project.ID
|
||||
_ = t.store.TouchProject(ctx, project.ID)
|
||||
} else {
|
||||
results, err = t.store.SearchThoughts(ctx, embedding, model, threshold, limit, map[string]any{})
|
||||
}
|
||||
|
||||
results, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, threshold, projectID, nil)
|
||||
if err != nil {
|
||||
return nil, SearchOutput{}, err
|
||||
}
|
||||
|
||||
@@ -48,15 +48,11 @@ func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
|
||||
count := 0
|
||||
|
||||
if query != "" {
|
||||
embedding, err := t.provider.Embed(ctx, query)
|
||||
if err != nil {
|
||||
return nil, SummarizeOutput{}, err
|
||||
}
|
||||
var projectID *uuid.UUID
|
||||
if project != nil {
|
||||
projectID = &project.ID
|
||||
}
|
||||
results, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, projectID, nil)
|
||||
results, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
|
||||
if err != nil {
|
||||
return nil, SummarizeOutput{}, err
|
||||
}
|
||||
|
||||
3
migrations/010_fulltext_index.sql
Normal file
3
migrations/010_fulltext_index.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
-- Full-text search index on thought content for semantic fallback when no embeddings exist.
|
||||
create index if not exists thoughts_content_fts_idx
|
||||
on thoughts using gin(to_tsvector('simple', content));
|
||||
Reference in New Issue
Block a user