feat(backfill): implement backfill tool for generating missing embeddings
This commit is contained in:
84
README.md
84
README.md
@@ -41,6 +41,7 @@ A Go MCP server for capturing and retrieving thoughts, memory, and project conte
|
|||||||
| `recall_context` | Semantic + recency context block for injection |
|
| `recall_context` | Semantic + recency context block for injection |
|
||||||
| `link_thoughts` | Create a typed relationship between thoughts |
|
| `link_thoughts` | Create a typed relationship between thoughts |
|
||||||
| `related_thoughts` | Explicit links + semantic neighbours |
|
| `related_thoughts` | Explicit links + semantic neighbours |
|
||||||
|
| `backfill_embeddings` | Generate missing embeddings for stored thoughts |
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
@@ -74,6 +75,89 @@ Alternatively, pass `client_id` and `client_secret` as body parameters instead o
|
|||||||
|
|
||||||
See `llm/plan.md` for full architecture and implementation plan.
|
See `llm/plan.md` for full architecture and implementation plan.
|
||||||
|
|
||||||
|
## Backfill
|
||||||
|
|
||||||
|
Run `backfill_embeddings` after switching embedding models or importing thoughts without vectors.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"project": "optional-project-name",
|
||||||
|
"limit": 100,
|
||||||
|
"include_archived": false,
|
||||||
|
"older_than_days": 0,
|
||||||
|
"dry_run": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
- `dry_run: true` — report counts without calling the embedding provider
|
||||||
|
- `limit` — max thoughts per call (default 100)
|
||||||
|
- Embeddings are generated in parallel (4 workers) and upserted; one failure does not abort the run
|
||||||
|
|
||||||
|
**Automatic backfill** (optional, config-gated):
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
backfill:
|
||||||
|
enabled: true
|
||||||
|
run_on_startup: true # run once on server start
|
||||||
|
interval: "15m" # repeat every 15 minutes
|
||||||
|
batch_size: 20
|
||||||
|
max_per_run: 100
|
||||||
|
include_archived: false
|
||||||
|
```
|
||||||
|
|
||||||
|
**Search fallback**: when no embeddings exist for the active model in scope, `search_thoughts`, `recall_context`, `get_project_context`, `summarize_thoughts`, and `related_thoughts` automatically fall back to Postgres full-text search so results are never silently empty.
|
||||||
|
|
||||||
|
## Client Setup
|
||||||
|
|
||||||
|
### Claude Code
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# API key auth
|
||||||
|
claude mcp add --transport http amcs http://localhost:8080/mcp --header "x-brain-key: <key>"
|
||||||
|
|
||||||
|
# Bearer token auth
|
||||||
|
claude mcp add --transport http amcs http://localhost:8080/mcp --header "Authorization: Bearer <token>"
|
||||||
|
```
|
||||||
|
|
||||||
|
### OpenAI Codex
|
||||||
|
|
||||||
|
Add to `~/.codex/config.toml`:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[[mcp_servers]]
|
||||||
|
name = "amcs"
|
||||||
|
url = "http://localhost:8080/mcp"
|
||||||
|
|
||||||
|
[mcp_servers.headers]
|
||||||
|
x-brain-key = "<key>"
|
||||||
|
```
|
||||||
|
|
||||||
|
### OpenCode
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# API key auth
|
||||||
|
opencode mcp add --name amcs --type remote --url http://localhost:8080/mcp --header "x-brain-key=<key>"
|
||||||
|
|
||||||
|
# Bearer token auth
|
||||||
|
opencode mcp add --name amcs --type remote --url http://localhost:8080/mcp --header "Authorization=Bearer <token>"
|
||||||
|
```
|
||||||
|
|
||||||
|
Or add directly to `opencode.json` / `~/.config/opencode/config.json`:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcp": {
|
||||||
|
"amcs": {
|
||||||
|
"type": "remote",
|
||||||
|
"url": "http://localhost:8080/mcp",
|
||||||
|
"headers": {
|
||||||
|
"x-brain-key": "<key>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Development
|
## Development
|
||||||
|
|
||||||
Run the SQL migrations against a local database with:
|
Run the SQL migrations against a local database with:
|
||||||
|
|||||||
@@ -79,3 +79,11 @@ logging:
|
|||||||
observability:
|
observability:
|
||||||
metrics_enabled: true
|
metrics_enabled: true
|
||||||
pprof_enabled: false
|
pprof_enabled: false
|
||||||
|
|
||||||
|
backfill:
|
||||||
|
enabled: false
|
||||||
|
run_on_startup: false
|
||||||
|
interval: "15m"
|
||||||
|
batch_size: 20
|
||||||
|
max_per_run: 100
|
||||||
|
include_archived: false
|
||||||
|
|||||||
@@ -74,6 +74,25 @@ func Run(ctx context.Context, configPath string) error {
|
|||||||
slog.String("provider", provider.Name()),
|
slog.String("provider", provider.Name()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.Backfill.Enabled && cfg.Backfill.RunOnStartup {
|
||||||
|
go runBackfillPass(ctx, db, provider, cfg.Backfill, logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Backfill.Enabled && cfg.Backfill.Interval > 0 {
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(cfg.Backfill.Interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
runBackfillPass(ctx, db, provider, cfg.Backfill, logger)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
server := &http.Server{
|
server := &http.Server{
|
||||||
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
|
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
|
||||||
Handler: routes(logger, cfg, db, provider, keyring, oauthRegistry, tokenStore, authCodes, dynClients, activeProjects),
|
Handler: routes(logger, cfg, db, provider, keyring, oauthRegistry, tokenStore, authCodes, dynClients, activeProjects),
|
||||||
@@ -123,6 +142,7 @@ func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.P
|
|||||||
Recall: tools.NewRecallTool(db, provider, cfg.Search, activeProjects),
|
Recall: tools.NewRecallTool(db, provider, cfg.Search, activeProjects),
|
||||||
Summarize: tools.NewSummarizeTool(db, provider, cfg.Search, activeProjects),
|
Summarize: tools.NewSummarizeTool(db, provider, cfg.Search, activeProjects),
|
||||||
Links: tools.NewLinksTool(db, provider, cfg.Search),
|
Links: tools.NewLinksTool(db, provider, cfg.Search),
|
||||||
|
Backfill: tools.NewBackfillTool(db, provider, activeProjects, logger),
|
||||||
}
|
}
|
||||||
|
|
||||||
mcpHandler := mcpserver.New(cfg.MCP, toolSet)
|
mcpHandler := mcpserver.New(cfg.MCP, toolSet)
|
||||||
@@ -136,6 +156,7 @@ func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.P
|
|||||||
mux.HandleFunc("/oauth/token", oauthTokenHandler(oauthRegistry, tokenStore, authCodes, logger))
|
mux.HandleFunc("/oauth/token", oauthTokenHandler(oauthRegistry, tokenStore, authCodes, logger))
|
||||||
}
|
}
|
||||||
mux.HandleFunc("/favicon.ico", serveFavicon)
|
mux.HandleFunc("/favicon.ico", serveFavicon)
|
||||||
|
mux.HandleFunc("/images/project.jpg", serveHomeImage)
|
||||||
mux.HandleFunc("/llm", serveLLMInstructions)
|
mux.HandleFunc("/llm", serveLLMInstructions)
|
||||||
|
|
||||||
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -155,8 +176,49 @@ func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.P
|
|||||||
})
|
})
|
||||||
|
|
||||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/" {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Method != http.MethodGet && r.Method != http.MethodHead {
|
||||||
|
w.Header().Set("Allow", "GET, HEAD")
|
||||||
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const homePage = `<!doctype html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||||
|
<title>AMCS</title>
|
||||||
|
<style>
|
||||||
|
body { margin: 0; font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; background: #f5f7fb; color: #172033; }
|
||||||
|
main { max-width: 860px; margin: 48px auto; background: #fff; border-radius: 12px; box-shadow: 0 10px 28px rgba(23, 32, 51, 0.12); overflow: hidden; }
|
||||||
|
.content { padding: 28px; }
|
||||||
|
h1 { margin: 0 0 12px 0; font-size: 2rem; }
|
||||||
|
p { margin: 0; line-height: 1.5; color: #334155; }
|
||||||
|
img { display: block; width: 100%; height: auto; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<main>
|
||||||
|
<img src="/images/project.jpg" alt="Avelon Memory Crystal project image">
|
||||||
|
<div class="content">
|
||||||
|
<h1>Avelon Memory Crystal Server (AMCS)</h1>
|
||||||
|
<p>AMCS is a memory server that captures, links, and retrieves structured project thoughts for AI assistants using semantic search, summaries, and MCP tools.</p>
|
||||||
|
</div>
|
||||||
|
</main>
|
||||||
|
</body>
|
||||||
|
</html>`
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
_, _ = w.Write([]byte("amcs is running"))
|
if r.Method == http.MethodHead {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = w.Write([]byte(homePage))
|
||||||
})
|
})
|
||||||
|
|
||||||
return observability.Chain(
|
return observability.Chain(
|
||||||
@@ -167,3 +229,39 @@ func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.P
|
|||||||
observability.Timeout(cfg.Server.WriteTimeout),
|
observability.Timeout(cfg.Server.WriteTimeout),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func runBackfillPass(ctx context.Context, db *store.DB, provider ai.Provider, cfg config.BackfillConfig, logger *slog.Logger) {
|
||||||
|
backfiller := tools.NewBackfillTool(db, provider, nil, logger)
|
||||||
|
_, out, err := backfiller.Handle(ctx, nil, tools.BackfillInput{
|
||||||
|
Limit: cfg.MaxPerRun,
|
||||||
|
IncludeArchived: cfg.IncludeArchived,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("auto backfill failed", slog.String("error", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Info("auto backfill pass",
|
||||||
|
slog.String("model", out.Model),
|
||||||
|
slog.Int("scanned", out.Scanned),
|
||||||
|
slog.Int("embedded", out.Embedded),
|
||||||
|
slog.Int("failed", out.Failed),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func serveHomeImage(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet && r.Method != http.MethodHead {
|
||||||
|
w.Header().Set("Allow", "GET, HEAD")
|
||||||
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "image/jpeg")
|
||||||
|
w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
if r.Method == http.MethodHead {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = w.Write(homeImage)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,15 +1,9 @@
|
|||||||
package app
|
package app
|
||||||
|
|
||||||
import (
|
import (
|
||||||
_ "embed"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
//go:embed static/favicon.ico
|
|
||||||
faviconICO []byte
|
|
||||||
)
|
|
||||||
|
|
||||||
func serveFavicon(w http.ResponseWriter, r *http.Request) {
|
func serveFavicon(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "image/x-icon")
|
w.Header().Set("Content-Type", "image/x-icon")
|
||||||
w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
|
w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
|
||||||
|
|||||||
BIN
internal/app/static/avelonmemorycrystal.jpg
Normal file
BIN
internal/app/static/avelonmemorycrystal.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 260 KiB |
24
internal/app/static_assets.go
Normal file
24
internal/app/static_assets.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"embed"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
//go:embed static/*
|
||||||
|
staticFiles embed.FS
|
||||||
|
|
||||||
|
faviconICO = mustReadStaticFile("favicon.ico")
|
||||||
|
homeImage = mustReadStaticFile("avelonmemorycrystal.jpg")
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustReadStaticFile(name string) []byte {
|
||||||
|
data, err := fs.ReadFile(staticFiles, "static/"+name)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("failed to read embedded static file %q: %v", name, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
}
|
||||||
@@ -17,6 +17,7 @@ type Config struct {
|
|||||||
Search SearchConfig `yaml:"search"`
|
Search SearchConfig `yaml:"search"`
|
||||||
Logging LoggingConfig `yaml:"logging"`
|
Logging LoggingConfig `yaml:"logging"`
|
||||||
Observability ObservabilityConfig `yaml:"observability"`
|
Observability ObservabilityConfig `yaml:"observability"`
|
||||||
|
Backfill BackfillConfig `yaml:"backfill"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
@@ -135,3 +136,12 @@ type ObservabilityConfig struct {
|
|||||||
MetricsEnabled bool `yaml:"metrics_enabled"`
|
MetricsEnabled bool `yaml:"metrics_enabled"`
|
||||||
PprofEnabled bool `yaml:"pprof_enabled"`
|
PprofEnabled bool `yaml:"pprof_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type BackfillConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
RunOnStartup bool `yaml:"run_on_startup"`
|
||||||
|
Interval time.Duration `yaml:"interval"`
|
||||||
|
BatchSize int `yaml:"batch_size"`
|
||||||
|
MaxPerRun int `yaml:"max_per_run"`
|
||||||
|
IncludeArchived bool `yaml:"include_archived"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -95,6 +95,13 @@ func defaultConfig() Config {
|
|||||||
Level: "info",
|
Level: "info",
|
||||||
Format: "json",
|
Format: "json",
|
||||||
},
|
},
|
||||||
|
Backfill: BackfillConfig{
|
||||||
|
Enabled: false,
|
||||||
|
RunOnStartup: false,
|
||||||
|
Interval: 15 * time.Minute,
|
||||||
|
BatchSize: 20,
|
||||||
|
MaxPerRun: 100,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -81,5 +81,14 @@ func (c Config) Validate() error {
|
|||||||
return fmt.Errorf("invalid config: logging.level is required")
|
return fmt.Errorf("invalid config: logging.level is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.Backfill.Enabled {
|
||||||
|
if c.Backfill.BatchSize <= 0 {
|
||||||
|
return fmt.Errorf("invalid config: backfill.batch_size must be greater than zero when backfill is enabled")
|
||||||
|
}
|
||||||
|
if c.Backfill.MaxPerRun < c.Backfill.BatchSize {
|
||||||
|
return fmt.Errorf("invalid config: backfill.max_per_run must be >= backfill.batch_size")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ type ToolSet struct {
|
|||||||
Recall *tools.RecallTool
|
Recall *tools.RecallTool
|
||||||
Summarize *tools.SummarizeTool
|
Summarize *tools.SummarizeTool
|
||||||
Links *tools.LinksTool
|
Links *tools.LinksTool
|
||||||
|
Backfill *tools.BackfillTool
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(cfg config.MCPConfig, toolSet ToolSet) http.Handler {
|
func New(cfg config.MCPConfig, toolSet ToolSet) http.Handler {
|
||||||
@@ -117,6 +118,11 @@ func New(cfg config.MCPConfig, toolSet ToolSet) http.Handler {
|
|||||||
Description: "Retrieve explicit links and semantic neighbors for a thought.",
|
Description: "Retrieve explicit links and semantic neighbors for a thought.",
|
||||||
}, toolSet.Links.Related)
|
}, toolSet.Links.Related)
|
||||||
|
|
||||||
|
addTool(server, &mcp.Tool{
|
||||||
|
Name: "backfill_embeddings",
|
||||||
|
Description: "Generate missing embeddings for stored thoughts using the active embedding model.",
|
||||||
|
}, toolSet.Backfill.Handle)
|
||||||
|
|
||||||
return mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server {
|
return mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server {
|
||||||
return server
|
return server
|
||||||
}, &mcp.StreamableHTTPOptions{
|
}, &mcp.StreamableHTTPOptions{
|
||||||
|
|||||||
@@ -358,6 +358,140 @@ func (db *DB) SearchSimilarThoughts(ctx context.Context, embedding []float32, em
|
|||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db *DB) HasEmbeddingsForModel(ctx context.Context, model string, projectID *uuid.UUID) (bool, error) {
|
||||||
|
args := []any{model}
|
||||||
|
conditions := []string{
|
||||||
|
"e.model = $1",
|
||||||
|
"t.archived_at is null",
|
||||||
|
}
|
||||||
|
if projectID != nil {
|
||||||
|
args = append(args, *projectID)
|
||||||
|
conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args)))
|
||||||
|
}
|
||||||
|
|
||||||
|
query := `select exists(select 1 from embeddings e join thoughts t on t.guid = e.thought_id where ` +
|
||||||
|
strings.Join(conditions, " and ") + `)`
|
||||||
|
|
||||||
|
var exists bool
|
||||||
|
if err := db.pool.QueryRow(ctx, query, args...).Scan(&exists); err != nil {
|
||||||
|
return false, fmt.Errorf("check embeddings for model: %w", err)
|
||||||
|
}
|
||||||
|
return exists, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) ListThoughtsMissingEmbedding(ctx context.Context, model string, limit int, projectID *uuid.UUID, includeArchived bool, olderThanDays int) ([]thoughttypes.Thought, error) {
|
||||||
|
args := []any{model}
|
||||||
|
conditions := []string{"e.id is null"}
|
||||||
|
|
||||||
|
if !includeArchived {
|
||||||
|
conditions = append(conditions, "t.archived_at is null")
|
||||||
|
}
|
||||||
|
if projectID != nil {
|
||||||
|
args = append(args, *projectID)
|
||||||
|
conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args)))
|
||||||
|
}
|
||||||
|
if olderThanDays > 0 {
|
||||||
|
args = append(args, time.Now().Add(-time.Duration(olderThanDays)*24*time.Hour))
|
||||||
|
conditions = append(conditions, fmt.Sprintf("t.created_at < $%d", len(args)))
|
||||||
|
}
|
||||||
|
args = append(args, limit)
|
||||||
|
|
||||||
|
query := `
|
||||||
|
select t.guid, t.content, t.metadata, t.project_id, t.archived_at, t.created_at, t.updated_at
|
||||||
|
from thoughts t
|
||||||
|
left join embeddings e on e.thought_id = t.guid and e.model = $1
|
||||||
|
where ` + strings.Join(conditions, " and ") + `
|
||||||
|
order by t.created_at asc
|
||||||
|
limit $` + fmt.Sprintf("%d", len(args))
|
||||||
|
|
||||||
|
rows, err := db.pool.Query(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("list thoughts missing embedding: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
thoughts := make([]thoughttypes.Thought, 0, limit)
|
||||||
|
for rows.Next() {
|
||||||
|
var thought thoughttypes.Thought
|
||||||
|
var metadataBytes []byte
|
||||||
|
if err := rows.Scan(&thought.ID, &thought.Content, &metadataBytes, &thought.ProjectID, &thought.ArchivedAt, &thought.CreatedAt, &thought.UpdatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan missing-embedding thought: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(metadataBytes, &thought.Metadata); err != nil {
|
||||||
|
return nil, fmt.Errorf("decode missing-embedding metadata: %w", err)
|
||||||
|
}
|
||||||
|
thoughts = append(thoughts, thought)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("iterate missing-embedding thoughts: %w", err)
|
||||||
|
}
|
||||||
|
return thoughts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) UpsertEmbedding(ctx context.Context, thoughtID uuid.UUID, model string, embedding []float32) error {
|
||||||
|
_, err := db.pool.Exec(ctx, `
|
||||||
|
insert into embeddings (thought_id, model, dim, embedding)
|
||||||
|
values ($1, $2, $3, $4)
|
||||||
|
on conflict (thought_id, model) do update
|
||||||
|
set embedding = excluded.embedding,
|
||||||
|
dim = excluded.dim,
|
||||||
|
updated_at = now()
|
||||||
|
`, thoughtID, model, len(embedding), pgvector.NewVector(embedding))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("upsert embedding: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) SearchThoughtsText(ctx context.Context, query string, limit int, projectID *uuid.UUID, excludeID *uuid.UUID) ([]thoughttypes.SearchResult, error) {
|
||||||
|
args := []any{query}
|
||||||
|
conditions := []string{
|
||||||
|
"t.archived_at is null",
|
||||||
|
"to_tsvector('simple', t.content) @@ websearch_to_tsquery('simple', $1)",
|
||||||
|
}
|
||||||
|
if projectID != nil {
|
||||||
|
args = append(args, *projectID)
|
||||||
|
conditions = append(conditions, fmt.Sprintf("t.project_id = $%d", len(args)))
|
||||||
|
}
|
||||||
|
if excludeID != nil {
|
||||||
|
args = append(args, *excludeID)
|
||||||
|
conditions = append(conditions, fmt.Sprintf("t.guid <> $%d", len(args)))
|
||||||
|
}
|
||||||
|
args = append(args, limit)
|
||||||
|
|
||||||
|
q := `
|
||||||
|
select t.guid, t.content, t.metadata,
|
||||||
|
ts_rank_cd(to_tsvector('simple', t.content), websearch_to_tsquery('simple', $1)) as similarity,
|
||||||
|
t.created_at
|
||||||
|
from thoughts t
|
||||||
|
where ` + strings.Join(conditions, " and ") + `
|
||||||
|
order by similarity desc
|
||||||
|
limit $` + fmt.Sprintf("%d", len(args))
|
||||||
|
|
||||||
|
rows, err := db.pool.Query(ctx, q, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("text search thoughts: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
results := make([]thoughttypes.SearchResult, 0, limit)
|
||||||
|
for rows.Next() {
|
||||||
|
var result thoughttypes.SearchResult
|
||||||
|
var metadataBytes []byte
|
||||||
|
if err := rows.Scan(&result.ID, &result.Content, &metadataBytes, &result.Similarity, &result.CreatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("scan text search result: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(metadataBytes, &result.Metadata); err != nil {
|
||||||
|
return nil, fmt.Errorf("decode text search metadata: %w", err)
|
||||||
|
}
|
||||||
|
results = append(results, result)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("iterate text search results: %w", err)
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
func topCounts(in map[string]int, limit int) []thoughttypes.KeyCount {
|
func topCounts(in map[string]int, limit int) []thoughttypes.KeyCount {
|
||||||
type kv struct {
|
type kv struct {
|
||||||
key string
|
key string
|
||||||
|
|||||||
141
internal/tools/backfill.go
Normal file
141
internal/tools/backfill.go
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
|
"golang.org/x/sync/semaphore"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/ai"
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/session"
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
const backfillConcurrency = 4
|
||||||
|
|
||||||
|
type BackfillTool struct {
|
||||||
|
store *store.DB
|
||||||
|
provider ai.Provider
|
||||||
|
sessions *session.ActiveProjects
|
||||||
|
logger *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
type BackfillInput struct {
|
||||||
|
Project string `json:"project,omitempty" jsonschema:"optional project name or id to scope the backfill"`
|
||||||
|
Limit int `json:"limit,omitempty" jsonschema:"maximum number of thoughts to process in one call; defaults to 100"`
|
||||||
|
IncludeArchived bool `json:"include_archived,omitempty" jsonschema:"whether to include archived thoughts; defaults to false"`
|
||||||
|
OlderThanDays int `json:"older_than_days,omitempty" jsonschema:"only backfill thoughts older than N days; 0 means no restriction"`
|
||||||
|
DryRun bool `json:"dry_run,omitempty" jsonschema:"report counts and sample ids without generating embeddings"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BackfillFailure struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type BackfillOutput struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Scanned int `json:"scanned"`
|
||||||
|
Embedded int `json:"embedded"`
|
||||||
|
Skipped int `json:"skipped"`
|
||||||
|
Failed int `json:"failed"`
|
||||||
|
DryRun bool `json:"dry_run"`
|
||||||
|
Failures []BackfillFailure `json:"failures,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBackfillTool(db *store.DB, provider ai.Provider, sessions *session.ActiveProjects, logger *slog.Logger) *BackfillTool {
|
||||||
|
return &BackfillTool{store: db, provider: provider, sessions: sessions, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in BackfillInput) (*mcp.CallToolResult, BackfillOutput, error) {
|
||||||
|
limit := in.Limit
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 100
|
||||||
|
}
|
||||||
|
|
||||||
|
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, BackfillOutput{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var projectID *uuid.UUID
|
||||||
|
if project != nil {
|
||||||
|
projectID = &project.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
model := t.provider.EmbeddingModel()
|
||||||
|
|
||||||
|
thoughts, err := t.store.ListThoughtsMissingEmbedding(ctx, model, limit, projectID, in.IncludeArchived, in.OlderThanDays)
|
||||||
|
if err != nil {
|
||||||
|
return nil, BackfillOutput{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
out := BackfillOutput{
|
||||||
|
Model: model,
|
||||||
|
Scanned: len(thoughts),
|
||||||
|
DryRun: in.DryRun,
|
||||||
|
}
|
||||||
|
|
||||||
|
if in.DryRun || len(thoughts) == 0 {
|
||||||
|
return nil, out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
sem := semaphore.NewWeighted(backfillConcurrency)
|
||||||
|
var mu sync.Mutex
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
for _, thought := range thoughts {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err := sem.Acquire(ctx, 1); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id uuid.UUID, content string) {
|
||||||
|
defer wg.Done()
|
||||||
|
defer sem.Release(1)
|
||||||
|
|
||||||
|
vec, embedErr := t.provider.Embed(ctx, content)
|
||||||
|
if embedErr != nil {
|
||||||
|
mu.Lock()
|
||||||
|
out.Failures = append(out.Failures, BackfillFailure{ID: id.String(), Error: embedErr.Error()})
|
||||||
|
mu.Unlock()
|
||||||
|
t.logger.Warn("backfill embed failed", slog.String("thought_id", id.String()), slog.String("error", embedErr.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if upsertErr := t.store.UpsertEmbedding(ctx, id, model, vec); upsertErr != nil {
|
||||||
|
mu.Lock()
|
||||||
|
out.Failures = append(out.Failures, BackfillFailure{ID: id.String(), Error: upsertErr.Error()})
|
||||||
|
mu.Unlock()
|
||||||
|
t.logger.Warn("backfill upsert failed", slog.String("thought_id", id.String()), slog.String("error", upsertErr.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
out.Embedded++
|
||||||
|
mu.Unlock()
|
||||||
|
}(thought.ID, thought.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
out.Failed = len(out.Failures)
|
||||||
|
out.Skipped = out.Scanned - out.Embedded - out.Failed
|
||||||
|
|
||||||
|
t.logger.Info("backfill completed",
|
||||||
|
slog.String("model", model),
|
||||||
|
slog.Int("scanned", out.Scanned),
|
||||||
|
slog.Int("embedded", out.Embedded),
|
||||||
|
slog.Int("failed", out.Failed),
|
||||||
|
slog.Duration("duration", time.Since(start)),
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil, out, nil
|
||||||
|
}
|
||||||
@@ -72,11 +72,7 @@ func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in P
|
|||||||
|
|
||||||
query := strings.TrimSpace(in.Query)
|
query := strings.TrimSpace(in.Query)
|
||||||
if query != "" {
|
if query != "" {
|
||||||
embedding, err := t.provider.Embed(ctx, query)
|
semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, &project.ID, nil)
|
||||||
if err != nil {
|
|
||||||
return nil, ProjectContextOutput{}, err
|
|
||||||
}
|
|
||||||
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, &project.ID, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ProjectContextOutput{}, err
|
return nil, ProjectContextOutput{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -117,11 +117,7 @@ func (t *LinksTool) Related(ctx context.Context, _ *mcp.CallToolRequest, in Rela
|
|||||||
}
|
}
|
||||||
|
|
||||||
if includeSemantic {
|
if includeSemantic {
|
||||||
embedding, err := t.provider.Embed(ctx, thought.Content)
|
semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, thought.Content, t.search.DefaultLimit, t.search.DefaultThreshold, thought.ProjectID, &thought.ID)
|
||||||
if err != nil {
|
|
||||||
return nil, RelatedOutput{}, err
|
|
||||||
}
|
|
||||||
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, t.search.DefaultLimit, thought.ProjectID, &thought.ID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, RelatedOutput{}, err
|
return nil, RelatedOutput{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,17 +48,13 @@ func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Re
|
|||||||
}
|
}
|
||||||
|
|
||||||
limit := normalizeLimit(in.Limit, t.search)
|
limit := normalizeLimit(in.Limit, t.search)
|
||||||
embedding, err := t.provider.Embed(ctx, query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, RecallOutput{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var projectID *uuid.UUID
|
var projectID *uuid.UUID
|
||||||
if project != nil {
|
if project != nil {
|
||||||
projectID = &project.ID
|
projectID = &project.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
semantic, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, projectID, nil)
|
semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, RecallOutput{}, err
|
return nil, RecallOutput{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
41
internal/tools/retrieval.go
Normal file
41
internal/tools/retrieval.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/ai"
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/config"
|
||||||
|
"git.warky.dev/wdevs/amcs/internal/store"
|
||||||
|
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
// semanticSearch runs vector similarity search if embeddings exist for the active model
|
||||||
|
// in the given scope, otherwise falls back to Postgres full-text search.
|
||||||
|
func semanticSearch(
|
||||||
|
ctx context.Context,
|
||||||
|
db *store.DB,
|
||||||
|
provider ai.Provider,
|
||||||
|
search config.SearchConfig,
|
||||||
|
query string,
|
||||||
|
limit int,
|
||||||
|
threshold float64,
|
||||||
|
projectID *uuid.UUID,
|
||||||
|
excludeID *uuid.UUID,
|
||||||
|
) ([]thoughttypes.SearchResult, error) {
|
||||||
|
hasEmbeddings, err := db.HasEmbeddingsForModel(ctx, provider.EmbeddingModel(), projectID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasEmbeddings {
|
||||||
|
embedding, err := provider.Embed(ctx, query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return db.SearchSimilarThoughts(ctx, embedding, provider.EmbeddingModel(), threshold, limit, projectID, excludeID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.SearchThoughtsText(ctx, query, limit, projectID, excludeID)
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||||
|
|
||||||
"git.warky.dev/wdevs/amcs/internal/ai"
|
"git.warky.dev/wdevs/amcs/internal/ai"
|
||||||
@@ -44,24 +45,18 @@ func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Se
|
|||||||
limit := normalizeLimit(in.Limit, t.search)
|
limit := normalizeLimit(in.Limit, t.search)
|
||||||
threshold := normalizeThreshold(in.Threshold, t.search.DefaultThreshold)
|
threshold := normalizeThreshold(in.Threshold, t.search.DefaultThreshold)
|
||||||
|
|
||||||
embedding, err := t.provider.Embed(ctx, query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, SearchOutput{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
|
project, err := resolveProject(ctx, t.store, t.sessions, req, in.Project, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, SearchOutput{}, err
|
return nil, SearchOutput{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
model := t.provider.EmbeddingModel()
|
var projectID *uuid.UUID
|
||||||
var results []thoughttypes.SearchResult
|
|
||||||
if project != nil {
|
if project != nil {
|
||||||
results, err = t.store.SearchSimilarThoughts(ctx, embedding, model, threshold, limit, &project.ID, nil)
|
projectID = &project.ID
|
||||||
_ = t.store.TouchProject(ctx, project.ID)
|
_ = t.store.TouchProject(ctx, project.ID)
|
||||||
} else {
|
|
||||||
results, err = t.store.SearchThoughts(ctx, embedding, model, threshold, limit, map[string]any{})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
results, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, threshold, projectID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, SearchOutput{}, err
|
return nil, SearchOutput{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,15 +48,11 @@ func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
|
|||||||
count := 0
|
count := 0
|
||||||
|
|
||||||
if query != "" {
|
if query != "" {
|
||||||
embedding, err := t.provider.Embed(ctx, query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, SummarizeOutput{}, err
|
|
||||||
}
|
|
||||||
var projectID *uuid.UUID
|
var projectID *uuid.UUID
|
||||||
if project != nil {
|
if project != nil {
|
||||||
projectID = &project.ID
|
projectID = &project.ID
|
||||||
}
|
}
|
||||||
results, err := t.store.SearchSimilarThoughts(ctx, embedding, t.provider.EmbeddingModel(), t.search.DefaultThreshold, limit, projectID, nil)
|
results, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, SummarizeOutput{}, err
|
return nil, SummarizeOutput{}, err
|
||||||
}
|
}
|
||||||
|
|||||||
3
migrations/010_fulltext_index.sql
Normal file
3
migrations/010_fulltext_index.sql
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
-- Full-text search index on thought content for semantic fallback when no embeddings exist.
|
||||||
|
create index if not exists thoughts_content_fts_idx
|
||||||
|
on thoughts using gin(to_tsvector('simple', content));
|
||||||
Reference in New Issue
Block a user