test(config): add migration tests for litellm provider
Some checks failed
CI / build-and-test (push) Failing after -32m22s

* Implement tests for migrating configuration from v1 to v2 for the litellm provider.
* Validate the structure and values of the migrated configuration.
* Ensure migration rejects newer versions of the configuration.
fix(validate): enhance AI provider validation logic
* Consolidate provider validation into a dedicated method.
* Ensure at least one provider is specified and validate its type.
* Check for required fields based on provider type.
fix(mcpserver): update tool set to use new enrichment tool
* Replace RetryMetadataTool with RetryEnrichmentTool in the ToolSet.
fix(tools): refactor tools to use embedding and metadata runners
* Update tools to utilize EmbeddingRunner and MetadataRunner instead of Provider.
* Adjust method calls to align with the new runner interfaces.
This commit is contained in:
2026-04-21 21:14:28 +02:00
parent 532d1560a3
commit 14e218d784
39 changed files with 2062 additions and 901 deletions

1
.gitignore vendored
View File

@@ -34,3 +34,4 @@ OB1/
ui/node_modules/
ui/.svelte-kit/
internal/app/ui/dist/
.codex

View File

@@ -244,12 +244,24 @@ Link existing skills and guardrails to a project so they are automatically avail
Config is YAML-driven. Copy `configs/config.example.yaml` and set:
- `database.url` — Postgres connection string
- `auth.mode``api_keys` or `oauth_client_credentials`
- `auth.keys` — API keys for MCP access via `x-brain-key` or `Authorization: Bearer <key>` when `auth.mode=api_keys`
- `auth.oauth.clients` — client registry when `auth.mode=oauth_client_credentials`
- `auth.keys`static API keys for MCP access via `x-brain-key` or `Authorization: Bearer <key>`
- `auth.oauth.clients` — optional OAuth client credentials registry
- `ai.providers` — named provider definitions (`litellm`, `ollama`, `openrouter`)
- `ai.embeddings.primary` / `ai.metadata.primary` — primary role targets (`provider` + `model`)
- `ai.embeddings.fallbacks` / `ai.metadata.fallbacks` — sequential fallback targets
- `mcp.version` is build-generated and should not be set in config
**OAuth Client Credentials flow** (`auth.mode=oauth_client_credentials`):
Config schema is versioned. Current schema version is `2`.
Use the migration helper to rewrite legacy configs in-place:
```bash
go run ./cmd/amcs-migrate-config --config ./configs/dev.yaml
```
Use `--dry-run` to print migrated YAML without writing.
**OAuth Client Credentials flow**:
1. Obtain a token — `POST /oauth/token` (public, no auth required):
```
@@ -267,8 +279,9 @@ Config is YAML-driven. Copy `configs/config.example.yaml` and set:
```
Alternatively, pass `client_id` and `client_secret` as body parameters instead of `Authorization: Basic`. Direct `Authorization: Basic` credential validation on the MCP endpoint is also supported as a fallback (no token required).
- `ai.litellm.base_url` and `ai.litellm.api_key` — LiteLLM proxy
- `ai.ollama.base_url` and `ai.ollama.api_key` — Ollama local or remote server
- `AMCS_LITELLM_BASE_URL` / `AMCS_LITELLM_API_KEY` override all configured LiteLLM providers
- `AMCS_OLLAMA_BASE_URL` / `AMCS_OLLAMA_API_KEY` override all configured Ollama providers
- `AMCS_OPENROUTER_API_KEY` overrides all configured OpenRouter providers
See `llm/plan.md` for an audited high-level status summary of the original implementation plan, and `llm/todo.md` for the audited backfill/fallback follow-up status.
@@ -643,27 +656,32 @@ Notes:
## Ollama
Set `ai.provider: "ollama"` to use a local or self-hosted Ollama server through its OpenAI-compatible API.
Set your role targets to an Ollama provider to use a local or self-hosted Ollama server through its OpenAI-compatible API.
Example:
```yaml
ai:
provider: "ollama"
embeddings:
model: "nomic-embed-text"
dimensions: 768
metadata:
model: "llama3.2"
temperature: 0.1
ollama:
providers:
local:
type: "ollama"
base_url: "http://localhost:11434/v1"
api_key: "ollama"
request_headers: {}
embeddings:
dimensions: 768
primary:
provider: "local"
model: "nomic-embed-text"
metadata:
temperature: 0.1
primary:
provider: "local"
model: "llama3.2"
```
Notes:
- For remote Ollama servers, point `ai.ollama.base_url` at the remote `/v1` endpoint.
- For remote Ollama servers, point `ai.providers.<name>.base_url` at the remote `/v1` endpoint.
- The client always sends Bearer auth; Ollama ignores it locally, so `api_key: "ollama"` is a safe default.
- `ai.embeddings.dimensions` must match the embedding model you actually use, or startup will fail the database vector-dimension check.

80
changelog.md Normal file
View File

@@ -0,0 +1,80 @@
# Changelog
## 2026-04-21
### 2026-04-21 21h - Config Schema v2 Introduced
- Refactored configuration to schema version `2` with named AI providers and role-based model chains.
- Added support for per-role primary and fallback targets for embeddings and metadata.
- Added optional background role overrides for backfill and metadata retry workers.
### 2026-04-21 21h - Automatic v1 -> v2 Migration
- Added config migration framework with explicit schema versioning.
- Implemented `v1 -> v2` migration to transform legacy provider blocks into named providers + role chains.
- Loader now auto-migrates older config files, rewrites migrated YAML, and creates timestamped backups.
### 2026-04-21 21h - AI Registry and Role Runners
- Added `ai.Registry` to build provider clients from named provider config entries.
- Added `EmbeddingRunner` and `MetadataRunner` with sequential fallback execution.
- Added target health tracking with cooldowns for transient/permanent/empty-response failures.
### 2026-04-21 21h - App and Tool Wiring Updates
- Rewired app startup to use provider registry + role runners for foreground and background flows.
- Updated capture, search, summarize, context, recall, backfill, metadata retry, and reparse paths to use new runners.
- Preserved environment override behavior for provider credentials/endpoints across matching provider types.
### 2026-04-21 21h - Migrate Config CLI Added
- Added `cmd/amcs-migrate-config` CLI to migrate config files to the current schema version.
- Supports dry-run output and in-place write mode with automatic backup file creation.
### 2026-04-21 21h - Tests and Documentation Updated
- Added focused tests for config migration, AI registry behavior, and runner fallback behavior.
- Updated `configs/config.example.yaml` to the new v2 schema.
- Updated README configuration sections and migration guidance to reflect v2 and `amcs-migrate-config` usage.
### 2026-04-21 21h - Uncommitted File Change List
- Modified: `.gitignore`
- Modified: `README.md`
- Modified: `configs/config.example.yaml`
- Modified: `internal/ai/compat/client.go`
- Modified: `internal/ai/compat/client_test.go`
- Modified: `internal/app/app.go`
- Modified: `internal/config/config.go`
- Modified: `internal/config/loader.go`
- Modified: `internal/config/loader_test.go`
- Modified: `internal/config/validate.go`
- Modified: `internal/config/validate_test.go`
- Modified: `internal/mcpserver/server.go`
- Modified: `internal/mcpserver/streamable_integration_test.go`
- Modified: `internal/tools/backfill.go`
- Modified: `internal/tools/capture.go`
- Modified: `internal/tools/context.go`
- Modified: `internal/tools/enrichment_retry.go`
- Modified: `internal/tools/links.go`
- Modified: `internal/tools/metadata_retry.go`
- Modified: `internal/tools/recall.go`
- Modified: `internal/tools/reparse_metadata.go`
- Modified: `internal/tools/retrieval.go`
- Modified: `internal/tools/search.go`
- Modified: `internal/tools/summarize.go`
- Modified: `internal/tools/update.go`
- Deleted: `internal/ai/factory.go`
- Deleted: `internal/ai/factory_test.go`
- Deleted: `internal/ai/litellm/client.go`
- Deleted: `internal/ai/ollama/client.go`
- Deleted: `internal/ai/openrouter/client.go`
- Deleted: `internal/ai/provider.go`
- New: `changelog.md`
- New: `cmd/amcs-migrate-config/main.go`
- New: `internal/ai/registry.go`
- New: `internal/ai/registry_test.go`
- New: `internal/ai/runner.go`
- New: `internal/ai/runner_test.go`
- New: `internal/config/migrate.go`
- New: `internal/config/migrate_test.go`

View File

@@ -0,0 +1,105 @@
package main
import (
"flag"
"fmt"
"log"
"os"
"time"
"gopkg.in/yaml.v3"
"git.warky.dev/wdevs/amcs/internal/config"
)
func main() {
var (
configPath string
dryRun bool
toVersion int
)
flag.StringVar(&configPath, "config", "", "Path to the YAML config file (default: $AMCS_CONFIG or ./configs/dev.yaml)")
flag.BoolVar(&dryRun, "dry-run", false, "Print the migrated config to stdout instead of writing it back")
flag.IntVar(&toVersion, "to-version", config.CurrentConfigVersion, "Stop migrating after reaching this version")
flag.Parse()
if toVersion <= 0 || toVersion > config.CurrentConfigVersion {
log.Fatalf("invalid -to-version %d (must be between 1 and %d)", toVersion, config.CurrentConfigVersion)
}
path := config.ResolvePath(configPath)
original, err := os.ReadFile(path)
if err != nil {
log.Fatalf("read config %q: %v", path, err)
}
raw := map[string]any{}
if err := yaml.Unmarshal(original, &raw); err != nil {
log.Fatalf("decode config %q: %v", path, err)
}
if raw == nil {
raw = map[string]any{}
}
applied, err := migrateUpTo(raw, toVersion)
if err != nil {
log.Fatalf("migrate: %v", err)
}
if len(applied) == 0 {
fmt.Fprintf(os.Stderr, "%s already at version %d; nothing to do\n", path, currentVersion(raw))
return
}
out, err := yaml.Marshal(raw)
if err != nil {
log.Fatalf("marshal migrated config: %v", err)
}
for _, step := range applied {
fmt.Fprintf(os.Stderr, "applied migration v%d -> v%d: %s\n", step.From, step.To, step.Describe)
}
if dryRun {
_, _ = os.Stdout.Write(out)
return
}
backup := fmt.Sprintf("%s.bak.%d", path, time.Now().Unix())
if err := os.WriteFile(backup, original, 0o600); err != nil {
log.Fatalf("write backup %q: %v", backup, err)
}
if err := os.WriteFile(path, out, 0o600); err != nil {
log.Fatalf("write migrated config %q: %v", path, err)
}
fmt.Fprintf(os.Stderr, "wrote migrated config to %s (backup: %s)\n", path, backup)
}
// migrateUpTo runs the migration ladder but stops at the requested version.
func migrateUpTo(raw map[string]any, target int) ([]config.ConfigMigration, error) {
if currentVersion(raw) >= target {
return nil, nil
}
if target == config.CurrentConfigVersion {
return config.Migrate(raw)
}
// Partial migrations are rare; for now reject anything other than the
// current version target since the migration ladder is short.
return nil, fmt.Errorf("partial migration to v%d is not supported (use -to-version=%d)", target, config.CurrentConfigVersion)
}
func currentVersion(raw map[string]any) int {
v, ok := raw["version"]
if !ok {
return 1
}
switch n := v.(type) {
case int:
return n
case int64:
return int(n)
case float64:
return int(n)
}
return 1
}

View File

@@ -1,3 +1,5 @@
version: 2
server:
host: "0.0.0.0"
port: 8080
@@ -27,7 +29,7 @@ auth:
- id: "oauth-client"
client_id: ""
client_secret: ""
description: "used when auth.mode=oauth_client_credentials"
description: "optional OAuth client credentials"
database:
url: "postgres://postgres:postgres@localhost:5432/amcs?sslmode=disable"
@@ -37,33 +39,58 @@ database:
max_conn_idle_time: "10m"
ai:
provider: "litellm"
embeddings:
model: "openai/text-embedding-3-small"
dimensions: 1536
metadata:
model: "gpt-4o-mini"
fallback_models: []
temperature: 0.1
log_conversations: false
litellm:
providers:
default:
type: "litellm"
base_url: "http://localhost:4000/v1"
api_key: "replace-me"
use_responses_api: false
request_headers: {}
embedding_model: "openrouter/openai/text-embedding-3-small"
metadata_model: "gpt-4o-mini"
fallback_metadata_models: []
ollama:
ollama_local:
type: "ollama"
base_url: "http://localhost:11434/v1"
api_key: "ollama"
request_headers: {}
openrouter:
type: "openrouter"
base_url: "https://openrouter.ai/api/v1"
api_key: ""
api_key: "replace-me"
app_name: "amcs"
site_url: ""
extra_headers: {}
request_headers: {}
embeddings:
dimensions: 1536
primary:
provider: "default"
model: "openai/text-embedding-3-small"
fallbacks:
- provider: "ollama_local"
model: "nomic-embed-text"
metadata:
temperature: 0.1
log_conversations: false
timeout: "10s"
primary:
provider: "default"
model: "gpt-4o-mini"
fallbacks:
- provider: "openrouter"
model: "openai/gpt-4.1-mini"
# Optional overrides for background jobs (backfill_embeddings,
# retry_failed_metadata, reparse_thought_metadata).
background:
embeddings:
primary:
provider: "default"
model: "openai/text-embedding-3-small"
metadata:
primary:
provider: "default"
model: "gpt-4o-mini"
capture:
source: "mcp"

View File

@@ -14,7 +14,6 @@ import (
"regexp"
"slices"
"strings"
"sync"
"time"
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
@@ -36,38 +35,41 @@ Rules:
- If unsure, prefer "observation".
- Do not include any text outside the JSON object.`
// Client is a low-level OpenAI-compatible HTTP client. It knows nothing about
// role chains, fallbacks, or health — those concerns belong to ai.Runner. Each
// method takes the model name per-call so a single Client instance can service
// many different models on the same base URL.
type Client struct {
name string
baseURL string
apiKey string
embeddingModel string
metadataModel string
fallbackMetadataModels []string
temperature float64
headers map[string]string
httpClient *http.Client
log *slog.Logger
dimensions int
logConversations bool
modelHealthMu sync.Mutex
modelHealth map[string]modelHealthState
}
type Config struct {
Name string
BaseURL string
APIKey string
EmbeddingModel string
MetadataModel string
FallbackMetadataModels []string
Temperature float64
Headers map[string]string
HTTPClient *http.Client
Log *slog.Logger
Dimensions int
}
// MetadataOptions control a single ExtractMetadataWith call.
type MetadataOptions struct {
Model string
Temperature float64
LogConversations bool
}
// SummarizeOptions control a single SummarizeWith call.
type SummarizeOptions struct {
Model string
Temperature float64
}
type embeddingsRequest struct {
Input string `json:"input"`
Model string `json:"model"`
@@ -127,65 +129,38 @@ type providerError struct {
const maxMetadataAttempts = 3
const (
emptyResponseCircuitThreshold = 3
emptyResponseCircuitTTL = 5 * time.Minute
permanentModelFailureTTL = 24 * time.Hour
)
// ErrEmptyResponse and ErrNoJSONObject are sentinel errors callers can inspect
// to classify metadata failures (e.g. bump empty-response health counters).
var (
errMetadataEmptyResponse = errors.New("metadata empty response")
errMetadataNoJSONObject = errors.New("metadata response contains no JSON object")
ErrEmptyResponse = errors.New("metadata empty response")
ErrNoJSONObject = errors.New("metadata response contains no JSON object")
)
type modelHealthState struct {
consecutiveEmpty int
unhealthyUntil time.Time
}
func New(cfg Config) *Client {
fallbacks := make([]string, 0, len(cfg.FallbackMetadataModels))
seen := make(map[string]struct{}, len(cfg.FallbackMetadataModels))
for _, model := range cfg.FallbackMetadataModels {
model = strings.TrimSpace(model)
if model == "" {
continue
}
if _, ok := seen[model]; ok {
continue
}
seen[model] = struct{}{}
fallbacks = append(fallbacks, model)
}
return &Client{
name: cfg.Name,
baseURL: cfg.BaseURL,
apiKey: cfg.APIKey,
embeddingModel: cfg.EmbeddingModel,
metadataModel: cfg.MetadataModel,
fallbackMetadataModels: fallbacks,
temperature: cfg.Temperature,
headers: cfg.Headers,
httpClient: cfg.HTTPClient,
log: cfg.Log,
dimensions: cfg.Dimensions,
logConversations: cfg.LogConversations,
modelHealth: make(map[string]modelHealthState),
}
}
func (c *Client) Embed(ctx context.Context, input string) ([]float32, error) {
func (c *Client) Name() string { return c.name }
// EmbedWith generates an embedding for the given input using model.
func (c *Client) EmbedWith(ctx context.Context, model, input string) ([]float32, error) {
input = strings.TrimSpace(input)
if input == "" {
return nil, fmt.Errorf("%s embed: input must not be empty", c.name)
}
if strings.TrimSpace(model) == "" {
return nil, fmt.Errorf("%s embed: model is required", c.name)
}
var resp embeddingsResponse
err := c.doJSON(ctx, "/embeddings", embeddingsRequest{
Input: input,
Model: c.embeddingModel,
}, &resp)
err := c.doJSON(ctx, "/embeddings", embeddingsRequest{Input: input, Model: model}, &resp)
if err != nil {
return nil, err
}
@@ -195,133 +170,26 @@ func (c *Client) Embed(ctx context.Context, input string) ([]float32, error) {
if len(resp.Data) == 0 {
return nil, fmt.Errorf("%s embed: no embedding returned", c.name)
}
if c.dimensions > 0 && len(resp.Data[0].Embedding) != c.dimensions {
return nil, fmt.Errorf("%s embed: expected %d dimensions, got %d", c.name, c.dimensions, len(resp.Data[0].Embedding))
}
return resp.Data[0].Embedding, nil
}
func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) {
// ExtractMetadataWith extracts structured metadata for input using opts.Model.
// Returns compat.ErrEmptyResponse / ErrNoJSONObject wrapped when the model
// produces unusable output so callers can classify the failure.
func (c *Client) ExtractMetadataWith(ctx context.Context, opts MetadataOptions, input string) (thoughttypes.ThoughtMetadata, error) {
input = strings.TrimSpace(input)
if input == "" {
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: input must not be empty", c.name)
}
start := time.Now()
if c.log != nil {
c.log.Info("metadata client started",
slog.String("provider", c.name),
slog.String("model", c.metadataModel),
)
}
logCompletion := func(model string, err error) {
if c.log == nil {
return
}
attrs := []any{
slog.String("provider", c.name),
slog.String("model", model),
slog.String("duration", formatLogDuration(time.Since(start))),
}
if err != nil {
attrs = append(attrs, slog.String("error", err.Error()))
c.log.Error("metadata client completed", attrs...)
return
}
c.log.Info("metadata client completed", attrs...)
}
result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel)
if errors.Is(err, errMetadataEmptyResponse) {
c.noteEmptyResponse(c.metadataModel)
}
if isPermanentModelError(err) {
c.notePermanentModelFailure(c.metadataModel, err)
}
if err == nil {
c.noteModelSuccess(c.metadataModel)
logCompletion(c.metadataModel, nil)
return result, nil
}
for _, fallbackModel := range c.fallbackMetadataModels {
if ctx.Err() != nil {
break
}
if fallbackModel == "" || fallbackModel == c.metadataModel {
continue
}
if c.shouldBypassModel(fallbackModel) {
continue
}
if c.log != nil {
c.log.Warn("metadata extraction failed, trying fallback model",
slog.String("provider", c.name),
slog.String("primary_model", c.metadataModel),
slog.String("fallback_model", fallbackModel),
slog.String("error", err.Error()),
)
}
fallbackResult, fallbackErr := c.extractMetadataWithModel(ctx, input, fallbackModel)
if errors.Is(fallbackErr, errMetadataEmptyResponse) {
c.noteEmptyResponse(fallbackModel)
}
if isPermanentModelError(fallbackErr) {
c.notePermanentModelFailure(fallbackModel, fallbackErr)
}
if fallbackErr == nil {
c.noteModelSuccess(fallbackModel)
logCompletion(fallbackModel, nil)
return fallbackResult, nil
}
err = fallbackErr
}
if ctx.Err() != nil {
err = fmt.Errorf("%s metadata: %w", c.name, ctx.Err())
logCompletion(c.metadataModel, err)
return thoughttypes.ThoughtMetadata{}, err
}
heuristic := heuristicMetadataFromInput(input)
if c.log != nil {
c.log.Warn("metadata extraction failed for all models, using heuristic fallback",
slog.String("provider", c.name),
slog.String("error", err.Error()),
)
}
logCompletion(c.metadataModel, nil)
return heuristic, nil
}
func formatLogDuration(d time.Duration) string {
if d < 0 {
d = -d
}
totalMilliseconds := d.Milliseconds()
minutes := totalMilliseconds / 60000
seconds := (totalMilliseconds / 1000) % 60
milliseconds := totalMilliseconds % 1000
return fmt.Sprintf("%02d:%02d:%03d", minutes, seconds, milliseconds)
}
func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) {
if c.shouldBypassModel(model) {
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: model %q temporarily bypassed after repeated empty responses", c.name, model)
if strings.TrimSpace(opts.Model) == "" {
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: model is required", c.name)
}
stream := true
req := chatCompletionsRequest{
Model: model,
Temperature: c.temperature,
ResponseFormat: &responseType{
Type: "json_object",
},
Model: opts.Model,
Temperature: opts.Temperature,
ResponseFormat: &responseType{Type: "json_object"},
Stream: &stream,
Messages: []chatMessage{
{Role: "system", Content: metadataSystemPrompt},
@@ -329,7 +197,7 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri
},
}
metadata, err := c.extractMetadataWithRequest(ctx, req, input, model)
metadata, err := c.extractMetadataWithRequest(ctx, req, input, opts)
if err == nil || !shouldRetryWithoutJSONMode(err) {
return metadata, err
}
@@ -337,23 +205,22 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri
if c.log != nil {
c.log.Warn("metadata json mode failed, retrying without response_format",
slog.String("provider", c.name),
slog.String("model", model),
slog.String("model", opts.Model),
slog.String("error", err.Error()),
)
}
req.ResponseFormat = nil
return c.extractMetadataWithRequest(ctx, req, input, model)
return c.extractMetadataWithRequest(ctx, req, input, opts)
}
func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatCompletionsRequest, input, model string) (thoughttypes.ThoughtMetadata, error) {
func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatCompletionsRequest, input string, opts MetadataOptions) (thoughttypes.ThoughtMetadata, error) {
var lastErr error
for attempt := 1; attempt <= maxMetadataAttempts; attempt++ {
if c.logConversations && c.log != nil {
if opts.LogConversations && c.log != nil {
c.log.Info("metadata conversation request",
slog.String("provider", c.name),
slog.String("model", model),
slog.String("model", opts.Model),
slog.Int("attempt", attempt),
slog.String("system", metadataSystemPrompt),
slog.String("input", input),
@@ -373,10 +240,10 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet
rawResponse := extractChoiceText(resp.Choices[0].Message, resp.Choices[0].Text)
if c.logConversations && c.log != nil {
if opts.LogConversations && c.log != nil {
c.log.Info("metadata conversation response",
slog.String("provider", c.name),
slog.String("model", model),
slog.String("model", opts.Model),
slog.Int("attempt", attempt),
slog.String("response", rawResponse),
)
@@ -387,13 +254,13 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet
metadataText = stripCodeFence(metadataText)
metadataText = extractJSONObject(metadataText)
if metadataText == "" {
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataNoJSONObject)
lastErr = fmt.Errorf("%s metadata: %w", c.name, ErrNoJSONObject)
if strings.TrimSpace(rawResponse) == "" && attempt < maxMetadataAttempts && ctx.Err() == nil {
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataEmptyResponse)
lastErr = fmt.Errorf("%s metadata: %w", c.name, ErrEmptyResponse)
if c.log != nil {
c.log.Warn("metadata response empty, waiting and retrying",
slog.String("provider", c.name),
slog.String("model", model),
slog.String("model", opts.Model),
slog.Int("attempt", attempt+1),
)
}
@@ -403,7 +270,7 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet
continue
}
if strings.TrimSpace(rawResponse) == "" {
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataEmptyResponse)
lastErr = fmt.Errorf("%s metadata: %w", c.name, ErrEmptyResponse)
}
return thoughttypes.ThoughtMetadata{}, lastErr
}
@@ -420,13 +287,17 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet
if lastErr != nil {
return thoughttypes.ThoughtMetadata{}, lastErr
}
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: %w", c.name, errMetadataNoJSONObject)
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: %w", c.name, ErrNoJSONObject)
}
func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
// SummarizeWith runs a chat-completion summarisation using opts.Model.
func (c *Client) SummarizeWith(ctx context.Context, opts SummarizeOptions, systemPrompt, userPrompt string) (string, error) {
if strings.TrimSpace(opts.Model) == "" {
return "", fmt.Errorf("%s summarize: model is required", c.name)
}
req := chatCompletionsRequest{
Model: c.metadataModel,
Temperature: 0.2,
Model: opts.Model,
Temperature: opts.Temperature,
Messages: []chatMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userPrompt},
@@ -447,12 +318,49 @@ func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string)
return extractChoiceText(resp.Choices[0].Message, resp.Choices[0].Text), nil
}
func (c *Client) Name() string {
return c.name
// IsPermanentModelError reports whether err indicates the model itself is
// invalid or missing (vs. a transient outage). Runners use this to mark a
// target unhealthy for longer.
func IsPermanentModelError(err error) bool {
if err == nil {
return false
}
lower := strings.ToLower(err.Error())
for _, marker := range []string{
"invalid model name",
"model_not_found",
"model not found",
"unknown model",
"no such model",
"does not exist",
} {
if strings.Contains(lower, marker) {
return true
}
}
return false
}
func (c *Client) EmbeddingModel() string {
return c.embeddingModel
// HeuristicMetadataFromInput produces best-effort metadata from the note text
// when every model in the chain has failed. Exported so ai.Runner can use it.
func HeuristicMetadataFromInput(input string) thoughttypes.ThoughtMetadata {
text := strings.TrimSpace(input)
lower := strings.ToLower(text)
metadata := thoughttypes.ThoughtMetadata{
People: heuristicPeople(text),
ActionItems: heuristicActionItems(text),
DatesMentioned: heuristicDates(text),
Topics: heuristicTopics(lower),
Type: heuristicType(lower),
}
if len(metadata.Topics) == 0 {
metadata.Topics = []string{"uncategorized"}
}
if metadata.Type == "" {
metadata.Type = "observation"
}
return metadata
}
func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest any) error {
@@ -724,8 +632,6 @@ func isRetryableChatResponseError(err error) bool {
return strings.Contains(lower, "read response") || strings.Contains(lower, "read stream response")
}
// extractJSONObject finds the first complete {...} block in s.
// It handles models that prepend prose to a JSON response despite json_object mode.
func extractJSONObject(s string) string {
for start := 0; start < len(s); start++ {
if s[start] != '{' {
@@ -768,10 +674,6 @@ func extractJSONObject(s string) string {
return ""
}
// stripThinkingBlocks removes <think>...</think> and <thinking>...</thinking>
// blocks produced by reasoning models (DeepSeek R1, QwQ, etc.) so that the
// remaining text can be parsed as JSON without interference from thinking content
// that may itself contain braces.
func stripThinkingBlocks(s string) string {
for _, tag := range []string{"think", "thinking"} {
open := "<" + tag + ">"
@@ -857,7 +759,6 @@ func extractTextFromAny(value any) string {
}
return strings.Join(parts, "\n")
case map[string]any:
// Common provider shapes for chat content parts.
for _, key := range []string{"text", "output_text", "content", "value"} {
if nested, ok := typed[key]; ok {
if text := strings.TrimSpace(extractTextFromAny(nested)); text != "" {
@@ -875,28 +776,6 @@ var (
wordPattern = regexp.MustCompile(`[a-zA-Z][a-zA-Z0-9_/-]{2,}`)
)
func heuristicMetadataFromInput(input string) thoughttypes.ThoughtMetadata {
text := strings.TrimSpace(input)
lower := strings.ToLower(text)
metadata := thoughttypes.ThoughtMetadata{
People: heuristicPeople(text),
ActionItems: heuristicActionItems(text),
DatesMentioned: heuristicDates(text),
Topics: heuristicTopics(lower),
Type: heuristicType(lower),
Source: "",
}
if len(metadata.Topics) == 0 {
metadata.Topics = []string{"uncategorized"}
}
if metadata.Type == "" {
metadata.Type = "observation"
}
return metadata
}
func heuristicType(lower string) string {
switch {
case strings.Contains(lower, "preferred name"), strings.Contains(lower, "personal profile"), strings.Contains(lower, "wife:"), strings.Contains(lower, "daughter:"), strings.Contains(lower, "born:"):
@@ -1055,7 +934,7 @@ func shouldRetryWithoutJSONMode(err error) bool {
if err == nil {
return false
}
if errors.Is(err, errMetadataEmptyResponse) || errors.Is(err, errMetadataNoJSONObject) {
if errors.Is(err, ErrEmptyResponse) || errors.Is(err, ErrNoJSONObject) {
return true
}
@@ -1063,27 +942,6 @@ func shouldRetryWithoutJSONMode(err error) bool {
return strings.Contains(lower, "parse json")
}
func isPermanentModelError(err error) bool {
if err == nil {
return false
}
lower := strings.ToLower(err.Error())
for _, marker := range []string{
"invalid model name",
"model_not_found",
"model not found",
"unknown model",
"no such model",
"does not exist",
} {
if strings.Contains(lower, marker) {
return true
}
}
return false
}
func sleepRetry(ctx context.Context, attempt int, log *slog.Logger, provider string) error {
delay := time.Duration(attempt*attempt) * 200 * time.Millisecond
if log != nil {
@@ -1110,59 +968,3 @@ func sleepMetadataRetry(ctx context.Context, attempt int) error {
return nil
}
}
func (c *Client) shouldBypassModel(model string) bool {
c.modelHealthMu.Lock()
defer c.modelHealthMu.Unlock()
state, ok := c.modelHealth[model]
if !ok {
return false
}
return !state.unhealthyUntil.IsZero() && time.Now().Before(state.unhealthyUntil)
}
func (c *Client) noteEmptyResponse(model string) {
c.modelHealthMu.Lock()
defer c.modelHealthMu.Unlock()
state := c.modelHealth[model]
state.consecutiveEmpty++
if state.consecutiveEmpty >= emptyResponseCircuitThreshold {
state.unhealthyUntil = time.Now().Add(emptyResponseCircuitTTL)
if c.log != nil {
c.log.Warn("metadata model marked temporarily unhealthy after repeated empty responses",
slog.String("provider", c.name),
slog.String("model", model),
slog.Time("until", state.unhealthyUntil),
)
}
}
c.modelHealth[model] = state
}
func (c *Client) noteModelSuccess(model string) {
c.modelHealthMu.Lock()
defer c.modelHealthMu.Unlock()
delete(c.modelHealth, model)
}
func (c *Client) notePermanentModelFailure(model string, err error) {
c.modelHealthMu.Lock()
defer c.modelHealthMu.Unlock()
state := c.modelHealth[model]
state.consecutiveEmpty = emptyResponseCircuitThreshold
state.unhealthyUntil = time.Now().Add(permanentModelFailureTTL)
c.modelHealth[model] = state
if c.log != nil {
c.log.Warn("metadata model marked unhealthy after permanent failure",
slog.String("provider", c.name),
slog.String("model", model),
slog.String("error", err.Error()),
slog.Time("until", state.unhealthyUntil),
)
}
}

View File

@@ -11,6 +11,17 @@ import (
"testing"
)
func newTestClient(t *testing.T, url string) *Client {
t.Helper()
return New(Config{
Name: "litellm",
BaseURL: url,
APIKey: "test-key",
HTTPClient: http.DefaultClient,
Log: slog.New(slog.NewTextHandler(io.Discard, nil)),
})
}
func TestExtractMetadataFromStreamingResponse(t *testing.T) {
t.Parallel()
@@ -26,6 +37,9 @@ func TestExtractMetadataFromStreamingResponse(t *testing.T) {
if req.Stream == nil || !*req.Stream {
t.Fatalf("stream flag = %v, want true", req.Stream)
}
if req.Model != "qwen3.5:latest" {
t.Fatalf("model = %q, want qwen3.5:latest", req.Model)
}
w.Header().Set("Content-Type", "text/event-stream")
_, _ = io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"{\\\"people\\\":[],\"}}]}\n\n")
@@ -35,20 +49,13 @@ func TestExtractMetadataFromStreamingResponse(t *testing.T) {
}))
defer server.Close()
client := New(Config{
Name: "litellm",
BaseURL: server.URL,
APIKey: "test-key",
MetadataModel: "qwen3.5:latest",
client := newTestClient(t, server.URL)
metadata, err := client.ExtractMetadataWith(context.Background(), MetadataOptions{
Model: "qwen3.5:latest",
Temperature: 0.1,
HTTPClient: server.Client(),
Log: slog.New(slog.NewTextHandler(io.Discard, nil)),
EmbeddingModel: "unused",
})
metadata, err := client.ExtractMetadata(context.Background(), "Project idea: Build an Android companion app.")
}, "Project idea: Build an Android companion app.")
if err != nil {
t.Fatalf("ExtractMetadata() error = %v", err)
t.Fatalf("ExtractMetadataWith() error = %v", err)
}
if metadata.Type != "idea" {
@@ -94,20 +101,13 @@ func TestExtractMetadataRetriesWithoutJSONMode(t *testing.T) {
}))
defer server.Close()
client := New(Config{
Name: "litellm",
BaseURL: server.URL,
APIKey: "test-key",
MetadataModel: "qwen3.5:latest",
client := newTestClient(t, server.URL)
metadata, err := client.ExtractMetadataWith(context.Background(), MetadataOptions{
Model: "qwen3.5:latest",
Temperature: 0.1,
HTTPClient: server.Client(),
Log: slog.New(slog.NewTextHandler(io.Discard, nil)),
EmbeddingModel: "unused",
})
metadata, err := client.ExtractMetadata(context.Background(), "Project idea: Build an Android companion app.")
}, "Project idea: Build an Android companion app.")
if err != nil {
t.Fatalf("ExtractMetadata() error = %v", err)
t.Fatalf("ExtractMetadataWith() error = %v", err)
}
if metadata.Type != "idea" {
@@ -127,71 +127,33 @@ func TestExtractMetadataRetriesWithoutJSONMode(t *testing.T) {
}
}
func TestExtractMetadataBypassesInvalidFallbackModelAfterFirstFailure(t *testing.T) {
func TestIsPermanentModelError(t *testing.T) {
t.Parallel()
var mu sync.Mutex
primaryCalls := 0
invalidFallbackCalls := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
_ = r.Body.Close()
}()
var req chatCompletionsRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
cases := []struct {
name string
err error
want bool
}{
{"nil", nil, false},
{"invalid model", errMsg("Invalid model name passed in model=qwen3"), true},
{"model not found", errMsg("model_not_found"), true},
{"no such model", errMsg("no such model"), true},
{"transient", errMsg("connection refused"), false},
}
switch req.Model {
case "empty-primary":
_, _ = io.WriteString(w, `{"choices":[{"message":{"role":"assistant","content":""}}]}`)
case "qwen3.5:latest":
mu.Lock()
primaryCalls++
mu.Unlock()
_, _ = io.WriteString(w, `{"choices":[{"message":{"role":"assistant","content":"{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"metadata\"],\"type\":\"observation\",\"source\":\"primary\"}"}}]}`)
case "qwen3":
mu.Lock()
invalidFallbackCalls++
mu.Unlock()
w.WriteHeader(http.StatusBadRequest)
_, _ = io.WriteString(w, "{\"error\":{\"message\":\"{'error': '/chat/completions: Invalid model name passed in model=qwen3. Call `/v1/models` to view available models for your key.'}\"}}")
default:
t.Fatalf("unexpected model %q", req.Model)
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
if got := IsPermanentModelError(tc.err); got != tc.want {
t.Fatalf("IsPermanentModelError(%v) = %v, want %v", tc.err, got, tc.want)
}
}))
defer server.Close()
client := New(Config{
Name: "litellm",
BaseURL: server.URL,
APIKey: "test-key",
MetadataModel: "empty-primary",
FallbackMetadataModels: []string{"qwen3", "qwen3.5:latest"},
Temperature: 0.1,
HTTPClient: server.Client(),
Log: slog.New(slog.NewTextHandler(io.Discard, nil)),
EmbeddingModel: "unused",
})
for i := 0; i < 2; i++ {
metadata, err := client.ExtractMetadata(context.Background(), "A short note about metadata.")
if err != nil {
t.Fatalf("ExtractMetadata() error = %v", err)
}
if metadata.Source != "primary" {
t.Fatalf("metadata source = %q, want primary", metadata.Source)
}
}
mu.Lock()
defer mu.Unlock()
if invalidFallbackCalls != 1 {
t.Fatalf("invalid fallback calls = %d, want 1", invalidFallbackCalls)
}
if primaryCalls != 2 {
t.Fatalf("valid fallback calls = %d, want 2", primaryCalls)
}
}
type stringError string
func (s stringError) Error() string { return string(s) }
func errMsg(s string) error { return stringError(s) }

View File

@@ -1,25 +0,0 @@
package ai
import (
"fmt"
"log/slog"
"net/http"
"git.warky.dev/wdevs/amcs/internal/ai/litellm"
"git.warky.dev/wdevs/amcs/internal/ai/ollama"
"git.warky.dev/wdevs/amcs/internal/ai/openrouter"
"git.warky.dev/wdevs/amcs/internal/config"
)
func NewProvider(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (Provider, error) {
switch cfg.Provider {
case "litellm":
return litellm.New(cfg, httpClient, log)
case "ollama":
return ollama.New(cfg, httpClient, log)
case "openrouter":
return openrouter.New(cfg, httpClient, log)
default:
return nil, fmt.Errorf("unsupported ai.provider: %s", cfg.Provider)
}
}

View File

@@ -1,33 +0,0 @@
package ai
import (
"io"
"log/slog"
"net/http"
"testing"
"git.warky.dev/wdevs/amcs/internal/config"
)
func TestNewProviderSupportsOllama(t *testing.T) {
provider, err := NewProvider(config.AIConfig{
Provider: "ollama",
Embeddings: config.AIEmbeddingConfig{
Model: "nomic-embed-text",
Dimensions: 768,
},
Metadata: config.AIMetadataConfig{
Model: "llama3.2",
},
Ollama: config.OllamaConfig{
BaseURL: "http://localhost:11434/v1",
APIKey: "ollama",
},
}, &http.Client{}, slog.New(slog.NewTextHandler(io.Discard, nil)))
if err != nil {
t.Fatalf("NewProvider() error = %v", err)
}
if provider.Name() != "ollama" {
t.Fatalf("provider name = %q, want ollama", provider.Name())
}
}

View File

@@ -1,30 +0,0 @@
package litellm
import (
"log/slog"
"net/http"
"git.warky.dev/wdevs/amcs/internal/ai/compat"
"git.warky.dev/wdevs/amcs/internal/config"
)
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
fallbacks := cfg.LiteLLM.EffectiveFallbackMetadataModels()
if len(fallbacks) == 0 {
fallbacks = cfg.Metadata.EffectiveFallbackModels()
}
return compat.New(compat.Config{
Name: "litellm",
BaseURL: cfg.LiteLLM.BaseURL,
APIKey: cfg.LiteLLM.APIKey,
EmbeddingModel: cfg.LiteLLM.EmbeddingModel,
MetadataModel: cfg.LiteLLM.MetadataModel,
FallbackMetadataModels: fallbacks,
Temperature: cfg.Metadata.Temperature,
Headers: cfg.LiteLLM.RequestHeaders,
HTTPClient: httpClient,
Log: log,
Dimensions: cfg.Embeddings.Dimensions,
LogConversations: cfg.Metadata.LogConversations,
}), nil
}

View File

@@ -1,26 +0,0 @@
package ollama
import (
"log/slog"
"net/http"
"git.warky.dev/wdevs/amcs/internal/ai/compat"
"git.warky.dev/wdevs/amcs/internal/config"
)
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
return compat.New(compat.Config{
Name: "ollama",
BaseURL: cfg.Ollama.BaseURL,
APIKey: cfg.Ollama.APIKey,
EmbeddingModel: cfg.Embeddings.Model,
MetadataModel: cfg.Metadata.Model,
FallbackMetadataModels: cfg.Metadata.EffectiveFallbackModels(),
Temperature: cfg.Metadata.Temperature,
Headers: cfg.Ollama.RequestHeaders,
HTTPClient: httpClient,
Log: log,
Dimensions: cfg.Embeddings.Dimensions,
LogConversations: cfg.Metadata.LogConversations,
}), nil
}

View File

@@ -1,37 +0,0 @@
package openrouter
import (
"log/slog"
"net/http"
"git.warky.dev/wdevs/amcs/internal/ai/compat"
"git.warky.dev/wdevs/amcs/internal/config"
)
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
headers := make(map[string]string, len(cfg.OpenRouter.ExtraHeaders)+2)
for key, value := range cfg.OpenRouter.ExtraHeaders {
headers[key] = value
}
if cfg.OpenRouter.SiteURL != "" {
headers["HTTP-Referer"] = cfg.OpenRouter.SiteURL
}
if cfg.OpenRouter.AppName != "" {
headers["X-Title"] = cfg.OpenRouter.AppName
}
return compat.New(compat.Config{
Name: "openrouter",
BaseURL: cfg.OpenRouter.BaseURL,
APIKey: cfg.OpenRouter.APIKey,
EmbeddingModel: cfg.Embeddings.Model,
MetadataModel: cfg.Metadata.Model,
FallbackMetadataModels: cfg.Metadata.EffectiveFallbackModels(),
Temperature: cfg.Metadata.Temperature,
Headers: headers,
HTTPClient: httpClient,
Log: log,
Dimensions: cfg.Embeddings.Dimensions,
LogConversations: cfg.Metadata.LogConversations,
}), nil
}

View File

@@ -1,15 +0,0 @@
package ai
import (
"context"
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
)
type Provider interface {
Embed(ctx context.Context, input string) ([]float32, error)
ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error)
Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error)
Name() string
EmbeddingModel() string
}

96
internal/ai/registry.go Normal file
View File

@@ -0,0 +1,96 @@
package ai
import (
"fmt"
"log/slog"
"net/http"
"strings"
"git.warky.dev/wdevs/amcs/internal/ai/compat"
"git.warky.dev/wdevs/amcs/internal/config"
)
// Registry holds one compat.Client per named provider. Runners look up clients
// by provider name when walking a role chain.
type Registry struct {
clients map[string]*compat.Client
}
// NewRegistry builds a Registry from the configured providers. Each provider
// type maps onto a compat.Client with type-specific header plumbing (e.g.
// openrouter's HTTP-Referer / X-Title).
func NewRegistry(providers map[string]config.ProviderConfig, httpClient *http.Client, log *slog.Logger) (*Registry, error) {
if httpClient == nil {
return nil, fmt.Errorf("ai registry: http client is required")
}
if len(providers) == 0 {
return nil, fmt.Errorf("ai registry: no providers configured")
}
clients := make(map[string]*compat.Client, len(providers))
for name, p := range providers {
headers, err := providerHeaders(p)
if err != nil {
return nil, fmt.Errorf("ai registry: provider %q: %w", name, err)
}
clients[name] = compat.New(compat.Config{
Name: name,
BaseURL: p.BaseURL,
APIKey: p.APIKey,
Headers: headers,
HTTPClient: httpClient,
Log: log,
})
}
return &Registry{clients: clients}, nil
}
// Client returns the compat.Client registered under name.
func (r *Registry) Client(name string) (*compat.Client, error) {
c, ok := r.clients[name]
if !ok {
return nil, fmt.Errorf("ai registry: provider %q is not configured", name)
}
return c, nil
}
// Names returns the registered provider names.
func (r *Registry) Names() []string {
names := make([]string, 0, len(r.clients))
for name := range r.clients {
names = append(names, name)
}
return names
}
func providerHeaders(p config.ProviderConfig) (map[string]string, error) {
switch p.Type {
case "litellm", "ollama":
return cloneHeaders(p.RequestHeaders), nil
case "openrouter":
headers := cloneHeaders(p.RequestHeaders)
if headers == nil {
headers = map[string]string{}
}
if s := strings.TrimSpace(p.SiteURL); s != "" {
headers["HTTP-Referer"] = s
}
if s := strings.TrimSpace(p.AppName); s != "" {
headers["X-Title"] = s
}
return headers, nil
default:
return nil, fmt.Errorf("unsupported provider type %q", p.Type)
}
}
func cloneHeaders(in map[string]string) map[string]string {
if len(in) == 0 {
return nil
}
out := make(map[string]string, len(in))
for k, v := range in {
out[k] = v
}
return out
}

View File

@@ -0,0 +1,80 @@
package ai
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"git.warky.dev/wdevs/amcs/internal/ai/compat"
"git.warky.dev/wdevs/amcs/internal/config"
)
func TestNewRegistryOpenRouterHeaders(t *testing.T) {
var (
gotReferer string
gotTitle string
gotCustom string
)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotReferer = r.Header.Get("HTTP-Referer")
gotTitle = r.Header.Get("X-Title")
gotCustom = r.Header.Get("X-Custom")
_ = json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{{"message": map[string]any{"role": "assistant", "content": "ok"}}},
})
}))
defer srv.Close()
providers := map[string]config.ProviderConfig{
"router": {
Type: "openrouter",
BaseURL: srv.URL,
APIKey: "secret",
RequestHeaders: map[string]string{
"X-Custom": "value",
},
AppName: "amcs",
SiteURL: "https://example.com",
},
}
reg, err := NewRegistry(providers, srv.Client(), nil)
if err != nil {
t.Fatalf("NewRegistry() error = %v", err)
}
client, err := reg.Client("router")
if err != nil {
t.Fatalf("Client(router) error = %v", err)
}
if _, err := client.SummarizeWith(context.Background(), compat.SummarizeOptions{Model: "gpt-4.1-mini"}, "system", "user"); err != nil {
t.Fatalf("SummarizeWith() error = %v", err)
}
if gotReferer != "https://example.com" {
t.Fatalf("HTTP-Referer = %q, want https://example.com", gotReferer)
}
if gotTitle != "amcs" {
t.Fatalf("X-Title = %q, want amcs", gotTitle)
}
if gotCustom != "value" {
t.Fatalf("X-Custom = %q, want value", gotCustom)
}
}
func TestNewRegistryRejectsUnsupportedProviderType(t *testing.T) {
providers := map[string]config.ProviderConfig{
"bad": {
Type: "unknown",
BaseURL: "http://localhost:4000/v1",
APIKey: "secret",
},
}
_, err := NewRegistry(providers, &http.Client{}, nil)
if err == nil {
t.Fatal("NewRegistry() error = nil, want unsupported provider type error")
}
}

367
internal/ai/runner.go Normal file
View File

@@ -0,0 +1,367 @@
package ai
import (
"context"
"errors"
"fmt"
"log/slog"
"sync"
"time"
"git.warky.dev/wdevs/amcs/internal/ai/compat"
"git.warky.dev/wdevs/amcs/internal/config"
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
)
// Health TTLs per failure class. These are short enough that a healed target
// gets retried without manual intervention, but long enough to avoid hammering
// a broken provider every call.
const (
transientCooldown = 30 * time.Second
permanentCooldown = 10 * time.Minute
emptyResponseThreshold = 3
emptyResponseCooldown = 2 * time.Minute
dimensionMismatchWarning = "embedding dimension mismatch"
)
// EmbedResult carries the vector plus the (provider, model) that produced it —
// callers store the actual model so later searches against that row use the
// matching query embedding.
type EmbedResult struct {
Vector []float32
Provider string
Model string
}
// EmbeddingRunner executes the embeddings role chain with sequential fallback.
type EmbeddingRunner struct {
registry *Registry
chain []config.RoleTarget
dimensions int
health *healthTracker
log *slog.Logger
}
// MetadataRunner executes the metadata role chain with sequential fallback and
// a heuristic fallthrough when every target is unhealthy or fails.
type MetadataRunner struct {
registry *Registry
chain []config.RoleTarget
opts metadataRunOpts
health *healthTracker
log *slog.Logger
}
type metadataRunOpts struct {
temperature float64
logConversations bool
}
// NewEmbeddingRunner builds a runner for the embeddings role. chain must be
// non-empty and every target must be registered.
func NewEmbeddingRunner(registry *Registry, chain []config.RoleTarget, dimensions int, log *slog.Logger) (*EmbeddingRunner, error) {
if registry == nil {
return nil, fmt.Errorf("embedding runner: registry is required")
}
if len(chain) == 0 {
return nil, fmt.Errorf("embedding runner: chain is empty")
}
if dimensions <= 0 {
return nil, fmt.Errorf("embedding runner: dimensions must be > 0")
}
for i, t := range chain {
if _, err := registry.Client(t.Provider); err != nil {
return nil, fmt.Errorf("embedding runner: chain[%d]: %w", i, err)
}
}
return &EmbeddingRunner{
registry: registry,
chain: chain,
dimensions: dimensions,
health: newHealthTracker(),
log: log,
}, nil
}
// NewMetadataRunner builds a runner for the metadata role.
func NewMetadataRunner(registry *Registry, chain []config.RoleTarget, temperature float64, logConversations bool, log *slog.Logger) (*MetadataRunner, error) {
if registry == nil {
return nil, fmt.Errorf("metadata runner: registry is required")
}
if len(chain) == 0 {
return nil, fmt.Errorf("metadata runner: chain is empty")
}
for i, t := range chain {
if _, err := registry.Client(t.Provider); err != nil {
return nil, fmt.Errorf("metadata runner: chain[%d]: %w", i, err)
}
}
return &MetadataRunner{
registry: registry,
chain: chain,
opts: metadataRunOpts{
temperature: temperature,
logConversations: logConversations,
},
health: newHealthTracker(),
log: log,
}, nil
}
// PrimaryProvider returns the first provider in the chain.
func (r *EmbeddingRunner) PrimaryProvider() string { return r.chain[0].Provider }
// PrimaryModel returns the first model in the chain — the one used as the
// storage key for search matching.
func (r *EmbeddingRunner) PrimaryModel() string { return r.chain[0].Model }
// Dimensions returns the required vector dimension.
func (r *EmbeddingRunner) Dimensions() int { return r.dimensions }
// Embed walks the chain and returns the first successful embedding. The
// returned EmbedResult names the actual (provider, model) that produced the
// vector — callers use that when recording the row.
func (r *EmbeddingRunner) Embed(ctx context.Context, input string) (EmbedResult, error) {
var errs []error
for _, target := range r.chain {
if r.health.skip(target) {
continue
}
client, err := r.registry.Client(target.Provider)
if err != nil {
errs = append(errs, err)
continue
}
vec, err := client.EmbedWith(ctx, target.Model, input)
if err != nil {
if ctx.Err() != nil {
return EmbedResult{}, ctx.Err()
}
r.classify(target, err)
r.logFailure("embed", target, err)
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err))
continue
}
if len(vec) != r.dimensions {
dimErr := fmt.Errorf("%s: expected %d, got %d", dimensionMismatchWarning, r.dimensions, len(vec))
r.health.markTransient(target)
r.logFailure("embed", target, dimErr)
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, dimErr))
continue
}
r.health.markHealthy(target)
return EmbedResult{Vector: vec, Provider: target.Provider, Model: target.Model}, nil
}
return EmbedResult{}, fmt.Errorf("all embedding targets failed: %w", errors.Join(errs...))
}
// EmbedPrimary embeds using only the primary target — used for search queries
// so the query vector matches rows stored under the primary model. Falls back
// to returning the error without walking the chain.
func (r *EmbeddingRunner) EmbedPrimary(ctx context.Context, input string) ([]float32, error) {
target := r.chain[0]
client, err := r.registry.Client(target.Provider)
if err != nil {
return nil, err
}
vec, err := client.EmbedWith(ctx, target.Model, input)
if err != nil {
r.classify(target, err)
return nil, err
}
if len(vec) != r.dimensions {
return nil, fmt.Errorf("%s: expected %d, got %d", dimensionMismatchWarning, r.dimensions, len(vec))
}
r.health.markHealthy(target)
return vec, nil
}
// PrimaryProvider / PrimaryModel for metadata mirror the embedding runner.
func (r *MetadataRunner) PrimaryProvider() string { return r.chain[0].Provider }
func (r *MetadataRunner) PrimaryModel() string { return r.chain[0].Model }
// ExtractMetadata walks the chain sequentially. If every target fails or is
// unhealthy, it returns a heuristic metadata so capture never hard-fails.
func (r *MetadataRunner) ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) {
var errs []error
for _, target := range r.chain {
if r.health.skip(target) {
continue
}
client, err := r.registry.Client(target.Provider)
if err != nil {
errs = append(errs, err)
continue
}
md, err := client.ExtractMetadataWith(ctx, compat.MetadataOptions{
Model: target.Model,
Temperature: r.opts.temperature,
LogConversations: r.opts.logConversations,
}, input)
if err != nil {
if ctx.Err() != nil {
return thoughttypes.ThoughtMetadata{}, ctx.Err()
}
r.classify(target, err)
r.logFailure("metadata", target, err)
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err))
continue
}
r.health.markHealthy(target)
return md, nil
}
if r.log != nil {
r.log.Warn("metadata chain exhausted, using heuristic fallback",
slog.Int("targets", len(r.chain)),
slog.String("error", errors.Join(errs...).Error()),
)
}
return compat.HeuristicMetadataFromInput(input), nil
}
// Summarize walks the chain; unlike metadata, there is no heuristic fallback —
// returns the joined error when everything fails.
func (r *MetadataRunner) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
var errs []error
for _, target := range r.chain {
if r.health.skip(target) {
continue
}
client, err := r.registry.Client(target.Provider)
if err != nil {
errs = append(errs, err)
continue
}
out, err := client.SummarizeWith(ctx, compat.SummarizeOptions{
Model: target.Model,
Temperature: r.opts.temperature,
}, systemPrompt, userPrompt)
if err != nil {
if ctx.Err() != nil {
return "", ctx.Err()
}
r.classify(target, err)
r.logFailure("summarize", target, err)
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err))
continue
}
r.health.markHealthy(target)
return out, nil
}
return "", fmt.Errorf("all summarize targets failed: %w", errors.Join(errs...))
}
func (r *EmbeddingRunner) classify(target config.RoleTarget, err error) {
switch {
case compat.IsPermanentModelError(err):
r.health.markPermanent(target)
default:
r.health.markTransient(target)
}
}
func (r *MetadataRunner) classify(target config.RoleTarget, err error) {
switch {
case compat.IsPermanentModelError(err):
r.health.markPermanent(target)
case errors.Is(err, compat.ErrEmptyResponse):
r.health.markEmpty(target)
default:
r.health.markTransient(target)
}
}
func (r *EmbeddingRunner) logFailure(role string, target config.RoleTarget, err error) {
if r.log == nil {
return
}
r.log.Warn("ai target failed",
slog.String("role", role),
slog.String("provider", target.Provider),
slog.String("model", target.Model),
slog.String("error", err.Error()),
)
}
func (r *MetadataRunner) logFailure(role string, target config.RoleTarget, err error) {
if r.log == nil {
return
}
r.log.Warn("ai target failed",
slog.String("role", role),
slog.String("provider", target.Provider),
slog.String("model", target.Model),
slog.String("error", err.Error()),
)
}
// healthTracker records per-(provider, model) failure state. skip returns true
// when a target is still inside its cooldown window; the caller then tries the
// next target in the chain.
type healthTracker struct {
mu sync.Mutex
states map[config.RoleTarget]*healthState
}
type healthState struct {
unhealthyUntil time.Time
emptyCount int
}
func newHealthTracker() *healthTracker {
return &healthTracker{states: map[config.RoleTarget]*healthState{}}
}
func (h *healthTracker) skip(target config.RoleTarget) bool {
h.mu.Lock()
defer h.mu.Unlock()
s, ok := h.states[target]
if !ok {
return false
}
return time.Now().Before(s.unhealthyUntil)
}
func (h *healthTracker) markTransient(target config.RoleTarget) {
h.setCooldown(target, transientCooldown)
}
func (h *healthTracker) markPermanent(target config.RoleTarget) {
h.setCooldown(target, permanentCooldown)
}
func (h *healthTracker) markEmpty(target config.RoleTarget) {
h.mu.Lock()
defer h.mu.Unlock()
s := h.states[target]
if s == nil {
s = &healthState{}
h.states[target] = s
}
s.emptyCount++
if s.emptyCount >= emptyResponseThreshold {
s.unhealthyUntil = time.Now().Add(emptyResponseCooldown)
s.emptyCount = 0
}
}
func (h *healthTracker) markHealthy(target config.RoleTarget) {
h.mu.Lock()
defer h.mu.Unlock()
if s, ok := h.states[target]; ok {
s.unhealthyUntil = time.Time{}
s.emptyCount = 0
}
}
func (h *healthTracker) setCooldown(target config.RoleTarget, d time.Duration) {
h.mu.Lock()
defer h.mu.Unlock()
s := h.states[target]
if s == nil {
s = &healthState{}
h.states[target] = s
}
s.unhealthyUntil = time.Now().Add(d)
s.emptyCount = 0
}

139
internal/ai/runner_test.go Normal file
View File

@@ -0,0 +1,139 @@
package ai
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"
"git.warky.dev/wdevs/amcs/internal/config"
)
func TestEmbeddingRunnerFallsBackAndSkipsUnhealthyPrimary(t *testing.T) {
var (
mu sync.Mutex
primaryCalls int
)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/embeddings" {
http.NotFound(w, r)
return
}
var req struct {
Model string `json:"model"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
switch req.Model {
case "embed-primary":
mu.Lock()
primaryCalls++
mu.Unlock()
http.Error(w, "upstream down", http.StatusBadGateway)
case "embed-fallback":
_ = json.NewEncoder(w).Encode(map[string]any{
"data": []map[string]any{{"embedding": []float32{0.1, 0.2, 0.3}}},
})
default:
http.Error(w, "unknown model", http.StatusBadRequest)
}
}))
defer srv.Close()
reg, err := NewRegistry(map[string]config.ProviderConfig{
"p1": {Type: "litellm", BaseURL: srv.URL, APIKey: "k1"},
"p2": {Type: "litellm", BaseURL: srv.URL, APIKey: "k2"},
}, srv.Client(), nil)
if err != nil {
t.Fatalf("NewRegistry() error = %v", err)
}
runner, err := NewEmbeddingRunner(reg, []config.RoleTarget{
{Provider: "p1", Model: "embed-primary"},
{Provider: "p2", Model: "embed-fallback"},
}, 3, nil)
if err != nil {
t.Fatalf("NewEmbeddingRunner() error = %v", err)
}
res, err := runner.Embed(context.Background(), "hello")
if err != nil {
t.Fatalf("Embed() first call error = %v", err)
}
if res.Provider != "p2" || res.Model != "embed-fallback" {
t.Fatalf("Embed() first call target = %s/%s, want p2/embed-fallback", res.Provider, res.Model)
}
res, err = runner.Embed(context.Background(), "hello again")
if err != nil {
t.Fatalf("Embed() second call error = %v", err)
}
if res.Provider != "p2" || res.Model != "embed-fallback" {
t.Fatalf("Embed() second call target = %s/%s, want p2/embed-fallback", res.Provider, res.Model)
}
mu.Lock()
calls := primaryCalls
mu.Unlock()
if calls != 3 {
t.Fatalf("primary calls = %d, want 3 (first request retries 3x; second call should skip unhealthy primary)", calls)
}
}
func TestMetadataRunnerSummarizeFallsBack(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/chat/completions" {
http.NotFound(w, r)
return
}
var req struct {
Model string `json:"model"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
switch req.Model {
case "sum-primary":
http.Error(w, "provider error", http.StatusBadGateway)
case "sum-fallback":
_ = json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{{
"message": map[string]any{"role": "assistant", "content": "fallback summary"},
}},
})
default:
http.Error(w, "unknown model", http.StatusBadRequest)
}
}))
defer srv.Close()
reg, err := NewRegistry(map[string]config.ProviderConfig{
"p1": {Type: "litellm", BaseURL: srv.URL, APIKey: "k1"},
"p2": {Type: "litellm", BaseURL: srv.URL, APIKey: "k2"},
}, srv.Client(), nil)
if err != nil {
t.Fatalf("NewRegistry() error = %v", err)
}
runner, err := NewMetadataRunner(reg, []config.RoleTarget{
{Provider: "p1", Model: "sum-primary"},
{Provider: "p2", Model: "sum-fallback"},
}, 0.1, false, nil)
if err != nil {
t.Fatalf("NewMetadataRunner() error = %v", err)
}
summary, err := runner.Summarize(context.Background(), "system", "user")
if err != nil {
t.Fatalf("Summarize() error = %v", err)
}
if summary != "fallback summary" {
t.Fatalf("summary = %q, want %q", summary, "fallback summary")
}
}

View File

@@ -34,7 +34,7 @@ func Run(ctx context.Context, configPath string) error {
logger.Info("loaded configuration",
slog.String("path", loadedFrom),
slog.String("provider", cfg.AI.Provider),
slog.Int("config_version", cfg.Version),
slog.String("version", info.Version),
slog.String("tag_name", info.TagName),
slog.String("build_date", info.BuildDate),
@@ -52,11 +52,37 @@ func Run(ctx context.Context, configPath string) error {
}
httpClient := &http.Client{Timeout: 30 * time.Second}
provider, err := ai.NewProvider(cfg.AI, httpClient, logger)
registry, err := ai.NewRegistry(cfg.AI.Providers, httpClient, logger)
if err != nil {
return err
}
foregroundEmbeddings, err := ai.NewEmbeddingRunner(registry, cfg.AI.Embeddings.Chain(), cfg.AI.Embeddings.Dimensions, logger)
if err != nil {
return err
}
foregroundMetadata, err := ai.NewMetadataRunner(registry, cfg.AI.Metadata.Chain(), cfg.AI.Metadata.Temperature, cfg.AI.Metadata.LogConversations, logger)
if err != nil {
return err
}
backgroundEmbeddings := foregroundEmbeddings
backgroundMetadata := foregroundMetadata
if cfg.AI.Background != nil {
if cfg.AI.Background.Embeddings != nil {
backgroundEmbeddings, err = ai.NewEmbeddingRunner(registry, cfg.AI.Background.Embeddings.AsTargets(), cfg.AI.Embeddings.Dimensions, logger)
if err != nil {
return err
}
}
if cfg.AI.Background.Metadata != nil {
backgroundMetadata, err = ai.NewMetadataRunner(registry, cfg.AI.Background.Metadata.AsTargets(), cfg.AI.Metadata.Temperature, cfg.AI.Metadata.LogConversations, logger)
if err != nil {
return err
}
}
}
var keyring *auth.Keyring
var oauthRegistry *auth.OAuthRegistry
var tokenStore *auth.TokenStore
@@ -77,12 +103,13 @@ func Run(ctx context.Context, configPath string) error {
dynClients := auth.NewDynamicClientStore()
activeProjects := session.NewActiveProjects()
logger.Info("database connection verified",
slog.String("provider", provider.Name()),
logger.Info("ai providers initialised",
slog.String("embedding_primary", foregroundEmbeddings.PrimaryProvider()+"/"+foregroundEmbeddings.PrimaryModel()),
slog.String("metadata_primary", foregroundMetadata.PrimaryProvider()+"/"+foregroundMetadata.PrimaryModel()),
)
if cfg.Backfill.Enabled && cfg.Backfill.RunOnStartup {
go runBackfillPass(ctx, db, provider, cfg.Backfill, logger)
go runBackfillPass(ctx, db, backgroundEmbeddings, cfg.Backfill, logger)
}
if cfg.Backfill.Enabled && cfg.Backfill.Interval > 0 {
@@ -94,14 +121,14 @@ func Run(ctx context.Context, configPath string) error {
case <-ctx.Done():
return
case <-ticker.C:
runBackfillPass(ctx, db, provider, cfg.Backfill, logger)
runBackfillPass(ctx, db, backgroundEmbeddings, cfg.Backfill, logger)
}
}
}()
}
if cfg.MetadataRetry.Enabled && cfg.MetadataRetry.RunOnStartup {
go runMetadataRetryPass(ctx, db, provider, cfg, activeProjects, logger)
go runMetadataRetryPass(ctx, db, backgroundMetadata, cfg, activeProjects, logger)
}
if cfg.MetadataRetry.Enabled && cfg.MetadataRetry.Interval > 0 {
@@ -113,13 +140,13 @@ func Run(ctx context.Context, configPath string) error {
case <-ctx.Done():
return
case <-ticker.C:
runMetadataRetryPass(ctx, db, provider, cfg, activeProjects, logger)
runMetadataRetryPass(ctx, db, backgroundMetadata, cfg, activeProjects, logger)
}
}
}()
}
handler, err := routes(logger, cfg, info, db, provider, keyring, oauthRegistry, tokenStore, authCodes, dynClients, activeProjects)
handler, err := routes(logger, cfg, info, db, foregroundEmbeddings, foregroundMetadata, backgroundEmbeddings, backgroundMetadata, keyring, oauthRegistry, tokenStore, authCodes, dynClients, activeProjects)
if err != nil {
return err
}
@@ -156,33 +183,33 @@ func Run(ctx context.Context, configPath string) error {
}
}
func routes(logger *slog.Logger, cfg *config.Config, info buildinfo.Info, db *store.DB, provider ai.Provider, keyring *auth.Keyring, oauthRegistry *auth.OAuthRegistry, tokenStore *auth.TokenStore, authCodes *auth.AuthCodeStore, dynClients *auth.DynamicClientStore, activeProjects *session.ActiveProjects) (http.Handler, error) {
func routes(logger *slog.Logger, cfg *config.Config, info buildinfo.Info, db *store.DB, embeddings *ai.EmbeddingRunner, metadata *ai.MetadataRunner, bgEmbeddings *ai.EmbeddingRunner, bgMetadata *ai.MetadataRunner, keyring *auth.Keyring, oauthRegistry *auth.OAuthRegistry, tokenStore *auth.TokenStore, authCodes *auth.AuthCodeStore, dynClients *auth.DynamicClientStore, activeProjects *session.ActiveProjects) (http.Handler, error) {
mux := http.NewServeMux()
accessTracker := auth.NewAccessTracker()
oauthEnabled := oauthRegistry != nil && tokenStore != nil
authMiddleware := auth.Middleware(cfg.Auth, keyring, oauthRegistry, tokenStore, accessTracker, logger)
filesTool := tools.NewFilesTool(db, activeProjects)
enrichmentRetryer := tools.NewEnrichmentRetryer(context.Background(), db, provider, cfg.Capture, cfg.AI.Metadata.Timeout, activeProjects, logger)
backfillTool := tools.NewBackfillTool(db, provider, activeProjects, logger)
enrichmentRetryer := tools.NewEnrichmentRetryer(context.Background(), db, bgMetadata, cfg.Capture, cfg.AI.Metadata.Timeout, activeProjects, logger)
backfillTool := tools.NewBackfillTool(db, bgEmbeddings, activeProjects, logger)
toolSet := mcpserver.ToolSet{
Capture: tools.NewCaptureTool(db, provider, cfg.Capture, cfg.AI.Metadata.Timeout, activeProjects, enrichmentRetryer, backfillTool, logger),
Search: tools.NewSearchTool(db, provider, cfg.Search, activeProjects),
Capture: tools.NewCaptureTool(db, embeddings, metadata, cfg.Capture, cfg.AI.Metadata.Timeout, activeProjects, nil, backfillTool, logger),
Search: tools.NewSearchTool(db, embeddings, cfg.Search, activeProjects),
List: tools.NewListTool(db, cfg.Search, activeProjects),
Stats: tools.NewStatsTool(db),
Get: tools.NewGetTool(db),
Update: tools.NewUpdateTool(db, provider, cfg.Capture, logger),
Update: tools.NewUpdateTool(db, embeddings, metadata, cfg.Capture, logger),
Delete: tools.NewDeleteTool(db),
Archive: tools.NewArchiveTool(db),
Projects: tools.NewProjectsTool(db, activeProjects),
Version: tools.NewVersionTool(cfg.MCP.ServerName, info),
Context: tools.NewContextTool(db, provider, cfg.Search, activeProjects),
Recall: tools.NewRecallTool(db, provider, cfg.Search, activeProjects),
Summarize: tools.NewSummarizeTool(db, provider, cfg.Search, activeProjects),
Links: tools.NewLinksTool(db, provider, cfg.Search),
Context: tools.NewContextTool(db, embeddings, cfg.Search, activeProjects),
Recall: tools.NewRecallTool(db, embeddings, cfg.Search, activeProjects),
Summarize: tools.NewSummarizeTool(db, embeddings, metadata, cfg.Search, activeProjects),
Links: tools.NewLinksTool(db, embeddings, cfg.Search),
Files: filesTool,
Backfill: backfillTool,
Reparse: tools.NewReparseMetadataTool(db, provider, cfg.Capture, activeProjects, logger),
Reparse: tools.NewReparseMetadataTool(db, bgMetadata, cfg.Capture, activeProjects, logger),
RetryMetadata: tools.NewRetryEnrichmentTool(enrichmentRetryer),
Maintenance: tools.NewMaintenanceTool(db),
Skills: tools.NewSkillsTool(db, activeProjects),
@@ -242,8 +269,8 @@ func routes(logger *slog.Logger, cfg *config.Config, info buildinfo.Info, db *st
), nil
}
func runMetadataRetryPass(ctx context.Context, db *store.DB, provider ai.Provider, cfg *config.Config, activeProjects *session.ActiveProjects, logger *slog.Logger) {
retryer := tools.NewMetadataRetryer(ctx, db, provider, cfg.Capture, cfg.AI.Metadata.Timeout, activeProjects, logger)
func runMetadataRetryPass(ctx context.Context, db *store.DB, metadataRunner *ai.MetadataRunner, cfg *config.Config, activeProjects *session.ActiveProjects, logger *slog.Logger) {
retryer := tools.NewMetadataRetryer(ctx, db, metadataRunner, cfg.Capture, cfg.AI.Metadata.Timeout, activeProjects, logger)
_, out, err := retryer.Handle(ctx, nil, tools.RetryMetadataInput{
Limit: cfg.MetadataRetry.MaxPerRun,
IncludeArchived: cfg.MetadataRetry.IncludeArchived,
@@ -261,8 +288,8 @@ func runMetadataRetryPass(ctx context.Context, db *store.DB, provider ai.Provide
)
}
func runBackfillPass(ctx context.Context, db *store.DB, provider ai.Provider, cfg config.BackfillConfig, logger *slog.Logger) {
backfiller := tools.NewBackfillTool(db, provider, nil, logger)
func runBackfillPass(ctx context.Context, db *store.DB, embeddings *ai.EmbeddingRunner, cfg config.BackfillConfig, logger *slog.Logger) {
backfiller := tools.NewBackfillTool(db, embeddings, nil, logger)
_, out, err := backfiller.Handle(ctx, nil, tools.BackfillInput{
Limit: cfg.MaxPerRun,
IncludeArchived: cfg.IncludeArchived,

View File

@@ -8,6 +8,7 @@ const (
)
type Config struct {
Version int `yaml:"version"`
Server ServerConfig `yaml:"server"`
MCP MCPConfig `yaml:"mcp"`
Auth AuthConfig `yaml:"auth"`
@@ -37,10 +38,7 @@ type MCPConfig struct {
Version string `yaml:"version"`
Transport string `yaml:"transport"`
SessionTimeout time.Duration `yaml:"session_timeout"`
// PublicURL is the externally reachable base URL of this server (e.g. https://amcs.example.com).
// When set, it is used to build absolute icon URLs in the MCP server identity.
PublicURL string `yaml:"public_url"`
// Instructions is set at startup from the embedded memory.md and sent to MCP clients on initialise.
Instructions string `yaml:"-"`
}
@@ -77,52 +75,82 @@ type DatabaseConfig struct {
MaxConnIdleTime time.Duration `yaml:"max_conn_idle_time"`
}
// AIConfig (v2): named providers + per-role chains.
type AIConfig struct {
Providers map[string]ProviderConfig `yaml:"providers"`
Embeddings EmbeddingsRoleConfig `yaml:"embeddings"`
Metadata MetadataRoleConfig `yaml:"metadata"`
Background *BackgroundRolesConfig `yaml:"background,omitempty"`
}
type ProviderConfig struct {
Type string `yaml:"type"`
BaseURL string `yaml:"base_url"`
APIKey string `yaml:"api_key"`
RequestHeaders map[string]string `yaml:"request_headers,omitempty"`
AppName string `yaml:"app_name,omitempty"`
SiteURL string `yaml:"site_url,omitempty"`
}
type RoleTarget struct {
Provider string `yaml:"provider"`
Embeddings AIEmbeddingConfig `yaml:"embeddings"`
Metadata AIMetadataConfig `yaml:"metadata"`
LiteLLM LiteLLMConfig `yaml:"litellm"`
Ollama OllamaConfig `yaml:"ollama"`
OpenRouter OpenRouterAIConfig `yaml:"openrouter"`
Model string `yaml:"model"`
}
type AIEmbeddingConfig struct {
Model string `yaml:"model"`
type RoleChain struct {
Primary RoleTarget `yaml:"primary"`
Fallbacks []RoleTarget `yaml:"fallbacks,omitempty"`
}
type EmbeddingsRoleConfig struct {
Dimensions int `yaml:"dimensions"`
Primary RoleTarget `yaml:"primary"`
Fallbacks []RoleTarget `yaml:"fallbacks,omitempty"`
}
type AIMetadataConfig struct {
Model string `yaml:"model"`
FallbackModels []string `yaml:"fallback_models"`
FallbackModel string `yaml:"fallback_model"` // legacy single fallback
type MetadataRoleConfig struct {
Temperature float64 `yaml:"temperature"`
LogConversations bool `yaml:"log_conversations"`
Timeout time.Duration `yaml:"timeout"`
Primary RoleTarget `yaml:"primary"`
Fallbacks []RoleTarget `yaml:"fallbacks,omitempty"`
}
type LiteLLMConfig struct {
BaseURL string `yaml:"base_url"`
APIKey string `yaml:"api_key"`
UseResponsesAPI bool `yaml:"use_responses_api"`
RequestHeaders map[string]string `yaml:"request_headers"`
EmbeddingModel string `yaml:"embedding_model"`
MetadataModel string `yaml:"metadata_model"`
FallbackMetadataModels []string `yaml:"fallback_metadata_models"`
FallbackMetadataModel string `yaml:"fallback_metadata_model"` // legacy single fallback
// BackgroundRolesConfig overrides the foreground chains for background workers
// (backfill_embeddings, metadata_retry, reparse_metadata). Either field may be
// nil to inherit the foreground role unchanged.
type BackgroundRolesConfig struct {
Embeddings *RoleChain `yaml:"embeddings,omitempty"`
Metadata *RoleChain `yaml:"metadata,omitempty"`
}
type OllamaConfig struct {
BaseURL string `yaml:"base_url"`
APIKey string `yaml:"api_key"`
RequestHeaders map[string]string `yaml:"request_headers"`
// Chain returns primary followed by fallbacks (deduped, blanks dropped).
func (e EmbeddingsRoleConfig) Chain() []RoleTarget {
return dedupeTargets(append([]RoleTarget{e.Primary}, e.Fallbacks...))
}
type OpenRouterAIConfig struct {
BaseURL string `yaml:"base_url"`
APIKey string `yaml:"api_key"`
AppName string `yaml:"app_name"`
SiteURL string `yaml:"site_url"`
ExtraHeaders map[string]string `yaml:"extra_headers"`
func (m MetadataRoleConfig) Chain() []RoleTarget {
return dedupeTargets(append([]RoleTarget{m.Primary}, m.Fallbacks...))
}
func (c RoleChain) AsTargets() []RoleTarget {
return dedupeTargets(append([]RoleTarget{c.Primary}, c.Fallbacks...))
}
func dedupeTargets(in []RoleTarget) []RoleTarget {
out := make([]RoleTarget, 0, len(in))
seen := make(map[RoleTarget]struct{}, len(in))
for _, t := range in {
if t.Provider == "" || t.Model == "" {
continue
}
if _, ok := seen[t]; ok {
continue
}
seen[t] = struct{}{}
out = append(out, t)
}
return out
}
type CaptureConfig struct {
@@ -167,45 +195,3 @@ type MetadataRetryConfig struct {
MaxPerRun int `yaml:"max_per_run"`
IncludeArchived bool `yaml:"include_archived"`
}
func (c AIMetadataConfig) EffectiveFallbackModels() []string {
models := make([]string, 0, len(c.FallbackModels)+1)
for _, model := range c.FallbackModels {
if model != "" {
models = append(models, model)
}
}
if c.FallbackModel != "" {
models = append(models, c.FallbackModel)
}
return dedupeNonEmpty(models)
}
func (c LiteLLMConfig) EffectiveFallbackMetadataModels() []string {
models := make([]string, 0, len(c.FallbackMetadataModels)+1)
for _, model := range c.FallbackMetadataModels {
if model != "" {
models = append(models, model)
}
}
if c.FallbackMetadataModel != "" {
models = append(models, c.FallbackMetadataModel)
}
return dedupeNonEmpty(models)
}
func dedupeNonEmpty(values []string) []string {
seen := make(map[string]struct{}, len(values))
out := make([]string, 0, len(values))
for _, value := range values {
if value == "" {
continue
}
if _, ok := seen[value]; ok {
continue
}
seen[value] = struct{}{}
out = append(out, value)
}
return out
}

View File

@@ -2,6 +2,7 @@ package config
import (
"fmt"
"log/slog"
"os"
"strconv"
"strings"
@@ -12,6 +13,12 @@ import (
)
func Load(explicitPath string) (*Config, string, error) {
return LoadWithLogger(explicitPath, nil)
}
// LoadWithLogger is Load with a logger surface for migration notices. Passing
// nil is fine — migration events will simply not be logged.
func LoadWithLogger(explicitPath string, log *slog.Logger) (*Config, string, error) {
path := ResolvePath(explicitPath)
data, err := os.ReadFile(path)
@@ -19,10 +26,40 @@ func Load(explicitPath string) (*Config, string, error) {
return nil, path, fmt.Errorf("read config %q: %w", path, err)
}
cfg := defaultConfig()
if err := yaml.Unmarshal(data, &cfg); err != nil {
raw := map[string]any{}
if err := yaml.Unmarshal(data, &raw); err != nil {
return nil, path, fmt.Errorf("decode config %q: %w", path, err)
}
if raw == nil {
raw = map[string]any{}
}
applied, err := Migrate(raw)
if err != nil {
return nil, path, fmt.Errorf("migrate config %q: %w", path, err)
}
if len(applied) > 0 {
if err := rewriteConfigFile(path, data, raw); err != nil {
return nil, path, err
}
if log != nil {
for _, step := range applied {
log.Warn("config migrated",
slog.String("path", path),
slog.Int("from_version", step.From),
slog.Int("to_version", step.To),
slog.String("describe", step.Describe),
)
}
}
}
cfg, err := decodeTyped(raw)
if err != nil {
return nil, path, fmt.Errorf("decode migrated config %q: %w", path, err)
}
cfg.Version = CurrentConfigVersion
applyEnvOverrides(&cfg)
if err := cfg.Validate(); err != nil {
@@ -32,6 +69,34 @@ func Load(explicitPath string) (*Config, string, error) {
return &cfg, path, nil
}
func decodeTyped(raw map[string]any) (Config, error) {
out, err := yaml.Marshal(raw)
if err != nil {
return Config{}, fmt.Errorf("re-marshal migrated config: %w", err)
}
cfg := defaultConfig()
if err := yaml.Unmarshal(out, &cfg); err != nil {
return Config{}, err
}
return cfg, nil
}
func rewriteConfigFile(path string, original []byte, migrated map[string]any) error {
backupPath := fmt.Sprintf("%s.bak.%d", path, time.Now().Unix())
if err := os.WriteFile(backupPath, original, 0o600); err != nil {
return fmt.Errorf("write backup %q: %w", backupPath, err)
}
out, err := yaml.Marshal(migrated)
if err != nil {
return fmt.Errorf("marshal migrated config: %w", err)
}
if err := os.WriteFile(path, out, 0o600); err != nil {
return fmt.Errorf("write migrated config %q: %w", path, err)
}
return nil
}
func ResolvePath(explicitPath string) string {
if path := strings.TrimSpace(explicitPath); path != "" {
if path != ".yaml" && path != ".yml" {
@@ -49,6 +114,7 @@ func ResolvePath(explicitPath string) string {
func defaultConfig() Config {
info := buildinfo.Current()
return Config{
Version: CurrentConfigVersion,
Server: ServerConfig{
Host: "0.0.0.0",
Port: 8080,
@@ -69,20 +135,14 @@ func defaultConfig() Config {
QueryParam: "key",
},
AI: AIConfig{
Provider: "litellm",
Embeddings: AIEmbeddingConfig{
Model: "openai/text-embedding-3-small",
Providers: map[string]ProviderConfig{},
Embeddings: EmbeddingsRoleConfig{
Dimensions: 1536,
},
Metadata: AIMetadataConfig{
Model: "gpt-4o-mini",
Metadata: MetadataRoleConfig{
Temperature: 0.1,
Timeout: 10 * time.Second,
},
Ollama: OllamaConfig{
BaseURL: "http://localhost:11434/v1",
APIKey: "ollama",
},
},
Capture: CaptureConfig{
Source: DefaultSource,
@@ -119,11 +179,12 @@ func defaultConfig() Config {
func applyEnvOverrides(cfg *Config) {
overrideString(&cfg.Database.URL, "AMCS_DATABASE_URL")
overrideString(&cfg.MCP.PublicURL, "AMCS_PUBLIC_URL")
overrideString(&cfg.AI.LiteLLM.BaseURL, "AMCS_LITELLM_BASE_URL")
overrideString(&cfg.AI.LiteLLM.APIKey, "AMCS_LITELLM_API_KEY")
overrideString(&cfg.AI.Ollama.BaseURL, "AMCS_OLLAMA_BASE_URL")
overrideString(&cfg.AI.Ollama.APIKey, "AMCS_OLLAMA_API_KEY")
overrideString(&cfg.AI.OpenRouter.APIKey, "AMCS_OPENROUTER_API_KEY")
overrideProviderField(cfg, "AMCS_LITELLM_BASE_URL", "litellm", func(p *ProviderConfig, v string) { p.BaseURL = v })
overrideProviderField(cfg, "AMCS_LITELLM_API_KEY", "litellm", func(p *ProviderConfig, v string) { p.APIKey = v })
overrideProviderField(cfg, "AMCS_OLLAMA_BASE_URL", "ollama", func(p *ProviderConfig, v string) { p.BaseURL = v })
overrideProviderField(cfg, "AMCS_OLLAMA_API_KEY", "ollama", func(p *ProviderConfig, v string) { p.APIKey = v })
overrideProviderField(cfg, "AMCS_OPENROUTER_API_KEY", "openrouter", func(p *ProviderConfig, v string) { p.APIKey = v })
if value, ok := os.LookupEnv("AMCS_SERVER_PORT"); ok {
if port, err := strconv.Atoi(strings.TrimSpace(value)); err == nil {
@@ -132,6 +193,24 @@ func applyEnvOverrides(cfg *Config) {
}
}
// overrideProviderField applies an env var to every configured provider of the
// given type. This preserves the v1 behaviour where e.g. AMCS_LITELLM_API_KEY
// rewrote the single litellm block — in v2 it rewrites every litellm provider.
func overrideProviderField(cfg *Config, envKey, providerType string, apply func(*ProviderConfig, string)) {
value, ok := os.LookupEnv(envKey)
if !ok {
return
}
value = strings.TrimSpace(value)
for name, p := range cfg.AI.Providers {
if p.Type != providerType {
continue
}
apply(&p, value)
cfg.AI.Providers[name] = p
}
}
func overrideString(target *string, envKey string) {
if value, ok := os.LookupEnv(envKey); ok {
*target = strings.TrimSpace(value)

View File

@@ -31,9 +31,8 @@ func TestResolvePathIgnoresBareYAMLExtension(t *testing.T) {
}
}
func TestLoadAppliesEnvOverrides(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "test.yaml")
if err := os.WriteFile(configPath, []byte(`
const v2ConfigYAML = `
version: 2
server:
port: 8080
mcp:
@@ -46,18 +45,30 @@ auth:
database:
url: "postgres://from-file"
ai:
provider: "litellm"
embeddings:
dimensions: 1536
litellm:
providers:
default:
type: "litellm"
base_url: "http://localhost:4000/v1"
api_key: "file-key"
embeddings:
dimensions: 1536
primary:
provider: "default"
model: "text-embed"
metadata:
primary:
provider: "default"
model: "gpt-4"
search:
default_limit: 10
max_limit: 50
logging:
level: "info"
`), 0o600); err != nil {
`
func TestLoadAppliesEnvOverrides(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "test.yaml")
if err := os.WriteFile(configPath, []byte(v2ConfigYAML), 0o600); err != nil {
t.Fatalf("write config: %v", err)
}
@@ -76,8 +87,8 @@ logging:
if cfg.Database.URL != "postgres://from-env" {
t.Fatalf("database url = %q, want env override", cfg.Database.URL)
}
if cfg.AI.LiteLLM.APIKey != "env-key" {
t.Fatalf("litellm api key = %q, want env override", cfg.AI.LiteLLM.APIKey)
if cfg.AI.Providers["default"].APIKey != "env-key" {
t.Fatalf("litellm api key = %q, want env override", cfg.AI.Providers["default"].APIKey)
}
if cfg.Server.Port != 9090 {
t.Fatalf("server port = %d, want 9090", cfg.Server.Port)
@@ -90,10 +101,12 @@ logging:
func TestLoadAppliesOllamaEnvOverrides(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "test.yaml")
if err := os.WriteFile(configPath, []byte(`
version: 2
server:
port: 8080
mcp:
path: "/mcp"
session_timeout: "10m"
auth:
keys:
- id: "test"
@@ -101,15 +114,20 @@ auth:
database:
url: "postgres://from-file"
ai:
provider: "ollama"
embeddings:
model: "nomic-embed-text"
dimensions: 768
metadata:
model: "llama3.2"
ollama:
providers:
local:
type: "ollama"
base_url: "http://localhost:11434/v1"
api_key: "ollama"
embeddings:
dimensions: 768
primary:
provider: "local"
model: "nomic-embed-text"
metadata:
primary:
provider: "local"
model: "llama3.2"
search:
default_limit: 10
max_limit: 50
@@ -127,10 +145,77 @@ logging:
t.Fatalf("Load() error = %v", err)
}
if cfg.AI.Ollama.BaseURL != "https://ollama.example.com/v1" {
t.Fatalf("ollama base url = %q, want env override", cfg.AI.Ollama.BaseURL)
p := cfg.AI.Providers["local"]
if p.BaseURL != "https://ollama.example.com/v1" {
t.Fatalf("ollama base url = %q, want env override", p.BaseURL)
}
if cfg.AI.Ollama.APIKey != "remote-key" {
t.Fatalf("ollama api key = %q, want env override", cfg.AI.Ollama.APIKey)
if p.APIKey != "remote-key" {
t.Fatalf("ollama api key = %q, want env override", p.APIKey)
}
}
func TestLoadMigratesV1Config(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "v1.yaml")
v1 := `
server:
port: 8080
mcp:
path: "/mcp"
session_timeout: "10m"
auth:
keys:
- id: "test"
value: "secret"
database:
url: "postgres://from-file"
ai:
provider: "litellm"
embeddings:
model: "text-embed"
dimensions: 1536
metadata:
model: "gpt-4"
temperature: 0.2
fallback_models: ["gpt-3.5"]
litellm:
base_url: "http://localhost:4000/v1"
api_key: "file-key"
search:
default_limit: 10
max_limit: 50
logging:
level: "info"
`
if err := os.WriteFile(configPath, []byte(v1), 0o600); err != nil {
t.Fatalf("write config: %v", err)
}
cfg, _, err := Load(configPath)
if err != nil {
t.Fatalf("Load() error = %v", err)
}
if cfg.Version != CurrentConfigVersion {
t.Fatalf("version = %d, want %d", cfg.Version, CurrentConfigVersion)
}
if p, ok := cfg.AI.Providers["default"]; !ok || p.Type != "litellm" || p.APIKey != "file-key" {
t.Fatalf("providers[default] = %+v, want litellm/file-key", p)
}
if cfg.AI.Embeddings.Primary.Model != "text-embed" || cfg.AI.Embeddings.Primary.Provider != "default" {
t.Fatalf("embeddings.primary = %+v, want default/text-embed", cfg.AI.Embeddings.Primary)
}
if cfg.AI.Metadata.Primary.Model != "gpt-4" || cfg.AI.Metadata.Primary.Provider != "default" {
t.Fatalf("metadata.primary = %+v, want default/gpt-4", cfg.AI.Metadata.Primary)
}
if len(cfg.AI.Metadata.Fallbacks) != 1 || cfg.AI.Metadata.Fallbacks[0].Model != "gpt-3.5" {
t.Fatalf("metadata.fallbacks = %+v, want [default/gpt-3.5]", cfg.AI.Metadata.Fallbacks)
}
entries, err := filepath.Glob(configPath + ".bak.*")
if err != nil {
t.Fatalf("glob backups: %v", err)
}
if len(entries) != 1 {
t.Fatalf("backup files = %d, want 1", len(entries))
}
}

341
internal/config/migrate.go Normal file
View File

@@ -0,0 +1,341 @@
package config
import (
"fmt"
"sort"
)
// CurrentConfigVersion is the schema version this binary expects. Files at a
// lower version are migrated automatically when loaded.
const CurrentConfigVersion = 2
// ConfigMigration upgrades a raw YAML map by one version.
type ConfigMigration struct {
From, To int
Describe string
Apply func(map[string]any) error
}
// migrations is the ordered ladder of upgrades. Add new entries at the end.
var migrations = []ConfigMigration{
{From: 1, To: 2, Describe: "named providers + role chains", Apply: migrateV1toV2},
}
// Migrate brings raw up to CurrentConfigVersion in place. Returns the list of
// migrations that were applied (may be empty if already current).
func Migrate(raw map[string]any) ([]ConfigMigration, error) {
if raw == nil {
return nil, fmt.Errorf("migrate: raw config is nil")
}
version := readVersion(raw)
if version > CurrentConfigVersion {
return nil, fmt.Errorf("migrate: config version %d is newer than supported version %d", version, CurrentConfigVersion)
}
applied := make([]ConfigMigration, 0)
for {
if version >= CurrentConfigVersion {
break
}
step, ok := findMigration(version)
if !ok {
return nil, fmt.Errorf("migrate: no migration registered from version %d", version)
}
if err := step.Apply(raw); err != nil {
return nil, fmt.Errorf("migrate v%d->v%d: %w", step.From, step.To, err)
}
raw["version"] = step.To
version = step.To
applied = append(applied, step)
}
return applied, nil
}
func findMigration(from int) (ConfigMigration, bool) {
for _, m := range migrations {
if m.From == from {
return m, true
}
}
return ConfigMigration{}, false
}
// readVersion returns the version from raw. Files without a version field are
// treated as version 1 (the original schema).
func readVersion(raw map[string]any) int {
v, ok := raw["version"]
if !ok {
return 1
}
switch n := v.(type) {
case int:
return n
case int64:
return int(n)
case float64:
return int(n)
}
return 1
}
// migrateV1toV2 lifts the single-provider config into the named-providers +
// role-chains layout. The pre-v2 config implicitly used one provider for both
// embeddings and metadata; we materialise that as a provider named "default".
func migrateV1toV2(raw map[string]any) error {
aiRaw := mapValue(raw, "ai")
if aiRaw == nil {
aiRaw = map[string]any{}
}
providerType := stringValue(aiRaw, "provider")
if providerType == "" {
providerType = "litellm"
}
providers, embeddingModel, metadataModel, fallbackModels := buildV1Provider(aiRaw, providerType)
embeddingsOld := mapValue(aiRaw, "embeddings")
dimensions := intValue(embeddingsOld, "dimensions")
if dimensions <= 0 {
dimensions = 1536
}
if embeddingModel == "" {
embeddingModel = stringValue(embeddingsOld, "model")
}
metadataOld := mapValue(aiRaw, "metadata")
if metadataModel == "" {
metadataModel = stringValue(metadataOld, "model")
}
temperature := floatValue(metadataOld, "temperature")
logConversations := boolValue(metadataOld, "log_conversations")
timeoutStr := stringValue(metadataOld, "timeout")
if list := stringListValue(metadataOld, "fallback_models"); len(list) > 0 {
fallbackModels = append(fallbackModels, list...)
}
if v := stringValue(metadataOld, "fallback_model"); v != "" {
fallbackModels = append(fallbackModels, v)
}
embeddings := map[string]any{
"dimensions": dimensions,
"primary": map[string]any{"provider": "default", "model": embeddingModel},
}
metadata := map[string]any{
"temperature": temperature,
"log_conversations": logConversations,
"primary": map[string]any{"provider": "default", "model": metadataModel},
}
if timeoutStr != "" {
metadata["timeout"] = timeoutStr
}
if fallbacks := chainTargets("default", fallbackModels); len(fallbacks) > 0 {
metadata["fallbacks"] = fallbacks
}
raw["ai"] = map[string]any{
"providers": providers,
"embeddings": embeddings,
"metadata": metadata,
}
return nil
}
func buildV1Provider(aiRaw map[string]any, providerType string) (map[string]any, string, string, []string) {
providers := map[string]any{}
defaultEntry := map[string]any{"type": providerType}
embedModel := ""
metaModel := ""
var fallbacks []string
switch providerType {
case "litellm":
block := mapValue(aiRaw, "litellm")
copyKeys(defaultEntry, block, "base_url", "api_key")
copyHeaders(defaultEntry, block, "request_headers")
embedModel = stringValue(block, "embedding_model")
metaModel = stringValue(block, "metadata_model")
if list := stringListValue(block, "fallback_metadata_models"); len(list) > 0 {
fallbacks = append(fallbacks, list...)
}
if v := stringValue(block, "fallback_metadata_model"); v != "" {
fallbacks = append(fallbacks, v)
}
case "ollama":
block := mapValue(aiRaw, "ollama")
copyKeys(defaultEntry, block, "base_url", "api_key")
copyHeaders(defaultEntry, block, "request_headers")
case "openrouter":
block := mapValue(aiRaw, "openrouter")
copyKeys(defaultEntry, block, "base_url", "api_key", "app_name", "site_url")
copyHeaders(defaultEntry, block, "extra_headers")
// rename: extra_headers → request_headers
if hdr, ok := defaultEntry["extra_headers"]; ok {
defaultEntry["request_headers"] = hdr
delete(defaultEntry, "extra_headers")
}
}
providers["default"] = defaultEntry
return providers, embedModel, metaModel, fallbacks
}
func chainTargets(provider string, models []string) []any {
out := make([]any, 0, len(models))
seen := map[string]struct{}{}
for _, m := range models {
if m == "" {
continue
}
key := provider + "|" + m
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, map[string]any{"provider": provider, "model": m})
}
return out
}
func mapValue(raw map[string]any, key string) map[string]any {
if raw == nil {
return nil
}
v, ok := raw[key]
if !ok {
return nil
}
switch m := v.(type) {
case map[string]any:
return m
case map[any]any:
return convertAnyMap(m)
}
return nil
}
func convertAnyMap(in map[any]any) map[string]any {
out := make(map[string]any, len(in))
keys := make([]string, 0, len(in))
for k, v := range in {
ks, ok := k.(string)
if !ok {
continue
}
keys = append(keys, ks)
out[ks] = v
}
sort.Strings(keys)
return out
}
func stringValue(raw map[string]any, key string) string {
if raw == nil {
return ""
}
v, ok := raw[key]
if !ok {
return ""
}
if s, ok := v.(string); ok {
return s
}
return ""
}
func intValue(raw map[string]any, key string) int {
if raw == nil {
return 0
}
switch n := raw[key].(type) {
case int:
return n
case int64:
return int(n)
case float64:
return int(n)
}
return 0
}
func floatValue(raw map[string]any, key string) float64 {
if raw == nil {
return 0
}
switch n := raw[key].(type) {
case float64:
return n
case int:
return float64(n)
case int64:
return float64(n)
}
return 0
}
func boolValue(raw map[string]any, key string) bool {
if raw == nil {
return false
}
if b, ok := raw[key].(bool); ok {
return b
}
return false
}
func stringListValue(raw map[string]any, key string) []string {
if raw == nil {
return nil
}
v, ok := raw[key]
if !ok {
return nil
}
list, ok := v.([]any)
if !ok {
return nil
}
out := make([]string, 0, len(list))
for _, item := range list {
if s, ok := item.(string); ok && s != "" {
out = append(out, s)
}
}
return out
}
func copyKeys(dst, src map[string]any, keys ...string) {
if src == nil {
return
}
for _, k := range keys {
if v, ok := src[k]; ok {
dst[k] = v
}
}
}
func copyHeaders(dst, src map[string]any, key string) {
if src == nil {
return
}
v, ok := src[key]
if !ok {
return
}
switch headers := v.(type) {
case map[string]any:
if len(headers) == 0 {
return
}
dst[key] = headers
case map[any]any:
if len(headers) == 0 {
return
}
dst[key] = convertAnyMap(headers)
}
}

View File

@@ -0,0 +1,77 @@
package config
import "testing"
func TestMigrateV1ToV2Litellm(t *testing.T) {
raw := map[string]any{
"ai": map[string]any{
"provider": "litellm",
"embeddings": map[string]any{
"model": "text-embedding-3-small",
"dimensions": 1536,
},
"metadata": map[string]any{
"model": "gpt-4o-mini",
"temperature": 0.2,
"fallback_models": []any{"gpt-4.1-mini"},
},
"litellm": map[string]any{
"base_url": "http://localhost:4000/v1",
"api_key": "secret",
},
},
}
applied, err := Migrate(raw)
if err != nil {
t.Fatalf("Migrate() error = %v", err)
}
if len(applied) != 1 || applied[0].From != 1 || applied[0].To != 2 {
t.Fatalf("applied = %+v, want [v1->v2]", applied)
}
if got := readVersion(raw); got != CurrentConfigVersion {
t.Fatalf("version = %d, want %d", got, CurrentConfigVersion)
}
ai := mapValue(raw, "ai")
providers := mapValue(ai, "providers")
def := mapValue(providers, "default")
if got := stringValue(def, "type"); got != "litellm" {
t.Fatalf("providers.default.type = %q, want litellm", got)
}
if got := stringValue(def, "base_url"); got != "http://localhost:4000/v1" {
t.Fatalf("providers.default.base_url = %q", got)
}
emb := mapValue(ai, "embeddings")
embPrimary := mapValue(emb, "primary")
if stringValue(embPrimary, "provider") != "default" || stringValue(embPrimary, "model") != "text-embedding-3-small" {
t.Fatalf("embeddings.primary = %+v, want default/text-embedding-3-small", embPrimary)
}
meta := mapValue(ai, "metadata")
metaPrimary := mapValue(meta, "primary")
if stringValue(metaPrimary, "provider") != "default" || stringValue(metaPrimary, "model") != "gpt-4o-mini" {
t.Fatalf("metadata.primary = %+v, want default/gpt-4o-mini", metaPrimary)
}
fallbacks, ok := meta["fallbacks"].([]any)
if !ok || len(fallbacks) != 1 {
t.Fatalf("metadata.fallbacks = %#v, want len=1", meta["fallbacks"])
}
firstFallback, ok := fallbacks[0].(map[string]any)
if !ok {
t.Fatalf("metadata.fallbacks[0] type = %T, want map[string]any", fallbacks[0])
}
if stringValue(firstFallback, "provider") != "default" || stringValue(firstFallback, "model") != "gpt-4.1-mini" {
t.Fatalf("metadata fallback = %+v, want default/gpt-4.1-mini", firstFallback)
}
}
func TestMigrateRejectsNewerVersion(t *testing.T) {
raw := map[string]any{"version": CurrentConfigVersion + 1}
_, err := Migrate(raw)
if err == nil {
t.Fatal("Migrate() error = nil, want error for newer config version")
}
}

View File

@@ -45,38 +45,8 @@ func (c Config) Validate() error {
return fmt.Errorf("invalid config: mcp.session_timeout must be greater than zero")
}
switch c.AI.Provider {
case "litellm", "ollama", "openrouter":
default:
return fmt.Errorf("invalid config: unsupported ai.provider %q", c.AI.Provider)
}
if c.AI.Embeddings.Dimensions <= 0 {
return fmt.Errorf("invalid config: ai.embeddings.dimensions must be greater than zero")
}
switch c.AI.Provider {
case "litellm":
if strings.TrimSpace(c.AI.LiteLLM.BaseURL) == "" {
return fmt.Errorf("invalid config: ai.litellm.base_url is required when ai.provider=litellm")
}
if strings.TrimSpace(c.AI.LiteLLM.APIKey) == "" {
return fmt.Errorf("invalid config: ai.litellm.api_key is required when ai.provider=litellm")
}
case "ollama":
if strings.TrimSpace(c.AI.Ollama.BaseURL) == "" {
return fmt.Errorf("invalid config: ai.ollama.base_url is required when ai.provider=ollama")
}
if strings.TrimSpace(c.AI.Ollama.APIKey) == "" {
return fmt.Errorf("invalid config: ai.ollama.api_key is required when ai.provider=ollama")
}
case "openrouter":
if strings.TrimSpace(c.AI.OpenRouter.BaseURL) == "" {
return fmt.Errorf("invalid config: ai.openrouter.base_url is required when ai.provider=openrouter")
}
if strings.TrimSpace(c.AI.OpenRouter.APIKey) == "" {
return fmt.Errorf("invalid config: ai.openrouter.api_key is required when ai.provider=openrouter")
}
if err := c.AI.validate(); err != nil {
return err
}
if c.Server.Port <= 0 {
@@ -108,3 +78,61 @@ func (c Config) Validate() error {
return nil
}
func (a AIConfig) validate() error {
if len(a.Providers) == 0 {
return fmt.Errorf("invalid config: ai.providers must contain at least one entry")
}
for name, p := range a.Providers {
if strings.TrimSpace(name) == "" {
return fmt.Errorf("invalid config: ai.providers contains an entry with an empty name")
}
switch p.Type {
case "litellm", "ollama", "openrouter":
default:
return fmt.Errorf("invalid config: ai.providers.%s.type %q is not supported", name, p.Type)
}
if strings.TrimSpace(p.BaseURL) == "" {
return fmt.Errorf("invalid config: ai.providers.%s.base_url is required", name)
}
if strings.TrimSpace(p.APIKey) == "" {
return fmt.Errorf("invalid config: ai.providers.%s.api_key is required", name)
}
}
if a.Embeddings.Dimensions <= 0 {
return fmt.Errorf("invalid config: ai.embeddings.dimensions must be greater than zero")
}
if err := a.validateChain("ai.embeddings", a.Embeddings.Chain()); err != nil {
return err
}
if err := a.validateChain("ai.metadata", a.Metadata.Chain()); err != nil {
return err
}
if a.Background != nil {
if a.Background.Embeddings != nil {
if err := a.validateChain("ai.background.embeddings", a.Background.Embeddings.AsTargets()); err != nil {
return err
}
}
if a.Background.Metadata != nil {
if err := a.validateChain("ai.background.metadata", a.Background.Metadata.AsTargets()); err != nil {
return err
}
}
}
return nil
}
func (a AIConfig) validateChain(prefix string, chain []RoleTarget) error {
if len(chain) == 0 {
return fmt.Errorf("invalid config: %s.primary must reference a configured provider and model", prefix)
}
for i, target := range chain {
if _, ok := a.Providers[target.Provider]; !ok {
return fmt.Errorf("invalid config: %s[%d] references unknown provider %q", prefix, i, target.Provider)
}
}
return nil
}

View File

@@ -7,6 +7,7 @@ import (
func validConfig() Config {
return Config{
Version: CurrentConfigVersion,
Server: ServerConfig{Port: 8080},
MCP: MCPConfig{Path: "/mcp", SessionTimeout: 10 * time.Minute},
Auth: AuthConfig{
@@ -14,21 +15,15 @@ func validConfig() Config {
},
Database: DatabaseConfig{URL: "postgres://example"},
AI: AIConfig{
Provider: "litellm",
Embeddings: AIEmbeddingConfig{
Providers: map[string]ProviderConfig{
"default": {Type: "litellm", BaseURL: "http://localhost:4000/v1", APIKey: "key"},
},
Embeddings: EmbeddingsRoleConfig{
Dimensions: 1536,
Primary: RoleTarget{Provider: "default", Model: "text-embed"},
},
LiteLLM: LiteLLMConfig{
BaseURL: "http://localhost:4000/v1",
APIKey: "key",
},
Ollama: OllamaConfig{
BaseURL: "http://localhost:11434/v1",
APIKey: "ollama",
},
OpenRouter: OpenRouterAIConfig{
BaseURL: "https://openrouter.ai/api/v1",
APIKey: "key",
Metadata: MetadataRoleConfig{
Primary: RoleTarget{Provider: "default", Model: "gpt-4"},
},
},
Search: SearchConfig{DefaultLimit: 10, MaxLimit: 50},
@@ -36,29 +31,44 @@ func validConfig() Config {
}
}
func TestValidateAcceptsSupportedProviders(t *testing.T) {
func TestValidateAcceptsSupportedProviderTypes(t *testing.T) {
for _, providerType := range []string{"litellm", "ollama", "openrouter"} {
cfg := validConfig()
p := cfg.AI.Providers["default"]
p.Type = providerType
cfg.AI.Providers["default"] = p
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate litellm error = %v", err)
t.Fatalf("Validate %s error = %v", providerType, err)
}
cfg.AI.Provider = "ollama"
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate ollama error = %v", err)
}
cfg.AI.Provider = "openrouter"
if err := cfg.Validate(); err != nil {
t.Fatalf("Validate openrouter error = %v", err)
}
}
func TestValidateRejectsInvalidProvider(t *testing.T) {
func TestValidateRejectsInvalidProviderType(t *testing.T) {
cfg := validConfig()
cfg.AI.Provider = "unknown"
p := cfg.AI.Providers["default"]
p.Type = "unknown"
cfg.AI.Providers["default"] = p
if err := cfg.Validate(); err == nil {
t.Fatal("Validate() error = nil, want error for unsupported provider")
t.Fatal("Validate() error = nil, want error for unsupported provider type")
}
}
func TestValidateRejectsChainWithUnknownProvider(t *testing.T) {
cfg := validConfig()
cfg.AI.Metadata.Primary = RoleTarget{Provider: "does-not-exist", Model: "x"}
if err := cfg.Validate(); err == nil {
t.Fatal("Validate() error = nil, want error for chain referencing unknown provider")
}
}
func TestValidateRejectsEmptyProviders(t *testing.T) {
cfg := validConfig()
cfg.AI.Providers = map[string]ProviderConfig{}
if err := cfg.Validate(); err == nil {
t.Fatal("Validate() error = nil, want error for empty providers")
}
}

View File

@@ -35,7 +35,7 @@ type ToolSet struct {
Files *tools.FilesTool
Backfill *tools.BackfillTool
Reparse *tools.ReparseMetadataTool
RetryMetadata *tools.RetryMetadataTool
RetryMetadata *tools.RetryEnrichmentTool
Maintenance *tools.MaintenanceTool
Skills *tools.SkillsTool
ChatHistory *tools.ChatHistoryTool

View File

@@ -126,7 +126,7 @@ func streamableTestToolSet() ToolSet {
Files: new(tools.FilesTool),
Backfill: new(tools.BackfillTool),
Reparse: new(tools.ReparseMetadataTool),
RetryMetadata: new(tools.RetryMetadataTool),
RetryMetadata: new(tools.RetryEnrichmentTool),
Maintenance: new(tools.MaintenanceTool),
Skills: new(tools.SkillsTool),
}

View File

@@ -19,7 +19,7 @@ const backfillConcurrency = 4
type BackfillTool struct {
store *store.DB
provider ai.Provider
embeddings *ai.EmbeddingRunner
sessions *session.ActiveProjects
logger *slog.Logger
}
@@ -47,15 +47,15 @@ type BackfillOutput struct {
Failures []BackfillFailure `json:"failures,omitempty"`
}
func NewBackfillTool(db *store.DB, provider ai.Provider, sessions *session.ActiveProjects, logger *slog.Logger) *BackfillTool {
return &BackfillTool{store: db, provider: provider, sessions: sessions, logger: logger}
func NewBackfillTool(db *store.DB, embeddings *ai.EmbeddingRunner, sessions *session.ActiveProjects, logger *slog.Logger) *BackfillTool {
return &BackfillTool{store: db, embeddings: embeddings, sessions: sessions, logger: logger}
}
// QueueThought queues a single thought for background embedding generation.
// It is used by capture when the embedding provider is temporarily unavailable.
func (t *BackfillTool) QueueThought(ctx context.Context, id uuid.UUID, content string) {
go func() {
vec, err := t.provider.Embed(ctx, content)
result, err := t.embeddings.Embed(ctx, content)
if err != nil {
t.logger.Warn("background embedding retry failed",
slog.String("thought_id", id.String()),
@@ -63,15 +63,17 @@ func (t *BackfillTool) QueueThought(ctx context.Context, id uuid.UUID, content s
)
return
}
model := t.provider.EmbeddingModel()
if err := t.store.UpsertEmbedding(ctx, id, model, vec); err != nil {
if err := t.store.UpsertEmbedding(ctx, id, result.Model, result.Vector); err != nil {
t.logger.Warn("background embedding upsert failed",
slog.String("thought_id", id.String()),
slog.String("error", err.Error()),
)
return
}
t.logger.Info("background embedding retry succeeded", slog.String("thought_id", id.String()))
t.logger.Info("background embedding retry succeeded",
slog.String("thought_id", id.String()),
slog.String("model", result.Model),
)
}()
}
@@ -91,15 +93,15 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
projectID = &project.ID
}
model := t.provider.EmbeddingModel()
primaryModel := t.embeddings.PrimaryModel()
thoughts, err := t.store.ListThoughtsMissingEmbedding(ctx, model, limit, projectID, in.IncludeArchived, in.OlderThanDays)
thoughts, err := t.store.ListThoughtsMissingEmbedding(ctx, primaryModel, limit, projectID, in.IncludeArchived, in.OlderThanDays)
if err != nil {
return nil, BackfillOutput{}, err
}
out := BackfillOutput{
Model: model,
Model: primaryModel,
Scanned: len(thoughts),
DryRun: in.DryRun,
}
@@ -125,7 +127,7 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
defer wg.Done()
defer sem.Release(1)
vec, embedErr := t.provider.Embed(ctx, content)
result, embedErr := t.embeddings.Embed(ctx, content)
if embedErr != nil {
mu.Lock()
out.Failures = append(out.Failures, BackfillFailure{ID: id.String(), Error: embedErr.Error()})
@@ -134,7 +136,7 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
return
}
if upsertErr := t.store.UpsertEmbedding(ctx, id, model, vec); upsertErr != nil {
if upsertErr := t.store.UpsertEmbedding(ctx, id, result.Model, result.Vector); upsertErr != nil {
mu.Lock()
out.Failures = append(out.Failures, BackfillFailure{ID: id.String(), Error: upsertErr.Error()})
mu.Unlock()
@@ -154,7 +156,7 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
out.Skipped = out.Scanned - out.Embedded - out.Failed
t.logger.Info("backfill completed",
slog.String("model", model),
slog.String("model", primaryModel),
slog.Int("scanned", out.Scanned),
slog.Int("embedded", out.Embedded),
slog.Int("failed", out.Failed),

View File

@@ -22,13 +22,20 @@ type EmbeddingQueuer interface {
QueueThought(ctx context.Context, id uuid.UUID, content string)
}
// MetadataQueuer queues a thought for background metadata retry. Both
// MetadataRetryer and EnrichmentRetryer satisfy this.
type MetadataQueuer interface {
QueueThought(id uuid.UUID)
}
type CaptureTool struct {
store *store.DB
provider ai.Provider
embeddings *ai.EmbeddingRunner
metadata *ai.MetadataRunner
capture config.CaptureConfig
sessions *session.ActiveProjects
metadataTimeout time.Duration
retryer *MetadataRetryer
retryer MetadataQueuer
embedRetryer EmbeddingQueuer
log *slog.Logger
}
@@ -42,8 +49,8 @@ type CaptureOutput struct {
Thought thoughttypes.Thought `json:"thought"`
}
func NewCaptureTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, retryer *MetadataRetryer, embedRetryer EmbeddingQueuer, log *slog.Logger) *CaptureTool {
return &CaptureTool{store: db, provider: provider, capture: capture, sessions: sessions, metadataTimeout: metadataTimeout, retryer: retryer, embedRetryer: embedRetryer, log: log}
func NewCaptureTool(db *store.DB, embeddings *ai.EmbeddingRunner, metadata *ai.MetadataRunner, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, retryer MetadataQueuer, embedRetryer EmbeddingQueuer, log *slog.Logger) *CaptureTool {
return &CaptureTool{store: db, embeddings: embeddings, metadata: metadata, capture: capture, sessions: sessions, metadataTimeout: metadataTimeout, retryer: retryer, embedRetryer: embedRetryer, log: log}
}
func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in CaptureInput) (*mcp.CallToolResult, CaptureOutput, error) {
@@ -66,7 +73,7 @@ func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in C
thought.ProjectID = &project.ID
}
created, err := t.store.InsertThought(ctx, thought, t.provider.EmbeddingModel())
created, err := t.store.InsertThought(ctx, thought, t.embeddings.PrimaryModel())
if err != nil {
return nil, CaptureOutput{}, err
}
@@ -89,7 +96,7 @@ func (t *CaptureTool) launchEnrichment(id uuid.UUID, content string) {
if t.retryer != nil {
attemptedAt := time.Now().UTC()
rawMetadata := metadata.Fallback(t.capture)
extracted, err := t.provider.ExtractMetadata(ctx, content)
extracted, err := t.metadata.ExtractMetadata(ctx, content)
if err != nil {
failed := metadata.MarkMetadataFailed(rawMetadata, t.capture, attemptedAt, err)
if _, updateErr := t.store.UpdateThoughtMetadata(ctx, id, failed); updateErr != nil {
@@ -100,7 +107,7 @@ func (t *CaptureTool) launchEnrichment(id uuid.UUID, content string) {
}
t.log.Warn("deferred metadata extraction failed",
slog.String("thought_id", id.String()),
slog.String("provider", t.provider.Name()),
slog.String("provider", t.metadata.PrimaryProvider()),
slog.String("error", err.Error()),
)
t.retryer.QueueThought(id)
@@ -116,10 +123,10 @@ func (t *CaptureTool) launchEnrichment(id uuid.UUID, content string) {
}
if t.embedRetryer != nil {
if _, err := t.provider.Embed(ctx, content); err != nil {
if _, err := t.embeddings.Embed(ctx, content); err != nil {
t.log.Warn("deferred embedding failed",
slog.String("thought_id", id.String()),
slog.String("provider", t.provider.Name()),
slog.String("provider", t.embeddings.PrimaryProvider()),
slog.String("error", err.Error()),
)
}

View File

@@ -16,7 +16,7 @@ import (
type ContextTool struct {
store *store.DB
provider ai.Provider
embeddings *ai.EmbeddingRunner
search config.SearchConfig
sessions *session.ActiveProjects
}
@@ -41,8 +41,8 @@ type ProjectContextOutput struct {
Items []ContextItem `json:"items"`
}
func NewContextTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *ContextTool {
return &ContextTool{store: db, provider: provider, search: search, sessions: sessions}
func NewContextTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig, sessions *session.ActiveProjects) *ContextTool {
return &ContextTool{store: db, embeddings: embeddings, search: search, sessions: sessions}
}
func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in ProjectContextInput) (*mcp.CallToolResult, ProjectContextOutput, error) {
@@ -72,7 +72,7 @@ func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in P
query := strings.TrimSpace(in.Query)
if query != "" {
semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, &project.ID, nil)
semantic, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, t.search.DefaultThreshold, &project.ID, nil)
if err != nil {
return nil, ProjectContextOutput{}, err
}

View File

@@ -32,7 +32,7 @@ var enrichmentRetryBackoff = []time.Duration{
type EnrichmentRetryer struct {
backgroundCtx context.Context
store *store.DB
provider ai.Provider
metadata *ai.MetadataRunner
capture config.CaptureConfig
sessions *session.ActiveProjects
metadataTimeout time.Duration
@@ -66,14 +66,14 @@ type RetryEnrichmentOutput struct {
Failures []RetryEnrichmentFailure `json:"failures,omitempty"`
}
func NewEnrichmentRetryer(backgroundCtx context.Context, db *store.DB, provider ai.Provider, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, logger *slog.Logger) *EnrichmentRetryer {
func NewEnrichmentRetryer(backgroundCtx context.Context, db *store.DB, metadataRunner *ai.MetadataRunner, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, logger *slog.Logger) *EnrichmentRetryer {
if backgroundCtx == nil {
backgroundCtx = context.Background()
}
return &EnrichmentRetryer{
backgroundCtx: backgroundCtx,
store: db,
provider: provider,
metadata: metadataRunner,
capture: capture,
sessions: sessions,
metadataTimeout: metadataTimeout,
@@ -190,7 +190,7 @@ func (r *EnrichmentRetryer) retryOne(ctx context.Context, id uuid.UUID) (bool, e
}
attemptedAt := time.Now().UTC()
extracted, extractErr := r.provider.ExtractMetadata(attemptCtx, thought.Content)
extracted, extractErr := r.metadata.ExtractMetadata(attemptCtx, thought.Content)
if extractErr != nil {
failedMetadata := metadata.MarkMetadataFailed(thought.Metadata, r.capture, attemptedAt, extractErr)
if _, updateErr := r.store.UpdateThoughtMetadata(ctx, thought.ID, failedMetadata); updateErr != nil {

View File

@@ -14,7 +14,7 @@ import (
type LinksTool struct {
store *store.DB
provider ai.Provider
embeddings *ai.EmbeddingRunner
search config.SearchConfig
}
@@ -47,8 +47,8 @@ type RelatedOutput struct {
Related []RelatedThought `json:"related"`
}
func NewLinksTool(db *store.DB, provider ai.Provider, search config.SearchConfig) *LinksTool {
return &LinksTool{store: db, provider: provider, search: search}
func NewLinksTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig) *LinksTool {
return &LinksTool{store: db, embeddings: embeddings, search: search}
}
func (t *LinksTool) Link(ctx context.Context, _ *mcp.CallToolRequest, in LinkInput) (*mcp.CallToolResult, LinkOutput, error) {
@@ -117,7 +117,7 @@ func (t *LinksTool) Related(ctx context.Context, _ *mcp.CallToolRequest, in Rela
}
if includeSemantic {
semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, thought.Content, t.search.DefaultLimit, t.search.DefaultThreshold, thought.ProjectID, &thought.ID)
semantic, err := semanticSearch(ctx, t.store, t.embeddings, t.search, thought.Content, t.search.DefaultLimit, t.search.DefaultThreshold, thought.ProjectID, &thought.ID)
if err != nil {
return nil, RelatedOutput{}, err
}

View File

@@ -23,7 +23,7 @@ const metadataRetryConcurrency = 4
type MetadataRetryer struct {
backgroundCtx context.Context
store *store.DB
provider ai.Provider
metadata *ai.MetadataRunner
capture config.CaptureConfig
sessions *session.ActiveProjects
metadataTimeout time.Duration
@@ -87,14 +87,14 @@ type RetryMetadataOutput struct {
Failures []RetryMetadataFailure `json:"failures,omitempty"`
}
func NewMetadataRetryer(backgroundCtx context.Context, db *store.DB, provider ai.Provider, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, logger *slog.Logger) *MetadataRetryer {
func NewMetadataRetryer(backgroundCtx context.Context, db *store.DB, metadataRunner *ai.MetadataRunner, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, logger *slog.Logger) *MetadataRetryer {
if backgroundCtx == nil {
backgroundCtx = context.Background()
}
return &MetadataRetryer{
backgroundCtx: backgroundCtx,
store: db,
provider: provider,
metadata: metadataRunner,
capture: capture,
sessions: sessions,
metadataTimeout: metadataTimeout,
@@ -223,7 +223,7 @@ func (r *MetadataRetryer) retryOne(ctx context.Context, id uuid.UUID) (bool, err
}
attemptedAt := time.Now().UTC()
extracted, extractErr := r.provider.ExtractMetadata(attemptCtx, thought.Content)
extracted, extractErr := r.metadata.ExtractMetadata(attemptCtx, thought.Content)
if extractErr != nil {
failedMetadata := metadata.MarkMetadataFailed(thought.Metadata, r.capture, attemptedAt, extractErr)
if _, updateErr := r.store.UpdateThoughtMetadata(ctx, thought.ID, failedMetadata); updateErr != nil {

View File

@@ -16,7 +16,7 @@ import (
type RecallTool struct {
store *store.DB
provider ai.Provider
embeddings *ai.EmbeddingRunner
search config.SearchConfig
sessions *session.ActiveProjects
}
@@ -32,8 +32,8 @@ type RecallOutput struct {
Items []ContextItem `json:"items"`
}
func NewRecallTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *RecallTool {
return &RecallTool{store: db, provider: provider, search: search, sessions: sessions}
func NewRecallTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig, sessions *session.ActiveProjects) *RecallTool {
return &RecallTool{store: db, embeddings: embeddings, search: search, sessions: sessions}
}
func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in RecallInput) (*mcp.CallToolResult, RecallOutput, error) {
@@ -54,7 +54,7 @@ func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Re
projectID = &project.ID
}
semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
semantic, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
if err != nil {
return nil, RecallOutput{}, err
}

View File

@@ -23,7 +23,7 @@ const metadataReparseConcurrency = 4
type ReparseMetadataTool struct {
store *store.DB
provider ai.Provider
metadata *ai.MetadataRunner
capture config.CaptureConfig
sessions *session.ActiveProjects
logger *slog.Logger
@@ -53,8 +53,8 @@ type ReparseMetadataOutput struct {
Failures []ReparseMetadataFailure `json:"failures,omitempty"`
}
func NewReparseMetadataTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, sessions *session.ActiveProjects, logger *slog.Logger) *ReparseMetadataTool {
return &ReparseMetadataTool{store: db, provider: provider, capture: capture, sessions: sessions, logger: logger}
func NewReparseMetadataTool(db *store.DB, metadataRunner *ai.MetadataRunner, capture config.CaptureConfig, sessions *session.ActiveProjects, logger *slog.Logger) *ReparseMetadataTool {
return &ReparseMetadataTool{store: db, metadata: metadataRunner, capture: capture, sessions: sessions, logger: logger}
}
func (t *ReparseMetadataTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in ReparseMetadataInput) (*mcp.CallToolResult, ReparseMetadataOutput, error) {
@@ -107,7 +107,7 @@ func (t *ReparseMetadataTool) Handle(ctx context.Context, req *mcp.CallToolReque
normalizedCurrent := metadata.Normalize(thought.Metadata, t.capture)
attemptedAt := time.Now().UTC()
extracted, extractErr := t.provider.ExtractMetadata(ctx, thought.Content)
extracted, extractErr := t.metadata.ExtractMetadata(ctx, thought.Content)
normalizedTarget := normalizedCurrent
if extractErr != nil {
normalizedTarget = metadata.MarkMetadataFailed(normalizedCurrent, t.capture, attemptedAt, extractErr)

View File

@@ -11,12 +11,14 @@ import (
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
)
// semanticSearch runs vector similarity search if embeddings exist for the active model
// in the given scope, otherwise falls back to Postgres full-text search.
// semanticSearch runs vector similarity search if embeddings exist for the
// primary embedding model in the given scope, otherwise falls back to Postgres
// full-text search. Search always uses the primary model so query vectors
// match rows stored under the primary model name.
func semanticSearch(
ctx context.Context,
db *store.DB,
provider ai.Provider,
embeddings *ai.EmbeddingRunner,
search config.SearchConfig,
query string,
limit int,
@@ -24,17 +26,18 @@ func semanticSearch(
projectID *uuid.UUID,
excludeID *uuid.UUID,
) ([]thoughttypes.SearchResult, error) {
hasEmbeddings, err := db.HasEmbeddingsForModel(ctx, provider.EmbeddingModel(), projectID)
model := embeddings.PrimaryModel()
hasEmbeddings, err := db.HasEmbeddingsForModel(ctx, model, projectID)
if err != nil {
return nil, err
}
if hasEmbeddings {
embedding, err := provider.Embed(ctx, query)
embedding, err := embeddings.EmbedPrimary(ctx, query)
if err != nil {
return nil, err
}
return db.SearchSimilarThoughts(ctx, embedding, provider.EmbeddingModel(), threshold, limit, projectID, excludeID)
return db.SearchSimilarThoughts(ctx, embedding, model, threshold, limit, projectID, excludeID)
}
return db.SearchThoughtsText(ctx, query, limit, projectID, excludeID)

View File

@@ -16,7 +16,7 @@ import (
type SearchTool struct {
store *store.DB
provider ai.Provider
embeddings *ai.EmbeddingRunner
search config.SearchConfig
sessions *session.ActiveProjects
}
@@ -32,8 +32,8 @@ type SearchOutput struct {
Results []thoughttypes.SearchResult `json:"results"`
}
func NewSearchTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *SearchTool {
return &SearchTool{store: db, provider: provider, search: search, sessions: sessions}
func NewSearchTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig, sessions *session.ActiveProjects) *SearchTool {
return &SearchTool{store: db, embeddings: embeddings, search: search, sessions: sessions}
}
func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in SearchInput) (*mcp.CallToolResult, SearchOutput, error) {
@@ -56,7 +56,7 @@ func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Se
_ = t.store.TouchProject(ctx, project.ID)
}
results, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, threshold, projectID, nil)
results, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, threshold, projectID, nil)
if err != nil {
return nil, SearchOutput{}, err
}

View File

@@ -15,7 +15,8 @@ import (
type SummarizeTool struct {
store *store.DB
provider ai.Provider
embeddings *ai.EmbeddingRunner
metadata *ai.MetadataRunner
search config.SearchConfig
sessions *session.ActiveProjects
}
@@ -32,8 +33,8 @@ type SummarizeOutput struct {
Count int `json:"count"`
}
func NewSummarizeTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *SummarizeTool {
return &SummarizeTool{store: db, provider: provider, search: search, sessions: sessions}
func NewSummarizeTool(db *store.DB, embeddings *ai.EmbeddingRunner, metadata *ai.MetadataRunner, search config.SearchConfig, sessions *session.ActiveProjects) *SummarizeTool {
return &SummarizeTool{store: db, embeddings: embeddings, metadata: metadata, search: search, sessions: sessions}
}
func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in SummarizeInput) (*mcp.CallToolResult, SummarizeOutput, error) {
@@ -52,7 +53,7 @@ func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
if project != nil {
projectID = &project.ID
}
results, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
results, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
if err != nil {
return nil, SummarizeOutput{}, err
}
@@ -77,7 +78,7 @@ func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
userPrompt := formatContextBlock("Summarize the following thoughts into concise prose with themes, action items, and notable people.", lines)
systemPrompt := "You summarize note collections. Be concise, concrete, and structured in plain prose."
summary, err := t.provider.Summarize(ctx, systemPrompt, userPrompt)
summary, err := t.metadata.Summarize(ctx, systemPrompt, userPrompt)
if err != nil {
return nil, SummarizeOutput{}, err
}

View File

@@ -17,7 +17,8 @@ import (
type UpdateTool struct {
store *store.DB
provider ai.Provider
embeddings *ai.EmbeddingRunner
metadata *ai.MetadataRunner
capture config.CaptureConfig
log *slog.Logger
}
@@ -33,8 +34,8 @@ type UpdateOutput struct {
Thought thoughttypes.Thought `json:"thought"`
}
func NewUpdateTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, log *slog.Logger) *UpdateTool {
return &UpdateTool{store: db, provider: provider, capture: capture, log: log}
func NewUpdateTool(db *store.DB, embeddings *ai.EmbeddingRunner, metadata *ai.MetadataRunner, capture config.CaptureConfig, log *slog.Logger) *UpdateTool {
return &UpdateTool{store: db, embeddings: embeddings, metadata: metadata, capture: capture, log: log}
}
func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in UpdateInput) (*mcp.CallToolResult, UpdateOutput, error) {
@@ -50,6 +51,7 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda
content := current.Content
var embedding []float32
embeddingModel := ""
mergedMetadata := current.Metadata
projectID := current.ProjectID
@@ -58,11 +60,13 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda
if content == "" {
return nil, UpdateOutput{}, errInvalidInput("content must not be empty")
}
embedding, err = t.provider.Embed(ctx, content)
embedResult, err := t.embeddings.Embed(ctx, content)
if err != nil {
return nil, UpdateOutput{}, err
}
extracted, extractErr := t.provider.ExtractMetadata(ctx, content)
embedding = embedResult.Vector
embeddingModel = embedResult.Model
extracted, extractErr := t.metadata.ExtractMetadata(ctx, content)
if extractErr != nil {
t.log.Warn("metadata extraction failed during update, keeping current metadata", slog.String("error", extractErr.Error()))
mergedMetadata = metadata.MarkMetadataFailed(mergedMetadata, t.capture, time.Now().UTC(), extractErr)
@@ -82,7 +86,7 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda
projectID = &project.ID
}
updated, err := t.store.UpdateThought(ctx, id, content, embedding, t.provider.EmbeddingModel(), mergedMetadata, projectID)
updated, err := t.store.UpdateThought(ctx, id, content, embedding, embeddingModel, mergedMetadata, projectID)
if err != nil {
return nil, UpdateOutput{}, err
}