From 14e218d78405b9d9f32bb2dbb6b7304dbc2836bc Mon Sep 17 00:00:00 2001 From: Hein Date: Tue, 21 Apr 2026 21:14:28 +0200 Subject: [PATCH] test(config): add migration tests for litellm provider * Implement tests for migrating configuration from v1 to v2 for the litellm provider. * Validate the structure and values of the migrated configuration. * Ensure migration rejects newer versions of the configuration. fix(validate): enhance AI provider validation logic * Consolidate provider validation into a dedicated method. * Ensure at least one provider is specified and validate its type. * Check for required fields based on provider type. fix(mcpserver): update tool set to use new enrichment tool * Replace RetryMetadataTool with RetryEnrichmentTool in the ToolSet. fix(tools): refactor tools to use embedding and metadata runners * Update tools to utilize EmbeddingRunner and MetadataRunner instead of Provider. * Adjust method calls to align with the new runner interfaces. --- .gitignore | 1 + README.md | 48 +- changelog.md | 80 ++++ cmd/amcs-migrate-config/main.go | 105 +++++ configs/config.example.yaml | 73 ++- internal/ai/compat/client.go | 436 +++++------------- internal/ai/compat/client_test.go | 138 ++---- internal/ai/factory.go | 25 - internal/ai/factory_test.go | 33 -- internal/ai/litellm/client.go | 30 -- internal/ai/ollama/client.go | 26 -- internal/ai/openrouter/client.go | 37 -- internal/ai/provider.go | 15 - internal/ai/registry.go | 96 ++++ internal/ai/registry_test.go | 80 ++++ internal/ai/runner.go | 367 +++++++++++++++ internal/ai/runner_test.go | 139 ++++++ internal/app/app.go | 75 ++- internal/config/config.go | 144 +++--- internal/config/loader.go | 111 ++++- internal/config/loader_test.go | 125 ++++- internal/config/migrate.go | 341 ++++++++++++++ internal/config/migrate_test.go | 77 ++++ internal/config/validate.go | 92 ++-- internal/config/validate_test.go | 74 +-- internal/mcpserver/server.go | 2 +- .../mcpserver/streamable_integration_test.go | 2 +- internal/tools/backfill.go | 34 +- internal/tools/capture.go | 25 +- internal/tools/context.go | 14 +- internal/tools/enrichment_retry.go | 8 +- internal/tools/links.go | 12 +- internal/tools/metadata_retry.go | 8 +- internal/tools/recall.go | 14 +- internal/tools/reparse_metadata.go | 8 +- internal/tools/retrieval.go | 15 +- internal/tools/search.go | 14 +- internal/tools/summarize.go | 17 +- internal/tools/update.go | 22 +- 39 files changed, 2062 insertions(+), 901 deletions(-) create mode 100644 changelog.md create mode 100644 cmd/amcs-migrate-config/main.go delete mode 100644 internal/ai/factory.go delete mode 100644 internal/ai/factory_test.go delete mode 100644 internal/ai/litellm/client.go delete mode 100644 internal/ai/ollama/client.go delete mode 100644 internal/ai/openrouter/client.go delete mode 100644 internal/ai/provider.go create mode 100644 internal/ai/registry.go create mode 100644 internal/ai/registry_test.go create mode 100644 internal/ai/runner.go create mode 100644 internal/ai/runner_test.go create mode 100644 internal/config/migrate.go create mode 100644 internal/config/migrate_test.go diff --git a/.gitignore b/.gitignore index e1f4b2a..11d4489 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,4 @@ OB1/ ui/node_modules/ ui/.svelte-kit/ internal/app/ui/dist/ +.codex diff --git a/README.md b/README.md index 75d2ba9..2f626c8 100644 --- a/README.md +++ b/README.md @@ -244,12 +244,24 @@ Link existing skills and guardrails to a project so they are automatically avail Config is YAML-driven. Copy `configs/config.example.yaml` and set: - `database.url` — Postgres connection string -- `auth.mode` — `api_keys` or `oauth_client_credentials` -- `auth.keys` — API keys for MCP access via `x-brain-key` or `Authorization: Bearer ` when `auth.mode=api_keys` -- `auth.oauth.clients` — client registry when `auth.mode=oauth_client_credentials` +- `auth.keys` — static API keys for MCP access via `x-brain-key` or `Authorization: Bearer ` +- `auth.oauth.clients` — optional OAuth client credentials registry +- `ai.providers` — named provider definitions (`litellm`, `ollama`, `openrouter`) +- `ai.embeddings.primary` / `ai.metadata.primary` — primary role targets (`provider` + `model`) +- `ai.embeddings.fallbacks` / `ai.metadata.fallbacks` — sequential fallback targets - `mcp.version` is build-generated and should not be set in config -**OAuth Client Credentials flow** (`auth.mode=oauth_client_credentials`): +Config schema is versioned. Current schema version is `2`. + +Use the migration helper to rewrite legacy configs in-place: + +```bash +go run ./cmd/amcs-migrate-config --config ./configs/dev.yaml +``` + +Use `--dry-run` to print migrated YAML without writing. + +**OAuth Client Credentials flow**: 1. Obtain a token — `POST /oauth/token` (public, no auth required): ``` @@ -267,8 +279,9 @@ Config is YAML-driven. Copy `configs/config.example.yaml` and set: ``` Alternatively, pass `client_id` and `client_secret` as body parameters instead of `Authorization: Basic`. Direct `Authorization: Basic` credential validation on the MCP endpoint is also supported as a fallback (no token required). -- `ai.litellm.base_url` and `ai.litellm.api_key` — LiteLLM proxy -- `ai.ollama.base_url` and `ai.ollama.api_key` — Ollama local or remote server +- `AMCS_LITELLM_BASE_URL` / `AMCS_LITELLM_API_KEY` override all configured LiteLLM providers +- `AMCS_OLLAMA_BASE_URL` / `AMCS_OLLAMA_API_KEY` override all configured Ollama providers +- `AMCS_OPENROUTER_API_KEY` overrides all configured OpenRouter providers See `llm/plan.md` for an audited high-level status summary of the original implementation plan, and `llm/todo.md` for the audited backfill/fallback follow-up status. @@ -643,27 +656,32 @@ Notes: ## Ollama -Set `ai.provider: "ollama"` to use a local or self-hosted Ollama server through its OpenAI-compatible API. +Set your role targets to an Ollama provider to use a local or self-hosted Ollama server through its OpenAI-compatible API. Example: ```yaml ai: - provider: "ollama" + providers: + local: + type: "ollama" + base_url: "http://localhost:11434/v1" + api_key: "ollama" + request_headers: {} embeddings: - model: "nomic-embed-text" dimensions: 768 + primary: + provider: "local" + model: "nomic-embed-text" metadata: - model: "llama3.2" temperature: 0.1 - ollama: - base_url: "http://localhost:11434/v1" - api_key: "ollama" - request_headers: {} + primary: + provider: "local" + model: "llama3.2" ``` Notes: -- For remote Ollama servers, point `ai.ollama.base_url` at the remote `/v1` endpoint. +- For remote Ollama servers, point `ai.providers..base_url` at the remote `/v1` endpoint. - The client always sends Bearer auth; Ollama ignores it locally, so `api_key: "ollama"` is a safe default. - `ai.embeddings.dimensions` must match the embedding model you actually use, or startup will fail the database vector-dimension check. diff --git a/changelog.md b/changelog.md new file mode 100644 index 0000000..5e64c3a --- /dev/null +++ b/changelog.md @@ -0,0 +1,80 @@ +# Changelog + +## 2026-04-21 + +### 2026-04-21 21h - Config Schema v2 Introduced + +- Refactored configuration to schema version `2` with named AI providers and role-based model chains. +- Added support for per-role primary and fallback targets for embeddings and metadata. +- Added optional background role overrides for backfill and metadata retry workers. + +### 2026-04-21 21h - Automatic v1 -> v2 Migration + +- Added config migration framework with explicit schema versioning. +- Implemented `v1 -> v2` migration to transform legacy provider blocks into named providers + role chains. +- Loader now auto-migrates older config files, rewrites migrated YAML, and creates timestamped backups. + +### 2026-04-21 21h - AI Registry and Role Runners + +- Added `ai.Registry` to build provider clients from named provider config entries. +- Added `EmbeddingRunner` and `MetadataRunner` with sequential fallback execution. +- Added target health tracking with cooldowns for transient/permanent/empty-response failures. + +### 2026-04-21 21h - App and Tool Wiring Updates + +- Rewired app startup to use provider registry + role runners for foreground and background flows. +- Updated capture, search, summarize, context, recall, backfill, metadata retry, and reparse paths to use new runners. +- Preserved environment override behavior for provider credentials/endpoints across matching provider types. + +### 2026-04-21 21h - Migrate Config CLI Added + +- Added `cmd/amcs-migrate-config` CLI to migrate config files to the current schema version. +- Supports dry-run output and in-place write mode with automatic backup file creation. + +### 2026-04-21 21h - Tests and Documentation Updated + +- Added focused tests for config migration, AI registry behavior, and runner fallback behavior. +- Updated `configs/config.example.yaml` to the new v2 schema. +- Updated README configuration sections and migration guidance to reflect v2 and `amcs-migrate-config` usage. + +### 2026-04-21 21h - Uncommitted File Change List + +- Modified: `.gitignore` +- Modified: `README.md` +- Modified: `configs/config.example.yaml` +- Modified: `internal/ai/compat/client.go` +- Modified: `internal/ai/compat/client_test.go` +- Modified: `internal/app/app.go` +- Modified: `internal/config/config.go` +- Modified: `internal/config/loader.go` +- Modified: `internal/config/loader_test.go` +- Modified: `internal/config/validate.go` +- Modified: `internal/config/validate_test.go` +- Modified: `internal/mcpserver/server.go` +- Modified: `internal/mcpserver/streamable_integration_test.go` +- Modified: `internal/tools/backfill.go` +- Modified: `internal/tools/capture.go` +- Modified: `internal/tools/context.go` +- Modified: `internal/tools/enrichment_retry.go` +- Modified: `internal/tools/links.go` +- Modified: `internal/tools/metadata_retry.go` +- Modified: `internal/tools/recall.go` +- Modified: `internal/tools/reparse_metadata.go` +- Modified: `internal/tools/retrieval.go` +- Modified: `internal/tools/search.go` +- Modified: `internal/tools/summarize.go` +- Modified: `internal/tools/update.go` +- Deleted: `internal/ai/factory.go` +- Deleted: `internal/ai/factory_test.go` +- Deleted: `internal/ai/litellm/client.go` +- Deleted: `internal/ai/ollama/client.go` +- Deleted: `internal/ai/openrouter/client.go` +- Deleted: `internal/ai/provider.go` +- New: `changelog.md` +- New: `cmd/amcs-migrate-config/main.go` +- New: `internal/ai/registry.go` +- New: `internal/ai/registry_test.go` +- New: `internal/ai/runner.go` +- New: `internal/ai/runner_test.go` +- New: `internal/config/migrate.go` +- New: `internal/config/migrate_test.go` diff --git a/cmd/amcs-migrate-config/main.go b/cmd/amcs-migrate-config/main.go new file mode 100644 index 0000000..53d7b66 --- /dev/null +++ b/cmd/amcs-migrate-config/main.go @@ -0,0 +1,105 @@ +package main + +import ( + "flag" + "fmt" + "log" + "os" + "time" + + "gopkg.in/yaml.v3" + + "git.warky.dev/wdevs/amcs/internal/config" +) + +func main() { + var ( + configPath string + dryRun bool + toVersion int + ) + flag.StringVar(&configPath, "config", "", "Path to the YAML config file (default: $AMCS_CONFIG or ./configs/dev.yaml)") + flag.BoolVar(&dryRun, "dry-run", false, "Print the migrated config to stdout instead of writing it back") + flag.IntVar(&toVersion, "to-version", config.CurrentConfigVersion, "Stop migrating after reaching this version") + flag.Parse() + + if toVersion <= 0 || toVersion > config.CurrentConfigVersion { + log.Fatalf("invalid -to-version %d (must be between 1 and %d)", toVersion, config.CurrentConfigVersion) + } + + path := config.ResolvePath(configPath) + original, err := os.ReadFile(path) + if err != nil { + log.Fatalf("read config %q: %v", path, err) + } + + raw := map[string]any{} + if err := yaml.Unmarshal(original, &raw); err != nil { + log.Fatalf("decode config %q: %v", path, err) + } + if raw == nil { + raw = map[string]any{} + } + + applied, err := migrateUpTo(raw, toVersion) + if err != nil { + log.Fatalf("migrate: %v", err) + } + + if len(applied) == 0 { + fmt.Fprintf(os.Stderr, "%s already at version %d; nothing to do\n", path, currentVersion(raw)) + return + } + + out, err := yaml.Marshal(raw) + if err != nil { + log.Fatalf("marshal migrated config: %v", err) + } + + for _, step := range applied { + fmt.Fprintf(os.Stderr, "applied migration v%d -> v%d: %s\n", step.From, step.To, step.Describe) + } + + if dryRun { + _, _ = os.Stdout.Write(out) + return + } + + backup := fmt.Sprintf("%s.bak.%d", path, time.Now().Unix()) + if err := os.WriteFile(backup, original, 0o600); err != nil { + log.Fatalf("write backup %q: %v", backup, err) + } + if err := os.WriteFile(path, out, 0o600); err != nil { + log.Fatalf("write migrated config %q: %v", path, err) + } + fmt.Fprintf(os.Stderr, "wrote migrated config to %s (backup: %s)\n", path, backup) +} + +// migrateUpTo runs the migration ladder but stops at the requested version. +func migrateUpTo(raw map[string]any, target int) ([]config.ConfigMigration, error) { + if currentVersion(raw) >= target { + return nil, nil + } + if target == config.CurrentConfigVersion { + return config.Migrate(raw) + } + // Partial migrations are rare; for now reject anything other than the + // current version target since the migration ladder is short. + return nil, fmt.Errorf("partial migration to v%d is not supported (use -to-version=%d)", target, config.CurrentConfigVersion) +} + +func currentVersion(raw map[string]any) int { + v, ok := raw["version"] + if !ok { + return 1 + } + switch n := v.(type) { + case int: + return n + case int64: + return int(n) + case float64: + return int(n) + } + return 1 +} diff --git a/configs/config.example.yaml b/configs/config.example.yaml index abd34dc..d13fb5d 100644 --- a/configs/config.example.yaml +++ b/configs/config.example.yaml @@ -1,3 +1,5 @@ +version: 2 + server: host: "0.0.0.0" port: 8080 @@ -27,7 +29,7 @@ auth: - id: "oauth-client" client_id: "" client_secret: "" - description: "used when auth.mode=oauth_client_credentials" + description: "optional OAuth client credentials" database: url: "postgres://postgres:postgres@localhost:5432/amcs?sslmode=disable" @@ -37,33 +39,58 @@ database: max_conn_idle_time: "10m" ai: - provider: "litellm" + providers: + default: + type: "litellm" + base_url: "http://localhost:4000/v1" + api_key: "replace-me" + request_headers: {} + + ollama_local: + type: "ollama" + base_url: "http://localhost:11434/v1" + api_key: "ollama" + request_headers: {} + + openrouter: + type: "openrouter" + base_url: "https://openrouter.ai/api/v1" + api_key: "replace-me" + app_name: "amcs" + site_url: "" + request_headers: {} + embeddings: - model: "openai/text-embedding-3-small" dimensions: 1536 + primary: + provider: "default" + model: "openai/text-embedding-3-small" + fallbacks: + - provider: "ollama_local" + model: "nomic-embed-text" + metadata: - model: "gpt-4o-mini" - fallback_models: [] temperature: 0.1 log_conversations: false - litellm: - base_url: "http://localhost:4000/v1" - api_key: "replace-me" - use_responses_api: false - request_headers: {} - embedding_model: "openrouter/openai/text-embedding-3-small" - metadata_model: "gpt-4o-mini" - fallback_metadata_models: [] - ollama: - base_url: "http://localhost:11434/v1" - api_key: "ollama" - request_headers: {} - openrouter: - base_url: "https://openrouter.ai/api/v1" - api_key: "" - app_name: "amcs" - site_url: "" - extra_headers: {} + timeout: "10s" + primary: + provider: "default" + model: "gpt-4o-mini" + fallbacks: + - provider: "openrouter" + model: "openai/gpt-4.1-mini" + + # Optional overrides for background jobs (backfill_embeddings, + # retry_failed_metadata, reparse_thought_metadata). + background: + embeddings: + primary: + provider: "default" + model: "openai/text-embedding-3-small" + metadata: + primary: + provider: "default" + model: "gpt-4o-mini" capture: source: "mcp" diff --git a/internal/ai/compat/client.go b/internal/ai/compat/client.go index 37c55f6..a8b129b 100644 --- a/internal/ai/compat/client.go +++ b/internal/ai/compat/client.go @@ -14,7 +14,6 @@ import ( "regexp" "slices" "strings" - "sync" "time" thoughttypes "git.warky.dev/wdevs/amcs/internal/types" @@ -36,36 +35,39 @@ Rules: - If unsure, prefer "observation". - Do not include any text outside the JSON object.` +// Client is a low-level OpenAI-compatible HTTP client. It knows nothing about +// role chains, fallbacks, or health — those concerns belong to ai.Runner. Each +// method takes the model name per-call so a single Client instance can service +// many different models on the same base URL. type Client struct { - name string - baseURL string - apiKey string - embeddingModel string - metadataModel string - fallbackMetadataModels []string - temperature float64 - headers map[string]string - httpClient *http.Client - log *slog.Logger - dimensions int - logConversations bool - modelHealthMu sync.Mutex - modelHealth map[string]modelHealthState + name string + baseURL string + apiKey string + headers map[string]string + httpClient *http.Client + log *slog.Logger } type Config struct { - Name string - BaseURL string - APIKey string - EmbeddingModel string - MetadataModel string - FallbackMetadataModels []string - Temperature float64 - Headers map[string]string - HTTPClient *http.Client - Log *slog.Logger - Dimensions int - LogConversations bool + Name string + BaseURL string + APIKey string + Headers map[string]string + HTTPClient *http.Client + Log *slog.Logger +} + +// MetadataOptions control a single ExtractMetadataWith call. +type MetadataOptions struct { + Model string + Temperature float64 + LogConversations bool +} + +// SummarizeOptions control a single SummarizeWith call. +type SummarizeOptions struct { + Model string + Temperature float64 } type embeddingsRequest struct { @@ -127,65 +129,38 @@ type providerError struct { const maxMetadataAttempts = 3 -const ( - emptyResponseCircuitThreshold = 3 - emptyResponseCircuitTTL = 5 * time.Minute - permanentModelFailureTTL = 24 * time.Hour -) - +// ErrEmptyResponse and ErrNoJSONObject are sentinel errors callers can inspect +// to classify metadata failures (e.g. bump empty-response health counters). var ( - errMetadataEmptyResponse = errors.New("metadata empty response") - errMetadataNoJSONObject = errors.New("metadata response contains no JSON object") + ErrEmptyResponse = errors.New("metadata empty response") + ErrNoJSONObject = errors.New("metadata response contains no JSON object") ) -type modelHealthState struct { - consecutiveEmpty int - unhealthyUntil time.Time -} - func New(cfg Config) *Client { - fallbacks := make([]string, 0, len(cfg.FallbackMetadataModels)) - seen := make(map[string]struct{}, len(cfg.FallbackMetadataModels)) - for _, model := range cfg.FallbackMetadataModels { - model = strings.TrimSpace(model) - if model == "" { - continue - } - if _, ok := seen[model]; ok { - continue - } - seen[model] = struct{}{} - fallbacks = append(fallbacks, model) - } - return &Client{ - name: cfg.Name, - baseURL: cfg.BaseURL, - apiKey: cfg.APIKey, - embeddingModel: cfg.EmbeddingModel, - metadataModel: cfg.MetadataModel, - fallbackMetadataModels: fallbacks, - temperature: cfg.Temperature, - headers: cfg.Headers, - httpClient: cfg.HTTPClient, - log: cfg.Log, - dimensions: cfg.Dimensions, - logConversations: cfg.LogConversations, - modelHealth: make(map[string]modelHealthState), + name: cfg.Name, + baseURL: cfg.BaseURL, + apiKey: cfg.APIKey, + headers: cfg.Headers, + httpClient: cfg.HTTPClient, + log: cfg.Log, } } -func (c *Client) Embed(ctx context.Context, input string) ([]float32, error) { +func (c *Client) Name() string { return c.name } + +// EmbedWith generates an embedding for the given input using model. +func (c *Client) EmbedWith(ctx context.Context, model, input string) ([]float32, error) { input = strings.TrimSpace(input) if input == "" { return nil, fmt.Errorf("%s embed: input must not be empty", c.name) } + if strings.TrimSpace(model) == "" { + return nil, fmt.Errorf("%s embed: model is required", c.name) + } var resp embeddingsResponse - err := c.doJSON(ctx, "/embeddings", embeddingsRequest{ - Input: input, - Model: c.embeddingModel, - }, &resp) + err := c.doJSON(ctx, "/embeddings", embeddingsRequest{Input: input, Model: model}, &resp) if err != nil { return nil, err } @@ -195,141 +170,34 @@ func (c *Client) Embed(ctx context.Context, input string) ([]float32, error) { if len(resp.Data) == 0 { return nil, fmt.Errorf("%s embed: no embedding returned", c.name) } - if c.dimensions > 0 && len(resp.Data[0].Embedding) != c.dimensions { - return nil, fmt.Errorf("%s embed: expected %d dimensions, got %d", c.name, c.dimensions, len(resp.Data[0].Embedding)) - } - return resp.Data[0].Embedding, nil } -func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) { +// ExtractMetadataWith extracts structured metadata for input using opts.Model. +// Returns compat.ErrEmptyResponse / ErrNoJSONObject wrapped when the model +// produces unusable output so callers can classify the failure. +func (c *Client) ExtractMetadataWith(ctx context.Context, opts MetadataOptions, input string) (thoughttypes.ThoughtMetadata, error) { input = strings.TrimSpace(input) if input == "" { return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: input must not be empty", c.name) } - - start := time.Now() - if c.log != nil { - c.log.Info("metadata client started", - slog.String("provider", c.name), - slog.String("model", c.metadataModel), - ) - } - - logCompletion := func(model string, err error) { - if c.log == nil { - return - } - - attrs := []any{ - slog.String("provider", c.name), - slog.String("model", model), - slog.String("duration", formatLogDuration(time.Since(start))), - } - if err != nil { - attrs = append(attrs, slog.String("error", err.Error())) - c.log.Error("metadata client completed", attrs...) - return - } - - c.log.Info("metadata client completed", attrs...) - } - - result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel) - if errors.Is(err, errMetadataEmptyResponse) { - c.noteEmptyResponse(c.metadataModel) - } - if isPermanentModelError(err) { - c.notePermanentModelFailure(c.metadataModel, err) - } - if err == nil { - c.noteModelSuccess(c.metadataModel) - logCompletion(c.metadataModel, nil) - return result, nil - } - - for _, fallbackModel := range c.fallbackMetadataModels { - if ctx.Err() != nil { - break - } - if fallbackModel == "" || fallbackModel == c.metadataModel { - continue - } - if c.shouldBypassModel(fallbackModel) { - continue - } - if c.log != nil { - c.log.Warn("metadata extraction failed, trying fallback model", - slog.String("provider", c.name), - slog.String("primary_model", c.metadataModel), - slog.String("fallback_model", fallbackModel), - slog.String("error", err.Error()), - ) - } - fallbackResult, fallbackErr := c.extractMetadataWithModel(ctx, input, fallbackModel) - if errors.Is(fallbackErr, errMetadataEmptyResponse) { - c.noteEmptyResponse(fallbackModel) - } - if isPermanentModelError(fallbackErr) { - c.notePermanentModelFailure(fallbackModel, fallbackErr) - } - if fallbackErr == nil { - c.noteModelSuccess(fallbackModel) - logCompletion(fallbackModel, nil) - return fallbackResult, nil - } - err = fallbackErr - } - - if ctx.Err() != nil { - err = fmt.Errorf("%s metadata: %w", c.name, ctx.Err()) - logCompletion(c.metadataModel, err) - return thoughttypes.ThoughtMetadata{}, err - } - - heuristic := heuristicMetadataFromInput(input) - if c.log != nil { - c.log.Warn("metadata extraction failed for all models, using heuristic fallback", - slog.String("provider", c.name), - slog.String("error", err.Error()), - ) - } - logCompletion(c.metadataModel, nil) - return heuristic, nil -} - -func formatLogDuration(d time.Duration) string { - if d < 0 { - d = -d - } - - totalMilliseconds := d.Milliseconds() - minutes := totalMilliseconds / 60000 - seconds := (totalMilliseconds / 1000) % 60 - milliseconds := totalMilliseconds % 1000 - return fmt.Sprintf("%02d:%02d:%03d", minutes, seconds, milliseconds) -} - -func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) { - if c.shouldBypassModel(model) { - return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: model %q temporarily bypassed after repeated empty responses", c.name, model) + if strings.TrimSpace(opts.Model) == "" { + return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: model is required", c.name) } stream := true req := chatCompletionsRequest{ - Model: model, - Temperature: c.temperature, - ResponseFormat: &responseType{ - Type: "json_object", - }, - Stream: &stream, + Model: opts.Model, + Temperature: opts.Temperature, + ResponseFormat: &responseType{Type: "json_object"}, + Stream: &stream, Messages: []chatMessage{ {Role: "system", Content: metadataSystemPrompt}, {Role: "user", Content: input}, }, } - metadata, err := c.extractMetadataWithRequest(ctx, req, input, model) + metadata, err := c.extractMetadataWithRequest(ctx, req, input, opts) if err == nil || !shouldRetryWithoutJSONMode(err) { return metadata, err } @@ -337,23 +205,22 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri if c.log != nil { c.log.Warn("metadata json mode failed, retrying without response_format", slog.String("provider", c.name), - slog.String("model", model), + slog.String("model", opts.Model), slog.String("error", err.Error()), ) } req.ResponseFormat = nil - return c.extractMetadataWithRequest(ctx, req, input, model) + return c.extractMetadataWithRequest(ctx, req, input, opts) } -func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatCompletionsRequest, input, model string) (thoughttypes.ThoughtMetadata, error) { - +func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatCompletionsRequest, input string, opts MetadataOptions) (thoughttypes.ThoughtMetadata, error) { var lastErr error for attempt := 1; attempt <= maxMetadataAttempts; attempt++ { - if c.logConversations && c.log != nil { + if opts.LogConversations && c.log != nil { c.log.Info("metadata conversation request", slog.String("provider", c.name), - slog.String("model", model), + slog.String("model", opts.Model), slog.Int("attempt", attempt), slog.String("system", metadataSystemPrompt), slog.String("input", input), @@ -373,10 +240,10 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet rawResponse := extractChoiceText(resp.Choices[0].Message, resp.Choices[0].Text) - if c.logConversations && c.log != nil { + if opts.LogConversations && c.log != nil { c.log.Info("metadata conversation response", slog.String("provider", c.name), - slog.String("model", model), + slog.String("model", opts.Model), slog.Int("attempt", attempt), slog.String("response", rawResponse), ) @@ -387,13 +254,13 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet metadataText = stripCodeFence(metadataText) metadataText = extractJSONObject(metadataText) if metadataText == "" { - lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataNoJSONObject) + lastErr = fmt.Errorf("%s metadata: %w", c.name, ErrNoJSONObject) if strings.TrimSpace(rawResponse) == "" && attempt < maxMetadataAttempts && ctx.Err() == nil { - lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataEmptyResponse) + lastErr = fmt.Errorf("%s metadata: %w", c.name, ErrEmptyResponse) if c.log != nil { c.log.Warn("metadata response empty, waiting and retrying", slog.String("provider", c.name), - slog.String("model", model), + slog.String("model", opts.Model), slog.Int("attempt", attempt+1), ) } @@ -403,7 +270,7 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet continue } if strings.TrimSpace(rawResponse) == "" { - lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataEmptyResponse) + lastErr = fmt.Errorf("%s metadata: %w", c.name, ErrEmptyResponse) } return thoughttypes.ThoughtMetadata{}, lastErr } @@ -420,13 +287,17 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet if lastErr != nil { return thoughttypes.ThoughtMetadata{}, lastErr } - return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: %w", c.name, errMetadataNoJSONObject) + return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: %w", c.name, ErrNoJSONObject) } -func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) { +// SummarizeWith runs a chat-completion summarisation using opts.Model. +func (c *Client) SummarizeWith(ctx context.Context, opts SummarizeOptions, systemPrompt, userPrompt string) (string, error) { + if strings.TrimSpace(opts.Model) == "" { + return "", fmt.Errorf("%s summarize: model is required", c.name) + } req := chatCompletionsRequest{ - Model: c.metadataModel, - Temperature: 0.2, + Model: opts.Model, + Temperature: opts.Temperature, Messages: []chatMessage{ {Role: "system", Content: systemPrompt}, {Role: "user", Content: userPrompt}, @@ -447,12 +318,49 @@ func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string) return extractChoiceText(resp.Choices[0].Message, resp.Choices[0].Text), nil } -func (c *Client) Name() string { - return c.name +// IsPermanentModelError reports whether err indicates the model itself is +// invalid or missing (vs. a transient outage). Runners use this to mark a +// target unhealthy for longer. +func IsPermanentModelError(err error) bool { + if err == nil { + return false + } + lower := strings.ToLower(err.Error()) + for _, marker := range []string{ + "invalid model name", + "model_not_found", + "model not found", + "unknown model", + "no such model", + "does not exist", + } { + if strings.Contains(lower, marker) { + return true + } + } + return false } -func (c *Client) EmbeddingModel() string { - return c.embeddingModel +// HeuristicMetadataFromInput produces best-effort metadata from the note text +// when every model in the chain has failed. Exported so ai.Runner can use it. +func HeuristicMetadataFromInput(input string) thoughttypes.ThoughtMetadata { + text := strings.TrimSpace(input) + lower := strings.ToLower(text) + + metadata := thoughttypes.ThoughtMetadata{ + People: heuristicPeople(text), + ActionItems: heuristicActionItems(text), + DatesMentioned: heuristicDates(text), + Topics: heuristicTopics(lower), + Type: heuristicType(lower), + } + if len(metadata.Topics) == 0 { + metadata.Topics = []string{"uncategorized"} + } + if metadata.Type == "" { + metadata.Type = "observation" + } + return metadata } func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest any) error { @@ -724,8 +632,6 @@ func isRetryableChatResponseError(err error) bool { return strings.Contains(lower, "read response") || strings.Contains(lower, "read stream response") } -// extractJSONObject finds the first complete {...} block in s. -// It handles models that prepend prose to a JSON response despite json_object mode. func extractJSONObject(s string) string { for start := 0; start < len(s); start++ { if s[start] != '{' { @@ -768,10 +674,6 @@ func extractJSONObject(s string) string { return "" } -// stripThinkingBlocks removes ... and ... -// blocks produced by reasoning models (DeepSeek R1, QwQ, etc.) so that the -// remaining text can be parsed as JSON without interference from thinking content -// that may itself contain braces. func stripThinkingBlocks(s string) string { for _, tag := range []string{"think", "thinking"} { open := "<" + tag + ">" @@ -857,7 +759,6 @@ func extractTextFromAny(value any) string { } return strings.Join(parts, "\n") case map[string]any: - // Common provider shapes for chat content parts. for _, key := range []string{"text", "output_text", "content", "value"} { if nested, ok := typed[key]; ok { if text := strings.TrimSpace(extractTextFromAny(nested)); text != "" { @@ -875,28 +776,6 @@ var ( wordPattern = regexp.MustCompile(`[a-zA-Z][a-zA-Z0-9_/-]{2,}`) ) -func heuristicMetadataFromInput(input string) thoughttypes.ThoughtMetadata { - text := strings.TrimSpace(input) - lower := strings.ToLower(text) - - metadata := thoughttypes.ThoughtMetadata{ - People: heuristicPeople(text), - ActionItems: heuristicActionItems(text), - DatesMentioned: heuristicDates(text), - Topics: heuristicTopics(lower), - Type: heuristicType(lower), - Source: "", - } - - if len(metadata.Topics) == 0 { - metadata.Topics = []string{"uncategorized"} - } - if metadata.Type == "" { - metadata.Type = "observation" - } - return metadata -} - func heuristicType(lower string) string { switch { case strings.Contains(lower, "preferred name"), strings.Contains(lower, "personal profile"), strings.Contains(lower, "wife:"), strings.Contains(lower, "daughter:"), strings.Contains(lower, "born:"): @@ -1055,7 +934,7 @@ func shouldRetryWithoutJSONMode(err error) bool { if err == nil { return false } - if errors.Is(err, errMetadataEmptyResponse) || errors.Is(err, errMetadataNoJSONObject) { + if errors.Is(err, ErrEmptyResponse) || errors.Is(err, ErrNoJSONObject) { return true } @@ -1063,27 +942,6 @@ func shouldRetryWithoutJSONMode(err error) bool { return strings.Contains(lower, "parse json") } -func isPermanentModelError(err error) bool { - if err == nil { - return false - } - - lower := strings.ToLower(err.Error()) - for _, marker := range []string{ - "invalid model name", - "model_not_found", - "model not found", - "unknown model", - "no such model", - "does not exist", - } { - if strings.Contains(lower, marker) { - return true - } - } - return false -} - func sleepRetry(ctx context.Context, attempt int, log *slog.Logger, provider string) error { delay := time.Duration(attempt*attempt) * 200 * time.Millisecond if log != nil { @@ -1110,59 +968,3 @@ func sleepMetadataRetry(ctx context.Context, attempt int) error { return nil } } - -func (c *Client) shouldBypassModel(model string) bool { - c.modelHealthMu.Lock() - defer c.modelHealthMu.Unlock() - - state, ok := c.modelHealth[model] - if !ok { - return false - } - return !state.unhealthyUntil.IsZero() && time.Now().Before(state.unhealthyUntil) -} - -func (c *Client) noteEmptyResponse(model string) { - c.modelHealthMu.Lock() - defer c.modelHealthMu.Unlock() - - state := c.modelHealth[model] - state.consecutiveEmpty++ - if state.consecutiveEmpty >= emptyResponseCircuitThreshold { - state.unhealthyUntil = time.Now().Add(emptyResponseCircuitTTL) - if c.log != nil { - c.log.Warn("metadata model marked temporarily unhealthy after repeated empty responses", - slog.String("provider", c.name), - slog.String("model", model), - slog.Time("until", state.unhealthyUntil), - ) - } - } - c.modelHealth[model] = state -} - -func (c *Client) noteModelSuccess(model string) { - c.modelHealthMu.Lock() - defer c.modelHealthMu.Unlock() - - delete(c.modelHealth, model) -} - -func (c *Client) notePermanentModelFailure(model string, err error) { - c.modelHealthMu.Lock() - defer c.modelHealthMu.Unlock() - - state := c.modelHealth[model] - state.consecutiveEmpty = emptyResponseCircuitThreshold - state.unhealthyUntil = time.Now().Add(permanentModelFailureTTL) - c.modelHealth[model] = state - - if c.log != nil { - c.log.Warn("metadata model marked unhealthy after permanent failure", - slog.String("provider", c.name), - slog.String("model", model), - slog.String("error", err.Error()), - slog.Time("until", state.unhealthyUntil), - ) - } -} diff --git a/internal/ai/compat/client_test.go b/internal/ai/compat/client_test.go index 5d3c93a..9b627f6 100644 --- a/internal/ai/compat/client_test.go +++ b/internal/ai/compat/client_test.go @@ -11,6 +11,17 @@ import ( "testing" ) +func newTestClient(t *testing.T, url string) *Client { + t.Helper() + return New(Config{ + Name: "litellm", + BaseURL: url, + APIKey: "test-key", + HTTPClient: http.DefaultClient, + Log: slog.New(slog.NewTextHandler(io.Discard, nil)), + }) +} + func TestExtractMetadataFromStreamingResponse(t *testing.T) { t.Parallel() @@ -26,6 +37,9 @@ func TestExtractMetadataFromStreamingResponse(t *testing.T) { if req.Stream == nil || !*req.Stream { t.Fatalf("stream flag = %v, want true", req.Stream) } + if req.Model != "qwen3.5:latest" { + t.Fatalf("model = %q, want qwen3.5:latest", req.Model) + } w.Header().Set("Content-Type", "text/event-stream") _, _ = io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"{\\\"people\\\":[],\"}}]}\n\n") @@ -35,20 +49,13 @@ func TestExtractMetadataFromStreamingResponse(t *testing.T) { })) defer server.Close() - client := New(Config{ - Name: "litellm", - BaseURL: server.URL, - APIKey: "test-key", - MetadataModel: "qwen3.5:latest", - Temperature: 0.1, - HTTPClient: server.Client(), - Log: slog.New(slog.NewTextHandler(io.Discard, nil)), - EmbeddingModel: "unused", - }) - - metadata, err := client.ExtractMetadata(context.Background(), "Project idea: Build an Android companion app.") + client := newTestClient(t, server.URL) + metadata, err := client.ExtractMetadataWith(context.Background(), MetadataOptions{ + Model: "qwen3.5:latest", + Temperature: 0.1, + }, "Project idea: Build an Android companion app.") if err != nil { - t.Fatalf("ExtractMetadata() error = %v", err) + t.Fatalf("ExtractMetadataWith() error = %v", err) } if metadata.Type != "idea" { @@ -94,20 +101,13 @@ func TestExtractMetadataRetriesWithoutJSONMode(t *testing.T) { })) defer server.Close() - client := New(Config{ - Name: "litellm", - BaseURL: server.URL, - APIKey: "test-key", - MetadataModel: "qwen3.5:latest", - Temperature: 0.1, - HTTPClient: server.Client(), - Log: slog.New(slog.NewTextHandler(io.Discard, nil)), - EmbeddingModel: "unused", - }) - - metadata, err := client.ExtractMetadata(context.Background(), "Project idea: Build an Android companion app.") + client := newTestClient(t, server.URL) + metadata, err := client.ExtractMetadataWith(context.Background(), MetadataOptions{ + Model: "qwen3.5:latest", + Temperature: 0.1, + }, "Project idea: Build an Android companion app.") if err != nil { - t.Fatalf("ExtractMetadata() error = %v", err) + t.Fatalf("ExtractMetadataWith() error = %v", err) } if metadata.Type != "idea" { @@ -127,71 +127,33 @@ func TestExtractMetadataRetriesWithoutJSONMode(t *testing.T) { } } -func TestExtractMetadataBypassesInvalidFallbackModelAfterFirstFailure(t *testing.T) { +func TestIsPermanentModelError(t *testing.T) { t.Parallel() - var mu sync.Mutex - primaryCalls := 0 - invalidFallbackCalls := 0 - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - defer func() { - _ = r.Body.Close() - }() - - var req chatCompletionsRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - t.Fatalf("decode request: %v", err) - } - - switch req.Model { - case "empty-primary": - _, _ = io.WriteString(w, `{"choices":[{"message":{"role":"assistant","content":""}}]}`) - case "qwen3.5:latest": - mu.Lock() - primaryCalls++ - mu.Unlock() - _, _ = io.WriteString(w, `{"choices":[{"message":{"role":"assistant","content":"{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"metadata\"],\"type\":\"observation\",\"source\":\"primary\"}"}}]}`) - case "qwen3": - mu.Lock() - invalidFallbackCalls++ - mu.Unlock() - w.WriteHeader(http.StatusBadRequest) - _, _ = io.WriteString(w, "{\"error\":{\"message\":\"{'error': '/chat/completions: Invalid model name passed in model=qwen3. Call `/v1/models` to view available models for your key.'}\"}}") - default: - t.Fatalf("unexpected model %q", req.Model) - } - })) - defer server.Close() - - client := New(Config{ - Name: "litellm", - BaseURL: server.URL, - APIKey: "test-key", - MetadataModel: "empty-primary", - FallbackMetadataModels: []string{"qwen3", "qwen3.5:latest"}, - Temperature: 0.1, - HTTPClient: server.Client(), - Log: slog.New(slog.NewTextHandler(io.Discard, nil)), - EmbeddingModel: "unused", - }) - - for i := 0; i < 2; i++ { - metadata, err := client.ExtractMetadata(context.Background(), "A short note about metadata.") - if err != nil { - t.Fatalf("ExtractMetadata() error = %v", err) - } - if metadata.Source != "primary" { - t.Fatalf("metadata source = %q, want primary", metadata.Source) - } + cases := []struct { + name string + err error + want bool + }{ + {"nil", nil, false}, + {"invalid model", errMsg("Invalid model name passed in model=qwen3"), true}, + {"model not found", errMsg("model_not_found"), true}, + {"no such model", errMsg("no such model"), true}, + {"transient", errMsg("connection refused"), false}, } - mu.Lock() - defer mu.Unlock() - if invalidFallbackCalls != 1 { - t.Fatalf("invalid fallback calls = %d, want 1", invalidFallbackCalls) - } - if primaryCalls != 2 { - t.Fatalf("valid fallback calls = %d, want 2", primaryCalls) + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + if got := IsPermanentModelError(tc.err); got != tc.want { + t.Fatalf("IsPermanentModelError(%v) = %v, want %v", tc.err, got, tc.want) + } + }) } } + +type stringError string + +func (s stringError) Error() string { return string(s) } + +func errMsg(s string) error { return stringError(s) } diff --git a/internal/ai/factory.go b/internal/ai/factory.go deleted file mode 100644 index b6ee360..0000000 --- a/internal/ai/factory.go +++ /dev/null @@ -1,25 +0,0 @@ -package ai - -import ( - "fmt" - "log/slog" - "net/http" - - "git.warky.dev/wdevs/amcs/internal/ai/litellm" - "git.warky.dev/wdevs/amcs/internal/ai/ollama" - "git.warky.dev/wdevs/amcs/internal/ai/openrouter" - "git.warky.dev/wdevs/amcs/internal/config" -) - -func NewProvider(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (Provider, error) { - switch cfg.Provider { - case "litellm": - return litellm.New(cfg, httpClient, log) - case "ollama": - return ollama.New(cfg, httpClient, log) - case "openrouter": - return openrouter.New(cfg, httpClient, log) - default: - return nil, fmt.Errorf("unsupported ai.provider: %s", cfg.Provider) - } -} diff --git a/internal/ai/factory_test.go b/internal/ai/factory_test.go deleted file mode 100644 index 02d2837..0000000 --- a/internal/ai/factory_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package ai - -import ( - "io" - "log/slog" - "net/http" - "testing" - - "git.warky.dev/wdevs/amcs/internal/config" -) - -func TestNewProviderSupportsOllama(t *testing.T) { - provider, err := NewProvider(config.AIConfig{ - Provider: "ollama", - Embeddings: config.AIEmbeddingConfig{ - Model: "nomic-embed-text", - Dimensions: 768, - }, - Metadata: config.AIMetadataConfig{ - Model: "llama3.2", - }, - Ollama: config.OllamaConfig{ - BaseURL: "http://localhost:11434/v1", - APIKey: "ollama", - }, - }, &http.Client{}, slog.New(slog.NewTextHandler(io.Discard, nil))) - if err != nil { - t.Fatalf("NewProvider() error = %v", err) - } - if provider.Name() != "ollama" { - t.Fatalf("provider name = %q, want ollama", provider.Name()) - } -} diff --git a/internal/ai/litellm/client.go b/internal/ai/litellm/client.go deleted file mode 100644 index 3c9f1b0..0000000 --- a/internal/ai/litellm/client.go +++ /dev/null @@ -1,30 +0,0 @@ -package litellm - -import ( - "log/slog" - "net/http" - - "git.warky.dev/wdevs/amcs/internal/ai/compat" - "git.warky.dev/wdevs/amcs/internal/config" -) - -func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) { - fallbacks := cfg.LiteLLM.EffectiveFallbackMetadataModels() - if len(fallbacks) == 0 { - fallbacks = cfg.Metadata.EffectiveFallbackModels() - } - return compat.New(compat.Config{ - Name: "litellm", - BaseURL: cfg.LiteLLM.BaseURL, - APIKey: cfg.LiteLLM.APIKey, - EmbeddingModel: cfg.LiteLLM.EmbeddingModel, - MetadataModel: cfg.LiteLLM.MetadataModel, - FallbackMetadataModels: fallbacks, - Temperature: cfg.Metadata.Temperature, - Headers: cfg.LiteLLM.RequestHeaders, - HTTPClient: httpClient, - Log: log, - Dimensions: cfg.Embeddings.Dimensions, - LogConversations: cfg.Metadata.LogConversations, - }), nil -} diff --git a/internal/ai/ollama/client.go b/internal/ai/ollama/client.go deleted file mode 100644 index 69abf8e..0000000 --- a/internal/ai/ollama/client.go +++ /dev/null @@ -1,26 +0,0 @@ -package ollama - -import ( - "log/slog" - "net/http" - - "git.warky.dev/wdevs/amcs/internal/ai/compat" - "git.warky.dev/wdevs/amcs/internal/config" -) - -func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) { - return compat.New(compat.Config{ - Name: "ollama", - BaseURL: cfg.Ollama.BaseURL, - APIKey: cfg.Ollama.APIKey, - EmbeddingModel: cfg.Embeddings.Model, - MetadataModel: cfg.Metadata.Model, - FallbackMetadataModels: cfg.Metadata.EffectiveFallbackModels(), - Temperature: cfg.Metadata.Temperature, - Headers: cfg.Ollama.RequestHeaders, - HTTPClient: httpClient, - Log: log, - Dimensions: cfg.Embeddings.Dimensions, - LogConversations: cfg.Metadata.LogConversations, - }), nil -} diff --git a/internal/ai/openrouter/client.go b/internal/ai/openrouter/client.go deleted file mode 100644 index b2fe6d0..0000000 --- a/internal/ai/openrouter/client.go +++ /dev/null @@ -1,37 +0,0 @@ -package openrouter - -import ( - "log/slog" - "net/http" - - "git.warky.dev/wdevs/amcs/internal/ai/compat" - "git.warky.dev/wdevs/amcs/internal/config" -) - -func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) { - headers := make(map[string]string, len(cfg.OpenRouter.ExtraHeaders)+2) - for key, value := range cfg.OpenRouter.ExtraHeaders { - headers[key] = value - } - if cfg.OpenRouter.SiteURL != "" { - headers["HTTP-Referer"] = cfg.OpenRouter.SiteURL - } - if cfg.OpenRouter.AppName != "" { - headers["X-Title"] = cfg.OpenRouter.AppName - } - - return compat.New(compat.Config{ - Name: "openrouter", - BaseURL: cfg.OpenRouter.BaseURL, - APIKey: cfg.OpenRouter.APIKey, - EmbeddingModel: cfg.Embeddings.Model, - MetadataModel: cfg.Metadata.Model, - FallbackMetadataModels: cfg.Metadata.EffectiveFallbackModels(), - Temperature: cfg.Metadata.Temperature, - Headers: headers, - HTTPClient: httpClient, - Log: log, - Dimensions: cfg.Embeddings.Dimensions, - LogConversations: cfg.Metadata.LogConversations, - }), nil -} diff --git a/internal/ai/provider.go b/internal/ai/provider.go deleted file mode 100644 index e547757..0000000 --- a/internal/ai/provider.go +++ /dev/null @@ -1,15 +0,0 @@ -package ai - -import ( - "context" - - thoughttypes "git.warky.dev/wdevs/amcs/internal/types" -) - -type Provider interface { - Embed(ctx context.Context, input string) ([]float32, error) - ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) - Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) - Name() string - EmbeddingModel() string -} diff --git a/internal/ai/registry.go b/internal/ai/registry.go new file mode 100644 index 0000000..e82b724 --- /dev/null +++ b/internal/ai/registry.go @@ -0,0 +1,96 @@ +package ai + +import ( + "fmt" + "log/slog" + "net/http" + "strings" + + "git.warky.dev/wdevs/amcs/internal/ai/compat" + "git.warky.dev/wdevs/amcs/internal/config" +) + +// Registry holds one compat.Client per named provider. Runners look up clients +// by provider name when walking a role chain. +type Registry struct { + clients map[string]*compat.Client +} + +// NewRegistry builds a Registry from the configured providers. Each provider +// type maps onto a compat.Client with type-specific header plumbing (e.g. +// openrouter's HTTP-Referer / X-Title). +func NewRegistry(providers map[string]config.ProviderConfig, httpClient *http.Client, log *slog.Logger) (*Registry, error) { + if httpClient == nil { + return nil, fmt.Errorf("ai registry: http client is required") + } + if len(providers) == 0 { + return nil, fmt.Errorf("ai registry: no providers configured") + } + + clients := make(map[string]*compat.Client, len(providers)) + for name, p := range providers { + headers, err := providerHeaders(p) + if err != nil { + return nil, fmt.Errorf("ai registry: provider %q: %w", name, err) + } + clients[name] = compat.New(compat.Config{ + Name: name, + BaseURL: p.BaseURL, + APIKey: p.APIKey, + Headers: headers, + HTTPClient: httpClient, + Log: log, + }) + } + return &Registry{clients: clients}, nil +} + +// Client returns the compat.Client registered under name. +func (r *Registry) Client(name string) (*compat.Client, error) { + c, ok := r.clients[name] + if !ok { + return nil, fmt.Errorf("ai registry: provider %q is not configured", name) + } + return c, nil +} + +// Names returns the registered provider names. +func (r *Registry) Names() []string { + names := make([]string, 0, len(r.clients)) + for name := range r.clients { + names = append(names, name) + } + return names +} + +func providerHeaders(p config.ProviderConfig) (map[string]string, error) { + switch p.Type { + case "litellm", "ollama": + return cloneHeaders(p.RequestHeaders), nil + case "openrouter": + headers := cloneHeaders(p.RequestHeaders) + if headers == nil { + headers = map[string]string{} + } + if s := strings.TrimSpace(p.SiteURL); s != "" { + headers["HTTP-Referer"] = s + } + if s := strings.TrimSpace(p.AppName); s != "" { + headers["X-Title"] = s + } + return headers, nil + default: + return nil, fmt.Errorf("unsupported provider type %q", p.Type) + } +} + +func cloneHeaders(in map[string]string) map[string]string { + if len(in) == 0 { + return nil + } + out := make(map[string]string, len(in)) + for k, v := range in { + out[k] = v + } + return out +} diff --git a/internal/ai/registry_test.go b/internal/ai/registry_test.go new file mode 100644 index 0000000..eaf16fd --- /dev/null +++ b/internal/ai/registry_test.go @@ -0,0 +1,80 @@ +package ai + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "git.warky.dev/wdevs/amcs/internal/ai/compat" + "git.warky.dev/wdevs/amcs/internal/config" +) + +func TestNewRegistryOpenRouterHeaders(t *testing.T) { + var ( + gotReferer string + gotTitle string + gotCustom string + ) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotReferer = r.Header.Get("HTTP-Referer") + gotTitle = r.Header.Get("X-Title") + gotCustom = r.Header.Get("X-Custom") + _ = json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{{"message": map[string]any{"role": "assistant", "content": "ok"}}}, + }) + })) + defer srv.Close() + + providers := map[string]config.ProviderConfig{ + "router": { + Type: "openrouter", + BaseURL: srv.URL, + APIKey: "secret", + RequestHeaders: map[string]string{ + "X-Custom": "value", + }, + AppName: "amcs", + SiteURL: "https://example.com", + }, + } + + reg, err := NewRegistry(providers, srv.Client(), nil) + if err != nil { + t.Fatalf("NewRegistry() error = %v", err) + } + + client, err := reg.Client("router") + if err != nil { + t.Fatalf("Client(router) error = %v", err) + } + + if _, err := client.SummarizeWith(context.Background(), compat.SummarizeOptions{Model: "gpt-4.1-mini"}, "system", "user"); err != nil { + t.Fatalf("SummarizeWith() error = %v", err) + } + if gotReferer != "https://example.com" { + t.Fatalf("HTTP-Referer = %q, want https://example.com", gotReferer) + } + if gotTitle != "amcs" { + t.Fatalf("X-Title = %q, want amcs", gotTitle) + } + if gotCustom != "value" { + t.Fatalf("X-Custom = %q, want value", gotCustom) + } +} + +func TestNewRegistryRejectsUnsupportedProviderType(t *testing.T) { + providers := map[string]config.ProviderConfig{ + "bad": { + Type: "unknown", + BaseURL: "http://localhost:4000/v1", + APIKey: "secret", + }, + } + + _, err := NewRegistry(providers, &http.Client{}, nil) + if err == nil { + t.Fatal("NewRegistry() error = nil, want unsupported provider type error") + } +} diff --git a/internal/ai/runner.go b/internal/ai/runner.go new file mode 100644 index 0000000..e852bde --- /dev/null +++ b/internal/ai/runner.go @@ -0,0 +1,367 @@ +package ai + +import ( + "context" + "errors" + "fmt" + "log/slog" + "sync" + "time" + + "git.warky.dev/wdevs/amcs/internal/ai/compat" + "git.warky.dev/wdevs/amcs/internal/config" + thoughttypes "git.warky.dev/wdevs/amcs/internal/types" +) + +// Health TTLs per failure class. These are short enough that a healed target +// gets retried without manual intervention, but long enough to avoid hammering +// a broken provider every call. +const ( + transientCooldown = 30 * time.Second + permanentCooldown = 10 * time.Minute + emptyResponseThreshold = 3 + emptyResponseCooldown = 2 * time.Minute + dimensionMismatchWarning = "embedding dimension mismatch" +) + +// EmbedResult carries the vector plus the (provider, model) that produced it — +// callers store the actual model so later searches against that row use the +// matching query embedding. +type EmbedResult struct { + Vector []float32 + Provider string + Model string +} + +// EmbeddingRunner executes the embeddings role chain with sequential fallback. +type EmbeddingRunner struct { + registry *Registry + chain []config.RoleTarget + dimensions int + health *healthTracker + log *slog.Logger +} + +// MetadataRunner executes the metadata role chain with sequential fallback and +// a heuristic fallthrough when every target is unhealthy or fails. +type MetadataRunner struct { + registry *Registry + chain []config.RoleTarget + opts metadataRunOpts + health *healthTracker + log *slog.Logger +} + +type metadataRunOpts struct { + temperature float64 + logConversations bool +} + +// NewEmbeddingRunner builds a runner for the embeddings role. chain must be +// non-empty and every target must be registered. +func NewEmbeddingRunner(registry *Registry, chain []config.RoleTarget, dimensions int, log *slog.Logger) (*EmbeddingRunner, error) { + if registry == nil { + return nil, fmt.Errorf("embedding runner: registry is required") + } + if len(chain) == 0 { + return nil, fmt.Errorf("embedding runner: chain is empty") + } + if dimensions <= 0 { + return nil, fmt.Errorf("embedding runner: dimensions must be > 0") + } + for i, t := range chain { + if _, err := registry.Client(t.Provider); err != nil { + return nil, fmt.Errorf("embedding runner: chain[%d]: %w", i, err) + } + } + return &EmbeddingRunner{ + registry: registry, + chain: chain, + dimensions: dimensions, + health: newHealthTracker(), + log: log, + }, nil +} + +// NewMetadataRunner builds a runner for the metadata role. +func NewMetadataRunner(registry *Registry, chain []config.RoleTarget, temperature float64, logConversations bool, log *slog.Logger) (*MetadataRunner, error) { + if registry == nil { + return nil, fmt.Errorf("metadata runner: registry is required") + } + if len(chain) == 0 { + return nil, fmt.Errorf("metadata runner: chain is empty") + } + for i, t := range chain { + if _, err := registry.Client(t.Provider); err != nil { + return nil, fmt.Errorf("metadata runner: chain[%d]: %w", i, err) + } + } + return &MetadataRunner{ + registry: registry, + chain: chain, + opts: metadataRunOpts{ + temperature: temperature, + logConversations: logConversations, + }, + health: newHealthTracker(), + log: log, + }, nil +} + +// PrimaryProvider returns the first provider in the chain. +func (r *EmbeddingRunner) PrimaryProvider() string { return r.chain[0].Provider } + +// PrimaryModel returns the first model in the chain — the one used as the +// storage key for search matching. +func (r *EmbeddingRunner) PrimaryModel() string { return r.chain[0].Model } + +// Dimensions returns the required vector dimension. +func (r *EmbeddingRunner) Dimensions() int { return r.dimensions } + +// Embed walks the chain and returns the first successful embedding. The +// returned EmbedResult names the actual (provider, model) that produced the +// vector — callers use that when recording the row. +func (r *EmbeddingRunner) Embed(ctx context.Context, input string) (EmbedResult, error) { + var errs []error + for _, target := range r.chain { + if r.health.skip(target) { + continue + } + client, err := r.registry.Client(target.Provider) + if err != nil { + errs = append(errs, err) + continue + } + vec, err := client.EmbedWith(ctx, target.Model, input) + if err != nil { + if ctx.Err() != nil { + return EmbedResult{}, ctx.Err() + } + r.classify(target, err) + r.logFailure("embed", target, err) + errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err)) + continue + } + if len(vec) != r.dimensions { + dimErr := fmt.Errorf("%s: expected %d, got %d", dimensionMismatchWarning, r.dimensions, len(vec)) + r.health.markTransient(target) + r.logFailure("embed", target, dimErr) + errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, dimErr)) + continue + } + r.health.markHealthy(target) + return EmbedResult{Vector: vec, Provider: target.Provider, Model: target.Model}, nil + } + return EmbedResult{}, fmt.Errorf("all embedding targets failed: %w", errors.Join(errs...)) +} + +// EmbedPrimary embeds using only the primary target — used for search queries +// so the query vector matches rows stored under the primary model. Falls back +// to returning the error without walking the chain. +func (r *EmbeddingRunner) EmbedPrimary(ctx context.Context, input string) ([]float32, error) { + target := r.chain[0] + client, err := r.registry.Client(target.Provider) + if err != nil { + return nil, err + } + vec, err := client.EmbedWith(ctx, target.Model, input) + if err != nil { + r.classify(target, err) + return nil, err + } + if len(vec) != r.dimensions { + return nil, fmt.Errorf("%s: expected %d, got %d", dimensionMismatchWarning, r.dimensions, len(vec)) + } + r.health.markHealthy(target) + return vec, nil +} + +// PrimaryProvider / PrimaryModel for metadata mirror the embedding runner. +func (r *MetadataRunner) PrimaryProvider() string { return r.chain[0].Provider } +func (r *MetadataRunner) PrimaryModel() string { return r.chain[0].Model } + +// ExtractMetadata walks the chain sequentially. If every target fails or is +// unhealthy, it returns a heuristic metadata so capture never hard-fails. +func (r *MetadataRunner) ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) { + var errs []error + for _, target := range r.chain { + if r.health.skip(target) { + continue + } + client, err := r.registry.Client(target.Provider) + if err != nil { + errs = append(errs, err) + continue + } + md, err := client.ExtractMetadataWith(ctx, compat.MetadataOptions{ + Model: target.Model, + Temperature: r.opts.temperature, + LogConversations: r.opts.logConversations, + }, input) + if err != nil { + if ctx.Err() != nil { + return thoughttypes.ThoughtMetadata{}, ctx.Err() + } + r.classify(target, err) + r.logFailure("metadata", target, err) + errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err)) + continue + } + r.health.markHealthy(target) + return md, nil + } + if r.log != nil { + r.log.Warn("metadata chain exhausted, using heuristic fallback", + slog.Int("targets", len(r.chain)), + slog.String("error", errors.Join(errs...).Error()), + ) + } + return compat.HeuristicMetadataFromInput(input), nil +} + +// Summarize walks the chain; unlike metadata, there is no heuristic fallback — +// returns the joined error when everything fails. +func (r *MetadataRunner) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) { + var errs []error + for _, target := range r.chain { + if r.health.skip(target) { + continue + } + client, err := r.registry.Client(target.Provider) + if err != nil { + errs = append(errs, err) + continue + } + out, err := client.SummarizeWith(ctx, compat.SummarizeOptions{ + Model: target.Model, + Temperature: r.opts.temperature, + }, systemPrompt, userPrompt) + if err != nil { + if ctx.Err() != nil { + return "", ctx.Err() + } + r.classify(target, err) + r.logFailure("summarize", target, err) + errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err)) + continue + } + r.health.markHealthy(target) + return out, nil + } + return "", fmt.Errorf("all summarize targets failed: %w", errors.Join(errs...)) +} + +func (r *EmbeddingRunner) classify(target config.RoleTarget, err error) { + switch { + case compat.IsPermanentModelError(err): + r.health.markPermanent(target) + default: + r.health.markTransient(target) + } +} + +func (r *MetadataRunner) classify(target config.RoleTarget, err error) { + switch { + case compat.IsPermanentModelError(err): + r.health.markPermanent(target) + case errors.Is(err, compat.ErrEmptyResponse): + r.health.markEmpty(target) + default: + r.health.markTransient(target) + } +} + +func (r *EmbeddingRunner) logFailure(role string, target config.RoleTarget, err error) { + if r.log == nil { + return + } + r.log.Warn("ai target failed", + slog.String("role", role), + slog.String("provider", target.Provider), + slog.String("model", target.Model), + slog.String("error", err.Error()), + ) +} + +func (r *MetadataRunner) logFailure(role string, target config.RoleTarget, err error) { + if r.log == nil { + return + } + r.log.Warn("ai target failed", + slog.String("role", role), + slog.String("provider", target.Provider), + slog.String("model", target.Model), + slog.String("error", err.Error()), + ) +} + +// healthTracker records per-(provider, model) failure state. skip returns true +// when a target is still inside its cooldown window; the caller then tries the +// next target in the chain. +type healthTracker struct { + mu sync.Mutex + states map[config.RoleTarget]*healthState +} + +type healthState struct { + unhealthyUntil time.Time + emptyCount int +} + +func newHealthTracker() *healthTracker { + return &healthTracker{states: map[config.RoleTarget]*healthState{}} +} + +func (h *healthTracker) skip(target config.RoleTarget) bool { + h.mu.Lock() + defer h.mu.Unlock() + s, ok := h.states[target] + if !ok { + return false + } + return time.Now().Before(s.unhealthyUntil) +} + +func (h *healthTracker) markTransient(target config.RoleTarget) { + h.setCooldown(target, transientCooldown) +} + +func (h *healthTracker) markPermanent(target config.RoleTarget) { + h.setCooldown(target, permanentCooldown) +} + +func (h *healthTracker) markEmpty(target config.RoleTarget) { + h.mu.Lock() + defer h.mu.Unlock() + s := h.states[target] + if s == nil { + s = &healthState{} + h.states[target] = s + } + s.emptyCount++ + if s.emptyCount >= emptyResponseThreshold { + s.unhealthyUntil = time.Now().Add(emptyResponseCooldown) + s.emptyCount = 0 + } +} + +func (h *healthTracker) markHealthy(target config.RoleTarget) { + h.mu.Lock() + defer h.mu.Unlock() + if s, ok := h.states[target]; ok { + s.unhealthyUntil = time.Time{} + s.emptyCount = 0 + } +} + +func (h *healthTracker) setCooldown(target config.RoleTarget, d time.Duration) { + h.mu.Lock() + defer h.mu.Unlock() + s := h.states[target] + if s == nil { + s = &healthState{} + h.states[target] = s + } + s.unhealthyUntil = time.Now().Add(d) + s.emptyCount = 0 +} diff --git a/internal/ai/runner_test.go b/internal/ai/runner_test.go new file mode 100644 index 0000000..7a148e8 --- /dev/null +++ b/internal/ai/runner_test.go @@ -0,0 +1,139 @@ +package ai + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "git.warky.dev/wdevs/amcs/internal/config" +) + +func TestEmbeddingRunnerFallsBackAndSkipsUnhealthyPrimary(t *testing.T) { + var ( + mu sync.Mutex + primaryCalls int + ) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/embeddings" { + http.NotFound(w, r) + return + } + var req struct { + Model string `json:"model"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + switch req.Model { + case "embed-primary": + mu.Lock() + primaryCalls++ + mu.Unlock() + http.Error(w, "upstream down", http.StatusBadGateway) + case "embed-fallback": + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": []map[string]any{{"embedding": []float32{0.1, 0.2, 0.3}}}, + }) + default: + http.Error(w, "unknown model", http.StatusBadRequest) + } + })) + defer srv.Close() + + reg, err := NewRegistry(map[string]config.ProviderConfig{ + "p1": {Type: "litellm", BaseURL: srv.URL, APIKey: "k1"}, + "p2": {Type: "litellm", BaseURL: srv.URL, APIKey: "k2"}, + }, srv.Client(), nil) + if err != nil { + t.Fatalf("NewRegistry() error = %v", err) + } + + runner, err := NewEmbeddingRunner(reg, []config.RoleTarget{ + {Provider: "p1", Model: "embed-primary"}, + {Provider: "p2", Model: "embed-fallback"}, + }, 3, nil) + if err != nil { + t.Fatalf("NewEmbeddingRunner() error = %v", err) + } + + res, err := runner.Embed(context.Background(), "hello") + if err != nil { + t.Fatalf("Embed() first call error = %v", err) + } + if res.Provider != "p2" || res.Model != "embed-fallback" { + t.Fatalf("Embed() first call target = %s/%s, want p2/embed-fallback", res.Provider, res.Model) + } + + res, err = runner.Embed(context.Background(), "hello again") + if err != nil { + t.Fatalf("Embed() second call error = %v", err) + } + if res.Provider != "p2" || res.Model != "embed-fallback" { + t.Fatalf("Embed() second call target = %s/%s, want p2/embed-fallback", res.Provider, res.Model) + } + + mu.Lock() + calls := primaryCalls + mu.Unlock() + if calls != 3 { + t.Fatalf("primary calls = %d, want 3 (first request retries 3x; second call should skip unhealthy primary)", calls) + } +} + +func TestMetadataRunnerSummarizeFallsBack(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/chat/completions" { + http.NotFound(w, r) + return + } + var req struct { + Model string `json:"model"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + switch req.Model { + case "sum-primary": + http.Error(w, "provider error", http.StatusBadGateway) + case "sum-fallback": + _ = json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{{ + "message": map[string]any{"role": "assistant", "content": "fallback summary"}, + }}, + }) + default: + http.Error(w, "unknown model", http.StatusBadRequest) + } + })) + defer srv.Close() + + reg, err := NewRegistry(map[string]config.ProviderConfig{ + "p1": {Type: "litellm", BaseURL: srv.URL, APIKey: "k1"}, + "p2": {Type: "litellm", BaseURL: srv.URL, APIKey: "k2"}, + }, srv.Client(), nil) + if err != nil { + t.Fatalf("NewRegistry() error = %v", err) + } + + runner, err := NewMetadataRunner(reg, []config.RoleTarget{ + {Provider: "p1", Model: "sum-primary"}, + {Provider: "p2", Model: "sum-fallback"}, + }, 0.1, false, nil) + if err != nil { + t.Fatalf("NewMetadataRunner() error = %v", err) + } + + summary, err := runner.Summarize(context.Background(), "system", "user") + if err != nil { + t.Fatalf("Summarize() error = %v", err) + } + if summary != "fallback summary" { + t.Fatalf("summary = %q, want %q", summary, "fallback summary") + } +} diff --git a/internal/app/app.go b/internal/app/app.go index 4ae9359..e4471b6 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -34,7 +34,7 @@ func Run(ctx context.Context, configPath string) error { logger.Info("loaded configuration", slog.String("path", loadedFrom), - slog.String("provider", cfg.AI.Provider), + slog.Int("config_version", cfg.Version), slog.String("version", info.Version), slog.String("tag_name", info.TagName), slog.String("build_date", info.BuildDate), @@ -52,11 +52,37 @@ func Run(ctx context.Context, configPath string) error { } httpClient := &http.Client{Timeout: 30 * time.Second} - provider, err := ai.NewProvider(cfg.AI, httpClient, logger) + registry, err := ai.NewRegistry(cfg.AI.Providers, httpClient, logger) if err != nil { return err } + foregroundEmbeddings, err := ai.NewEmbeddingRunner(registry, cfg.AI.Embeddings.Chain(), cfg.AI.Embeddings.Dimensions, logger) + if err != nil { + return err + } + foregroundMetadata, err := ai.NewMetadataRunner(registry, cfg.AI.Metadata.Chain(), cfg.AI.Metadata.Temperature, cfg.AI.Metadata.LogConversations, logger) + if err != nil { + return err + } + + backgroundEmbeddings := foregroundEmbeddings + backgroundMetadata := foregroundMetadata + if cfg.AI.Background != nil { + if cfg.AI.Background.Embeddings != nil { + backgroundEmbeddings, err = ai.NewEmbeddingRunner(registry, cfg.AI.Background.Embeddings.AsTargets(), cfg.AI.Embeddings.Dimensions, logger) + if err != nil { + return err + } + } + if cfg.AI.Background.Metadata != nil { + backgroundMetadata, err = ai.NewMetadataRunner(registry, cfg.AI.Background.Metadata.AsTargets(), cfg.AI.Metadata.Temperature, cfg.AI.Metadata.LogConversations, logger) + if err != nil { + return err + } + } + } + var keyring *auth.Keyring var oauthRegistry *auth.OAuthRegistry var tokenStore *auth.TokenStore @@ -77,12 +103,13 @@ func Run(ctx context.Context, configPath string) error { dynClients := auth.NewDynamicClientStore() activeProjects := session.NewActiveProjects() - logger.Info("database connection verified", - slog.String("provider", provider.Name()), + logger.Info("ai providers initialised", + slog.String("embedding_primary", foregroundEmbeddings.PrimaryProvider()+"/"+foregroundEmbeddings.PrimaryModel()), + slog.String("metadata_primary", foregroundMetadata.PrimaryProvider()+"/"+foregroundMetadata.PrimaryModel()), ) if cfg.Backfill.Enabled && cfg.Backfill.RunOnStartup { - go runBackfillPass(ctx, db, provider, cfg.Backfill, logger) + go runBackfillPass(ctx, db, backgroundEmbeddings, cfg.Backfill, logger) } if cfg.Backfill.Enabled && cfg.Backfill.Interval > 0 { @@ -94,14 +121,14 @@ func Run(ctx context.Context, configPath string) error { case <-ctx.Done(): return case <-ticker.C: - runBackfillPass(ctx, db, provider, cfg.Backfill, logger) + runBackfillPass(ctx, db, backgroundEmbeddings, cfg.Backfill, logger) } } }() } if cfg.MetadataRetry.Enabled && cfg.MetadataRetry.RunOnStartup { - go runMetadataRetryPass(ctx, db, provider, cfg, activeProjects, logger) + go runMetadataRetryPass(ctx, db, backgroundMetadata, cfg, activeProjects, logger) } if cfg.MetadataRetry.Enabled && cfg.MetadataRetry.Interval > 0 { @@ -113,13 +140,13 @@ func Run(ctx context.Context, configPath string) error { case <-ctx.Done(): return case <-ticker.C: - runMetadataRetryPass(ctx, db, provider, cfg, activeProjects, logger) + runMetadataRetryPass(ctx, db, backgroundMetadata, cfg, activeProjects, logger) } } }() } - handler, err := routes(logger, cfg, info, db, provider, keyring, oauthRegistry, tokenStore, authCodes, dynClients, activeProjects) + handler, err := routes(logger, cfg, info, db, foregroundEmbeddings, foregroundMetadata, backgroundEmbeddings, backgroundMetadata, keyring, oauthRegistry, tokenStore, authCodes, dynClients, activeProjects) if err != nil { return err } @@ -156,33 +183,33 @@ func Run(ctx context.Context, configPath string) error { } } -func routes(logger *slog.Logger, cfg *config.Config, info buildinfo.Info, db *store.DB, provider ai.Provider, keyring *auth.Keyring, oauthRegistry *auth.OAuthRegistry, tokenStore *auth.TokenStore, authCodes *auth.AuthCodeStore, dynClients *auth.DynamicClientStore, activeProjects *session.ActiveProjects) (http.Handler, error) { +func routes(logger *slog.Logger, cfg *config.Config, info buildinfo.Info, db *store.DB, embeddings *ai.EmbeddingRunner, metadata *ai.MetadataRunner, bgEmbeddings *ai.EmbeddingRunner, bgMetadata *ai.MetadataRunner, keyring *auth.Keyring, oauthRegistry *auth.OAuthRegistry, tokenStore *auth.TokenStore, authCodes *auth.AuthCodeStore, dynClients *auth.DynamicClientStore, activeProjects *session.ActiveProjects) (http.Handler, error) { mux := http.NewServeMux() accessTracker := auth.NewAccessTracker() oauthEnabled := oauthRegistry != nil && tokenStore != nil authMiddleware := auth.Middleware(cfg.Auth, keyring, oauthRegistry, tokenStore, accessTracker, logger) filesTool := tools.NewFilesTool(db, activeProjects) - enrichmentRetryer := tools.NewEnrichmentRetryer(context.Background(), db, provider, cfg.Capture, cfg.AI.Metadata.Timeout, activeProjects, logger) - backfillTool := tools.NewBackfillTool(db, provider, activeProjects, logger) + enrichmentRetryer := tools.NewEnrichmentRetryer(context.Background(), db, bgMetadata, cfg.Capture, cfg.AI.Metadata.Timeout, activeProjects, logger) + backfillTool := tools.NewBackfillTool(db, bgEmbeddings, activeProjects, logger) toolSet := mcpserver.ToolSet{ - Capture: tools.NewCaptureTool(db, provider, cfg.Capture, cfg.AI.Metadata.Timeout, activeProjects, enrichmentRetryer, backfillTool, logger), - Search: tools.NewSearchTool(db, provider, cfg.Search, activeProjects), + Capture: tools.NewCaptureTool(db, embeddings, metadata, cfg.Capture, cfg.AI.Metadata.Timeout, activeProjects, nil, backfillTool, logger), + Search: tools.NewSearchTool(db, embeddings, cfg.Search, activeProjects), List: tools.NewListTool(db, cfg.Search, activeProjects), Stats: tools.NewStatsTool(db), Get: tools.NewGetTool(db), - Update: tools.NewUpdateTool(db, provider, cfg.Capture, logger), + Update: tools.NewUpdateTool(db, embeddings, metadata, cfg.Capture, logger), Delete: tools.NewDeleteTool(db), Archive: tools.NewArchiveTool(db), Projects: tools.NewProjectsTool(db, activeProjects), Version: tools.NewVersionTool(cfg.MCP.ServerName, info), - Context: tools.NewContextTool(db, provider, cfg.Search, activeProjects), - Recall: tools.NewRecallTool(db, provider, cfg.Search, activeProjects), - Summarize: tools.NewSummarizeTool(db, provider, cfg.Search, activeProjects), - Links: tools.NewLinksTool(db, provider, cfg.Search), + Context: tools.NewContextTool(db, embeddings, cfg.Search, activeProjects), + Recall: tools.NewRecallTool(db, embeddings, cfg.Search, activeProjects), + Summarize: tools.NewSummarizeTool(db, embeddings, metadata, cfg.Search, activeProjects), + Links: tools.NewLinksTool(db, embeddings, cfg.Search), Files: filesTool, Backfill: backfillTool, - Reparse: tools.NewReparseMetadataTool(db, provider, cfg.Capture, activeProjects, logger), + Reparse: tools.NewReparseMetadataTool(db, bgMetadata, cfg.Capture, activeProjects, logger), RetryMetadata: tools.NewRetryEnrichmentTool(enrichmentRetryer), Maintenance: tools.NewMaintenanceTool(db), Skills: tools.NewSkillsTool(db, activeProjects), @@ -242,8 +269,8 @@ func routes(logger *slog.Logger, cfg *config.Config, info buildinfo.Info, db *st ), nil } -func runMetadataRetryPass(ctx context.Context, db *store.DB, provider ai.Provider, cfg *config.Config, activeProjects *session.ActiveProjects, logger *slog.Logger) { - retryer := tools.NewMetadataRetryer(ctx, db, provider, cfg.Capture, cfg.AI.Metadata.Timeout, activeProjects, logger) +func runMetadataRetryPass(ctx context.Context, db *store.DB, metadataRunner *ai.MetadataRunner, cfg *config.Config, activeProjects *session.ActiveProjects, logger *slog.Logger) { + retryer := tools.NewMetadataRetryer(ctx, db, metadataRunner, cfg.Capture, cfg.AI.Metadata.Timeout, activeProjects, logger) _, out, err := retryer.Handle(ctx, nil, tools.RetryMetadataInput{ Limit: cfg.MetadataRetry.MaxPerRun, IncludeArchived: cfg.MetadataRetry.IncludeArchived, @@ -261,8 +288,8 @@ func runMetadataRetryPass(ctx context.Context, db *store.DB, provider ai.Provide ) } -func runBackfillPass(ctx context.Context, db *store.DB, provider ai.Provider, cfg config.BackfillConfig, logger *slog.Logger) { - backfiller := tools.NewBackfillTool(db, provider, nil, logger) +func runBackfillPass(ctx context.Context, db *store.DB, embeddings *ai.EmbeddingRunner, cfg config.BackfillConfig, logger *slog.Logger) { + backfiller := tools.NewBackfillTool(db, embeddings, nil, logger) _, out, err := backfiller.Handle(ctx, nil, tools.BackfillInput{ Limit: cfg.MaxPerRun, IncludeArchived: cfg.IncludeArchived, diff --git a/internal/config/config.go b/internal/config/config.go index 46f8daa..281205a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -8,6 +8,7 @@ const ( ) type Config struct { + Version int `yaml:"version"` Server ServerConfig `yaml:"server"` MCP MCPConfig `yaml:"mcp"` Auth AuthConfig `yaml:"auth"` @@ -37,11 +38,8 @@ type MCPConfig struct { Version string `yaml:"version"` Transport string `yaml:"transport"` SessionTimeout time.Duration `yaml:"session_timeout"` - // PublicURL is the externally reachable base URL of this server (e.g. https://amcs.example.com). - // When set, it is used to build absolute icon URLs in the MCP server identity. - PublicURL string `yaml:"public_url"` - // Instructions is set at startup from the embedded memory.md and sent to MCP clients on initialise. - Instructions string `yaml:"-"` + PublicURL string `yaml:"public_url"` + Instructions string `yaml:"-"` } type AuthConfig struct { @@ -77,52 +75,82 @@ type DatabaseConfig struct { MaxConnIdleTime time.Duration `yaml:"max_conn_idle_time"` } +// AIConfig (v2): named providers + per-role chains. type AIConfig struct { - Provider string `yaml:"provider"` - Embeddings AIEmbeddingConfig `yaml:"embeddings"` - Metadata AIMetadataConfig `yaml:"metadata"` - LiteLLM LiteLLMConfig `yaml:"litellm"` - Ollama OllamaConfig `yaml:"ollama"` - OpenRouter OpenRouterAIConfig `yaml:"openrouter"` + Providers map[string]ProviderConfig `yaml:"providers"` + Embeddings EmbeddingsRoleConfig `yaml:"embeddings"` + Metadata MetadataRoleConfig `yaml:"metadata"` + Background *BackgroundRolesConfig `yaml:"background,omitempty"` } -type AIEmbeddingConfig struct { - Model string `yaml:"model"` - Dimensions int `yaml:"dimensions"` +type ProviderConfig struct { + Type string `yaml:"type"` + BaseURL string `yaml:"base_url"` + APIKey string `yaml:"api_key"` + RequestHeaders map[string]string `yaml:"request_headers,omitempty"` + AppName string `yaml:"app_name,omitempty"` + SiteURL string `yaml:"site_url,omitempty"` } -type AIMetadataConfig struct { - Model string `yaml:"model"` - FallbackModels []string `yaml:"fallback_models"` - FallbackModel string `yaml:"fallback_model"` // legacy single fallback +type RoleTarget struct { + Provider string `yaml:"provider"` + Model string `yaml:"model"` +} + +type RoleChain struct { + Primary RoleTarget `yaml:"primary"` + Fallbacks []RoleTarget `yaml:"fallbacks,omitempty"` +} + +type EmbeddingsRoleConfig struct { + Dimensions int `yaml:"dimensions"` + Primary RoleTarget `yaml:"primary"` + Fallbacks []RoleTarget `yaml:"fallbacks,omitempty"` +} + +type MetadataRoleConfig struct { Temperature float64 `yaml:"temperature"` LogConversations bool `yaml:"log_conversations"` Timeout time.Duration `yaml:"timeout"` + Primary RoleTarget `yaml:"primary"` + Fallbacks []RoleTarget `yaml:"fallbacks,omitempty"` } -type LiteLLMConfig struct { - BaseURL string `yaml:"base_url"` - APIKey string `yaml:"api_key"` - UseResponsesAPI bool `yaml:"use_responses_api"` - RequestHeaders map[string]string `yaml:"request_headers"` - EmbeddingModel string `yaml:"embedding_model"` - MetadataModel string `yaml:"metadata_model"` - FallbackMetadataModels []string `yaml:"fallback_metadata_models"` - FallbackMetadataModel string `yaml:"fallback_metadata_model"` // legacy single fallback +// BackgroundRolesConfig overrides the foreground chains for background workers +// (backfill_embeddings, metadata_retry, reparse_metadata). Either field may be +// nil to inherit the foreground role unchanged. +type BackgroundRolesConfig struct { + Embeddings *RoleChain `yaml:"embeddings,omitempty"` + Metadata *RoleChain `yaml:"metadata,omitempty"` } -type OllamaConfig struct { - BaseURL string `yaml:"base_url"` - APIKey string `yaml:"api_key"` - RequestHeaders map[string]string `yaml:"request_headers"` +// Chain returns primary followed by fallbacks (deduped, blanks dropped). +func (e EmbeddingsRoleConfig) Chain() []RoleTarget { + return dedupeTargets(append([]RoleTarget{e.Primary}, e.Fallbacks...)) } -type OpenRouterAIConfig struct { - BaseURL string `yaml:"base_url"` - APIKey string `yaml:"api_key"` - AppName string `yaml:"app_name"` - SiteURL string `yaml:"site_url"` - ExtraHeaders map[string]string `yaml:"extra_headers"` +func (m MetadataRoleConfig) Chain() []RoleTarget { + return dedupeTargets(append([]RoleTarget{m.Primary}, m.Fallbacks...)) +} + +func (c RoleChain) AsTargets() []RoleTarget { + return dedupeTargets(append([]RoleTarget{c.Primary}, c.Fallbacks...)) +} + +func dedupeTargets(in []RoleTarget) []RoleTarget { + out := make([]RoleTarget, 0, len(in)) + seen := make(map[RoleTarget]struct{}, len(in)) + for _, t := range in { + if t.Provider == "" || t.Model == "" { + continue + } + if _, ok := seen[t]; ok { + continue + } + seen[t] = struct{}{} + out = append(out, t) + } + return out } type CaptureConfig struct { @@ -167,45 +195,3 @@ type MetadataRetryConfig struct { MaxPerRun int `yaml:"max_per_run"` IncludeArchived bool `yaml:"include_archived"` } - -func (c AIMetadataConfig) EffectiveFallbackModels() []string { - models := make([]string, 0, len(c.FallbackModels)+1) - for _, model := range c.FallbackModels { - if model != "" { - models = append(models, model) - } - } - if c.FallbackModel != "" { - models = append(models, c.FallbackModel) - } - return dedupeNonEmpty(models) -} - -func (c LiteLLMConfig) EffectiveFallbackMetadataModels() []string { - models := make([]string, 0, len(c.FallbackMetadataModels)+1) - for _, model := range c.FallbackMetadataModels { - if model != "" { - models = append(models, model) - } - } - if c.FallbackMetadataModel != "" { - models = append(models, c.FallbackMetadataModel) - } - return dedupeNonEmpty(models) -} - -func dedupeNonEmpty(values []string) []string { - seen := make(map[string]struct{}, len(values)) - out := make([]string, 0, len(values)) - for _, value := range values { - if value == "" { - continue - } - if _, ok := seen[value]; ok { - continue - } - seen[value] = struct{}{} - out = append(out, value) - } - return out -} diff --git a/internal/config/loader.go b/internal/config/loader.go index d6b0a88..f133b59 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -2,6 +2,7 @@ package config import ( "fmt" + "log/slog" "os" "strconv" "strings" @@ -12,6 +13,12 @@ import ( ) func Load(explicitPath string) (*Config, string, error) { + return LoadWithLogger(explicitPath, nil) +} + +// LoadWithLogger is Load with a logger surface for migration notices. Passing +// nil is fine — migration events will simply not be logged. +func LoadWithLogger(explicitPath string, log *slog.Logger) (*Config, string, error) { path := ResolvePath(explicitPath) data, err := os.ReadFile(path) @@ -19,10 +26,40 @@ func Load(explicitPath string) (*Config, string, error) { return nil, path, fmt.Errorf("read config %q: %w", path, err) } - cfg := defaultConfig() - if err := yaml.Unmarshal(data, &cfg); err != nil { + raw := map[string]any{} + if err := yaml.Unmarshal(data, &raw); err != nil { return nil, path, fmt.Errorf("decode config %q: %w", path, err) } + if raw == nil { + raw = map[string]any{} + } + + applied, err := Migrate(raw) + if err != nil { + return nil, path, fmt.Errorf("migrate config %q: %w", path, err) + } + + if len(applied) > 0 { + if err := rewriteConfigFile(path, data, raw); err != nil { + return nil, path, err + } + if log != nil { + for _, step := range applied { + log.Warn("config migrated", + slog.String("path", path), + slog.Int("from_version", step.From), + slog.Int("to_version", step.To), + slog.String("describe", step.Describe), + ) + } + } + } + + cfg, err := decodeTyped(raw) + if err != nil { + return nil, path, fmt.Errorf("decode migrated config %q: %w", path, err) + } + cfg.Version = CurrentConfigVersion applyEnvOverrides(&cfg) if err := cfg.Validate(); err != nil { @@ -32,6 +69,34 @@ func Load(explicitPath string) (*Config, string, error) { return &cfg, path, nil } +func decodeTyped(raw map[string]any) (Config, error) { + out, err := yaml.Marshal(raw) + if err != nil { + return Config{}, fmt.Errorf("re-marshal migrated config: %w", err) + } + cfg := defaultConfig() + if err := yaml.Unmarshal(out, &cfg); err != nil { + return Config{}, err + } + return cfg, nil +} + +func rewriteConfigFile(path string, original []byte, migrated map[string]any) error { + backupPath := fmt.Sprintf("%s.bak.%d", path, time.Now().Unix()) + if err := os.WriteFile(backupPath, original, 0o600); err != nil { + return fmt.Errorf("write backup %q: %w", backupPath, err) + } + + out, err := yaml.Marshal(migrated) + if err != nil { + return fmt.Errorf("marshal migrated config: %w", err) + } + if err := os.WriteFile(path, out, 0o600); err != nil { + return fmt.Errorf("write migrated config %q: %w", path, err) + } + return nil +} + func ResolvePath(explicitPath string) string { if path := strings.TrimSpace(explicitPath); path != "" { if path != ".yaml" && path != ".yml" { @@ -49,6 +114,7 @@ func ResolvePath(explicitPath string) string { func defaultConfig() Config { info := buildinfo.Current() return Config{ + Version: CurrentConfigVersion, Server: ServerConfig{ Host: "0.0.0.0", Port: 8080, @@ -69,20 +135,14 @@ func defaultConfig() Config { QueryParam: "key", }, AI: AIConfig{ - Provider: "litellm", - Embeddings: AIEmbeddingConfig{ - Model: "openai/text-embedding-3-small", + Providers: map[string]ProviderConfig{}, + Embeddings: EmbeddingsRoleConfig{ Dimensions: 1536, }, - Metadata: AIMetadataConfig{ - Model: "gpt-4o-mini", + Metadata: MetadataRoleConfig{ Temperature: 0.1, Timeout: 10 * time.Second, }, - Ollama: OllamaConfig{ - BaseURL: "http://localhost:11434/v1", - APIKey: "ollama", - }, }, Capture: CaptureConfig{ Source: DefaultSource, @@ -119,11 +179,12 @@ func defaultConfig() Config { func applyEnvOverrides(cfg *Config) { overrideString(&cfg.Database.URL, "AMCS_DATABASE_URL") overrideString(&cfg.MCP.PublicURL, "AMCS_PUBLIC_URL") - overrideString(&cfg.AI.LiteLLM.BaseURL, "AMCS_LITELLM_BASE_URL") - overrideString(&cfg.AI.LiteLLM.APIKey, "AMCS_LITELLM_API_KEY") - overrideString(&cfg.AI.Ollama.BaseURL, "AMCS_OLLAMA_BASE_URL") - overrideString(&cfg.AI.Ollama.APIKey, "AMCS_OLLAMA_API_KEY") - overrideString(&cfg.AI.OpenRouter.APIKey, "AMCS_OPENROUTER_API_KEY") + + overrideProviderField(cfg, "AMCS_LITELLM_BASE_URL", "litellm", func(p *ProviderConfig, v string) { p.BaseURL = v }) + overrideProviderField(cfg, "AMCS_LITELLM_API_KEY", "litellm", func(p *ProviderConfig, v string) { p.APIKey = v }) + overrideProviderField(cfg, "AMCS_OLLAMA_BASE_URL", "ollama", func(p *ProviderConfig, v string) { p.BaseURL = v }) + overrideProviderField(cfg, "AMCS_OLLAMA_API_KEY", "ollama", func(p *ProviderConfig, v string) { p.APIKey = v }) + overrideProviderField(cfg, "AMCS_OPENROUTER_API_KEY", "openrouter", func(p *ProviderConfig, v string) { p.APIKey = v }) if value, ok := os.LookupEnv("AMCS_SERVER_PORT"); ok { if port, err := strconv.Atoi(strings.TrimSpace(value)); err == nil { @@ -132,6 +193,24 @@ func applyEnvOverrides(cfg *Config) { } } +// overrideProviderField applies an env var to every configured provider of the +// given type. This preserves the v1 behaviour where e.g. AMCS_LITELLM_API_KEY +// rewrote the single litellm block — in v2 it rewrites every litellm provider. +func overrideProviderField(cfg *Config, envKey, providerType string, apply func(*ProviderConfig, string)) { + value, ok := os.LookupEnv(envKey) + if !ok { + return + } + value = strings.TrimSpace(value) + for name, p := range cfg.AI.Providers { + if p.Type != providerType { + continue + } + apply(&p, value) + cfg.AI.Providers[name] = p + } +} + func overrideString(target *string, envKey string) { if value, ok := os.LookupEnv(envKey); ok { *target = strings.TrimSpace(value) diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index ac1025b..c0d685c 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -31,9 +31,8 @@ func TestResolvePathIgnoresBareYAMLExtension(t *testing.T) { } } -func TestLoadAppliesEnvOverrides(t *testing.T) { - configPath := filepath.Join(t.TempDir(), "test.yaml") - if err := os.WriteFile(configPath, []byte(` +const v2ConfigYAML = ` +version: 2 server: port: 8080 mcp: @@ -46,18 +45,30 @@ auth: database: url: "postgres://from-file" ai: - provider: "litellm" + providers: + default: + type: "litellm" + base_url: "http://localhost:4000/v1" + api_key: "file-key" embeddings: dimensions: 1536 - litellm: - base_url: "http://localhost:4000/v1" - api_key: "file-key" + primary: + provider: "default" + model: "text-embed" + metadata: + primary: + provider: "default" + model: "gpt-4" search: default_limit: 10 max_limit: 50 logging: level: "info" -`), 0o600); err != nil { +` + +func TestLoadAppliesEnvOverrides(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "test.yaml") + if err := os.WriteFile(configPath, []byte(v2ConfigYAML), 0o600); err != nil { t.Fatalf("write config: %v", err) } @@ -76,8 +87,8 @@ logging: if cfg.Database.URL != "postgres://from-env" { t.Fatalf("database url = %q, want env override", cfg.Database.URL) } - if cfg.AI.LiteLLM.APIKey != "env-key" { - t.Fatalf("litellm api key = %q, want env override", cfg.AI.LiteLLM.APIKey) + if cfg.AI.Providers["default"].APIKey != "env-key" { + t.Fatalf("litellm api key = %q, want env override", cfg.AI.Providers["default"].APIKey) } if cfg.Server.Port != 9090 { t.Fatalf("server port = %d, want 9090", cfg.Server.Port) @@ -90,10 +101,12 @@ logging: func TestLoadAppliesOllamaEnvOverrides(t *testing.T) { configPath := filepath.Join(t.TempDir(), "test.yaml") if err := os.WriteFile(configPath, []byte(` +version: 2 server: port: 8080 mcp: path: "/mcp" + session_timeout: "10m" auth: keys: - id: "test" @@ -101,15 +114,20 @@ auth: database: url: "postgres://from-file" ai: - provider: "ollama" + providers: + local: + type: "ollama" + base_url: "http://localhost:11434/v1" + api_key: "ollama" embeddings: - model: "nomic-embed-text" dimensions: 768 + primary: + provider: "local" + model: "nomic-embed-text" metadata: - model: "llama3.2" - ollama: - base_url: "http://localhost:11434/v1" - api_key: "ollama" + primary: + provider: "local" + model: "llama3.2" search: default_limit: 10 max_limit: 50 @@ -127,10 +145,77 @@ logging: t.Fatalf("Load() error = %v", err) } - if cfg.AI.Ollama.BaseURL != "https://ollama.example.com/v1" { - t.Fatalf("ollama base url = %q, want env override", cfg.AI.Ollama.BaseURL) + p := cfg.AI.Providers["local"] + if p.BaseURL != "https://ollama.example.com/v1" { + t.Fatalf("ollama base url = %q, want env override", p.BaseURL) } - if cfg.AI.Ollama.APIKey != "remote-key" { - t.Fatalf("ollama api key = %q, want env override", cfg.AI.Ollama.APIKey) + if p.APIKey != "remote-key" { + t.Fatalf("ollama api key = %q, want env override", p.APIKey) + } +} + +func TestLoadMigratesV1Config(t *testing.T) { + configPath := filepath.Join(t.TempDir(), "v1.yaml") + v1 := ` +server: + port: 8080 +mcp: + path: "/mcp" + session_timeout: "10m" +auth: + keys: + - id: "test" + value: "secret" +database: + url: "postgres://from-file" +ai: + provider: "litellm" + embeddings: + model: "text-embed" + dimensions: 1536 + metadata: + model: "gpt-4" + temperature: 0.2 + fallback_models: ["gpt-3.5"] + litellm: + base_url: "http://localhost:4000/v1" + api_key: "file-key" +search: + default_limit: 10 + max_limit: 50 +logging: + level: "info" +` + if err := os.WriteFile(configPath, []byte(v1), 0o600); err != nil { + t.Fatalf("write config: %v", err) + } + + cfg, _, err := Load(configPath) + if err != nil { + t.Fatalf("Load() error = %v", err) + } + + if cfg.Version != CurrentConfigVersion { + t.Fatalf("version = %d, want %d", cfg.Version, CurrentConfigVersion) + } + if p, ok := cfg.AI.Providers["default"]; !ok || p.Type != "litellm" || p.APIKey != "file-key" { + t.Fatalf("providers[default] = %+v, want litellm/file-key", p) + } + if cfg.AI.Embeddings.Primary.Model != "text-embed" || cfg.AI.Embeddings.Primary.Provider != "default" { + t.Fatalf("embeddings.primary = %+v, want default/text-embed", cfg.AI.Embeddings.Primary) + } + if cfg.AI.Metadata.Primary.Model != "gpt-4" || cfg.AI.Metadata.Primary.Provider != "default" { + t.Fatalf("metadata.primary = %+v, want default/gpt-4", cfg.AI.Metadata.Primary) + } + if len(cfg.AI.Metadata.Fallbacks) != 1 || cfg.AI.Metadata.Fallbacks[0].Model != "gpt-3.5" { + t.Fatalf("metadata.fallbacks = %+v, want [default/gpt-3.5]", cfg.AI.Metadata.Fallbacks) + } + + entries, err := filepath.Glob(configPath + ".bak.*") + if err != nil { + t.Fatalf("glob backups: %v", err) + } + if len(entries) != 1 { + t.Fatalf("backup files = %d, want 1", len(entries)) } } diff --git a/internal/config/migrate.go b/internal/config/migrate.go new file mode 100644 index 0000000..8f88040 --- /dev/null +++ b/internal/config/migrate.go @@ -0,0 +1,341 @@ +package config + +import ( + "fmt" + "sort" +) + +// CurrentConfigVersion is the schema version this binary expects. Files at a +// lower version are migrated automatically when loaded. +const CurrentConfigVersion = 2 + +// ConfigMigration upgrades a raw YAML map by one version. +type ConfigMigration struct { + From, To int + Describe string + Apply func(map[string]any) error +} + +// migrations is the ordered ladder of upgrades. Add new entries at the end. +var migrations = []ConfigMigration{ + {From: 1, To: 2, Describe: "named providers + role chains", Apply: migrateV1toV2}, +} + +// Migrate brings raw up to CurrentConfigVersion in place. Returns the list of +// migrations that were applied (may be empty if already current). +func Migrate(raw map[string]any) ([]ConfigMigration, error) { + if raw == nil { + return nil, fmt.Errorf("migrate: raw config is nil") + } + + version := readVersion(raw) + if version > CurrentConfigVersion { + return nil, fmt.Errorf("migrate: config version %d is newer than supported version %d", version, CurrentConfigVersion) + } + + applied := make([]ConfigMigration, 0) + for { + if version >= CurrentConfigVersion { + break + } + step, ok := findMigration(version) + if !ok { + return nil, fmt.Errorf("migrate: no migration registered from version %d", version) + } + if err := step.Apply(raw); err != nil { + return nil, fmt.Errorf("migrate v%d->v%d: %w", step.From, step.To, err) + } + raw["version"] = step.To + version = step.To + applied = append(applied, step) + } + return applied, nil +} + +func findMigration(from int) (ConfigMigration, bool) { + for _, m := range migrations { + if m.From == from { + return m, true + } + } + return ConfigMigration{}, false +} + +// readVersion returns the version from raw. Files without a version field are +// treated as version 1 (the original schema). +func readVersion(raw map[string]any) int { + v, ok := raw["version"] + if !ok { + return 1 + } + switch n := v.(type) { + case int: + return n + case int64: + return int(n) + case float64: + return int(n) + } + return 1 +} + +// migrateV1toV2 lifts the single-provider config into the named-providers + +// role-chains layout. The pre-v2 config implicitly used one provider for both +// embeddings and metadata; we materialise that as a provider named "default". +func migrateV1toV2(raw map[string]any) error { + aiRaw := mapValue(raw, "ai") + if aiRaw == nil { + aiRaw = map[string]any{} + } + + providerType := stringValue(aiRaw, "provider") + if providerType == "" { + providerType = "litellm" + } + + providers, embeddingModel, metadataModel, fallbackModels := buildV1Provider(aiRaw, providerType) + + embeddingsOld := mapValue(aiRaw, "embeddings") + dimensions := intValue(embeddingsOld, "dimensions") + if dimensions <= 0 { + dimensions = 1536 + } + if embeddingModel == "" { + embeddingModel = stringValue(embeddingsOld, "model") + } + + metadataOld := mapValue(aiRaw, "metadata") + if metadataModel == "" { + metadataModel = stringValue(metadataOld, "model") + } + temperature := floatValue(metadataOld, "temperature") + logConversations := boolValue(metadataOld, "log_conversations") + timeoutStr := stringValue(metadataOld, "timeout") + + if list := stringListValue(metadataOld, "fallback_models"); len(list) > 0 { + fallbackModels = append(fallbackModels, list...) + } + if v := stringValue(metadataOld, "fallback_model"); v != "" { + fallbackModels = append(fallbackModels, v) + } + + embeddings := map[string]any{ + "dimensions": dimensions, + "primary": map[string]any{"provider": "default", "model": embeddingModel}, + } + + metadata := map[string]any{ + "temperature": temperature, + "log_conversations": logConversations, + "primary": map[string]any{"provider": "default", "model": metadataModel}, + } + if timeoutStr != "" { + metadata["timeout"] = timeoutStr + } + if fallbacks := chainTargets("default", fallbackModels); len(fallbacks) > 0 { + metadata["fallbacks"] = fallbacks + } + + raw["ai"] = map[string]any{ + "providers": providers, + "embeddings": embeddings, + "metadata": metadata, + } + return nil +} + +func buildV1Provider(aiRaw map[string]any, providerType string) (map[string]any, string, string, []string) { + providers := map[string]any{} + defaultEntry := map[string]any{"type": providerType} + embedModel := "" + metaModel := "" + var fallbacks []string + + switch providerType { + case "litellm": + block := mapValue(aiRaw, "litellm") + copyKeys(defaultEntry, block, "base_url", "api_key") + copyHeaders(defaultEntry, block, "request_headers") + embedModel = stringValue(block, "embedding_model") + metaModel = stringValue(block, "metadata_model") + if list := stringListValue(block, "fallback_metadata_models"); len(list) > 0 { + fallbacks = append(fallbacks, list...) + } + if v := stringValue(block, "fallback_metadata_model"); v != "" { + fallbacks = append(fallbacks, v) + } + case "ollama": + block := mapValue(aiRaw, "ollama") + copyKeys(defaultEntry, block, "base_url", "api_key") + copyHeaders(defaultEntry, block, "request_headers") + case "openrouter": + block := mapValue(aiRaw, "openrouter") + copyKeys(defaultEntry, block, "base_url", "api_key", "app_name", "site_url") + copyHeaders(defaultEntry, block, "extra_headers") + // rename: extra_headers → request_headers + if hdr, ok := defaultEntry["extra_headers"]; ok { + defaultEntry["request_headers"] = hdr + delete(defaultEntry, "extra_headers") + } + } + + providers["default"] = defaultEntry + return providers, embedModel, metaModel, fallbacks +} + +func chainTargets(provider string, models []string) []any { + out := make([]any, 0, len(models)) + seen := map[string]struct{}{} + for _, m := range models { + if m == "" { + continue + } + key := provider + "|" + m + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + out = append(out, map[string]any{"provider": provider, "model": m}) + } + return out +} + +func mapValue(raw map[string]any, key string) map[string]any { + if raw == nil { + return nil + } + v, ok := raw[key] + if !ok { + return nil + } + switch m := v.(type) { + case map[string]any: + return m + case map[any]any: + return convertAnyMap(m) + } + return nil +} + +func convertAnyMap(in map[any]any) map[string]any { + out := make(map[string]any, len(in)) + keys := make([]string, 0, len(in)) + for k, v := range in { + ks, ok := k.(string) + if !ok { + continue + } + keys = append(keys, ks) + out[ks] = v + } + sort.Strings(keys) + return out +} + +func stringValue(raw map[string]any, key string) string { + if raw == nil { + return "" + } + v, ok := raw[key] + if !ok { + return "" + } + if s, ok := v.(string); ok { + return s + } + return "" +} + +func intValue(raw map[string]any, key string) int { + if raw == nil { + return 0 + } + switch n := raw[key].(type) { + case int: + return n + case int64: + return int(n) + case float64: + return int(n) + } + return 0 +} + +func floatValue(raw map[string]any, key string) float64 { + if raw == nil { + return 0 + } + switch n := raw[key].(type) { + case float64: + return n + case int: + return float64(n) + case int64: + return float64(n) + } + return 0 +} + +func boolValue(raw map[string]any, key string) bool { + if raw == nil { + return false + } + if b, ok := raw[key].(bool); ok { + return b + } + return false +} + +func stringListValue(raw map[string]any, key string) []string { + if raw == nil { + return nil + } + v, ok := raw[key] + if !ok { + return nil + } + list, ok := v.([]any) + if !ok { + return nil + } + out := make([]string, 0, len(list)) + for _, item := range list { + if s, ok := item.(string); ok && s != "" { + out = append(out, s) + } + } + return out +} + +func copyKeys(dst, src map[string]any, keys ...string) { + if src == nil { + return + } + for _, k := range keys { + if v, ok := src[k]; ok { + dst[k] = v + } + } +} + +func copyHeaders(dst, src map[string]any, key string) { + if src == nil { + return + } + v, ok := src[key] + if !ok { + return + } + switch headers := v.(type) { + case map[string]any: + if len(headers) == 0 { + return + } + dst[key] = headers + case map[any]any: + if len(headers) == 0 { + return + } + dst[key] = convertAnyMap(headers) + } +} diff --git a/internal/config/migrate_test.go b/internal/config/migrate_test.go new file mode 100644 index 0000000..b3771f6 --- /dev/null +++ b/internal/config/migrate_test.go @@ -0,0 +1,77 @@ +package config + +import "testing" + +func TestMigrateV1ToV2Litellm(t *testing.T) { + raw := map[string]any{ + "ai": map[string]any{ + "provider": "litellm", + "embeddings": map[string]any{ + "model": "text-embedding-3-small", + "dimensions": 1536, + }, + "metadata": map[string]any{ + "model": "gpt-4o-mini", + "temperature": 0.2, + "fallback_models": []any{"gpt-4.1-mini"}, + }, + "litellm": map[string]any{ + "base_url": "http://localhost:4000/v1", + "api_key": "secret", + }, + }, + } + + applied, err := Migrate(raw) + if err != nil { + t.Fatalf("Migrate() error = %v", err) + } + if len(applied) != 1 || applied[0].From != 1 || applied[0].To != 2 { + t.Fatalf("applied = %+v, want [v1->v2]", applied) + } + if got := readVersion(raw); got != CurrentConfigVersion { + t.Fatalf("version = %d, want %d", got, CurrentConfigVersion) + } + + ai := mapValue(raw, "ai") + providers := mapValue(ai, "providers") + def := mapValue(providers, "default") + if got := stringValue(def, "type"); got != "litellm" { + t.Fatalf("providers.default.type = %q, want litellm", got) + } + if got := stringValue(def, "base_url"); got != "http://localhost:4000/v1" { + t.Fatalf("providers.default.base_url = %q", got) + } + + emb := mapValue(ai, "embeddings") + embPrimary := mapValue(emb, "primary") + if stringValue(embPrimary, "provider") != "default" || stringValue(embPrimary, "model") != "text-embedding-3-small" { + t.Fatalf("embeddings.primary = %+v, want default/text-embedding-3-small", embPrimary) + } + + meta := mapValue(ai, "metadata") + metaPrimary := mapValue(meta, "primary") + if stringValue(metaPrimary, "provider") != "default" || stringValue(metaPrimary, "model") != "gpt-4o-mini" { + t.Fatalf("metadata.primary = %+v, want default/gpt-4o-mini", metaPrimary) + } + fallbacks, ok := meta["fallbacks"].([]any) + if !ok || len(fallbacks) != 1 { + t.Fatalf("metadata.fallbacks = %#v, want len=1", meta["fallbacks"]) + } + firstFallback, ok := fallbacks[0].(map[string]any) + if !ok { + t.Fatalf("metadata.fallbacks[0] type = %T, want map[string]any", fallbacks[0]) + } + if stringValue(firstFallback, "provider") != "default" || stringValue(firstFallback, "model") != "gpt-4.1-mini" { + t.Fatalf("metadata fallback = %+v, want default/gpt-4.1-mini", firstFallback) + } +} + +func TestMigrateRejectsNewerVersion(t *testing.T) { + raw := map[string]any{"version": CurrentConfigVersion + 1} + + _, err := Migrate(raw) + if err == nil { + t.Fatal("Migrate() error = nil, want error for newer config version") + } +} diff --git a/internal/config/validate.go b/internal/config/validate.go index af40b37..33a8bd8 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -45,38 +45,8 @@ func (c Config) Validate() error { return fmt.Errorf("invalid config: mcp.session_timeout must be greater than zero") } - switch c.AI.Provider { - case "litellm", "ollama", "openrouter": - default: - return fmt.Errorf("invalid config: unsupported ai.provider %q", c.AI.Provider) - } - - if c.AI.Embeddings.Dimensions <= 0 { - return fmt.Errorf("invalid config: ai.embeddings.dimensions must be greater than zero") - } - - switch c.AI.Provider { - case "litellm": - if strings.TrimSpace(c.AI.LiteLLM.BaseURL) == "" { - return fmt.Errorf("invalid config: ai.litellm.base_url is required when ai.provider=litellm") - } - if strings.TrimSpace(c.AI.LiteLLM.APIKey) == "" { - return fmt.Errorf("invalid config: ai.litellm.api_key is required when ai.provider=litellm") - } - case "ollama": - if strings.TrimSpace(c.AI.Ollama.BaseURL) == "" { - return fmt.Errorf("invalid config: ai.ollama.base_url is required when ai.provider=ollama") - } - if strings.TrimSpace(c.AI.Ollama.APIKey) == "" { - return fmt.Errorf("invalid config: ai.ollama.api_key is required when ai.provider=ollama") - } - case "openrouter": - if strings.TrimSpace(c.AI.OpenRouter.BaseURL) == "" { - return fmt.Errorf("invalid config: ai.openrouter.base_url is required when ai.provider=openrouter") - } - if strings.TrimSpace(c.AI.OpenRouter.APIKey) == "" { - return fmt.Errorf("invalid config: ai.openrouter.api_key is required when ai.provider=openrouter") - } + if err := c.AI.validate(); err != nil { + return err } if c.Server.Port <= 0 { @@ -108,3 +78,61 @@ func (c Config) Validate() error { return nil } + +func (a AIConfig) validate() error { + if len(a.Providers) == 0 { + return fmt.Errorf("invalid config: ai.providers must contain at least one entry") + } + for name, p := range a.Providers { + if strings.TrimSpace(name) == "" { + return fmt.Errorf("invalid config: ai.providers contains an entry with an empty name") + } + switch p.Type { + case "litellm", "ollama", "openrouter": + default: + return fmt.Errorf("invalid config: ai.providers.%s.type %q is not supported", name, p.Type) + } + if strings.TrimSpace(p.BaseURL) == "" { + return fmt.Errorf("invalid config: ai.providers.%s.base_url is required", name) + } + if strings.TrimSpace(p.APIKey) == "" { + return fmt.Errorf("invalid config: ai.providers.%s.api_key is required", name) + } + } + + if a.Embeddings.Dimensions <= 0 { + return fmt.Errorf("invalid config: ai.embeddings.dimensions must be greater than zero") + } + + if err := a.validateChain("ai.embeddings", a.Embeddings.Chain()); err != nil { + return err + } + if err := a.validateChain("ai.metadata", a.Metadata.Chain()); err != nil { + return err + } + if a.Background != nil { + if a.Background.Embeddings != nil { + if err := a.validateChain("ai.background.embeddings", a.Background.Embeddings.AsTargets()); err != nil { + return err + } + } + if a.Background.Metadata != nil { + if err := a.validateChain("ai.background.metadata", a.Background.Metadata.AsTargets()); err != nil { + return err + } + } + } + return nil +} + +func (a AIConfig) validateChain(prefix string, chain []RoleTarget) error { + if len(chain) == 0 { + return fmt.Errorf("invalid config: %s.primary must reference a configured provider and model", prefix) + } + for i, target := range chain { + if _, ok := a.Providers[target.Provider]; !ok { + return fmt.Errorf("invalid config: %s[%d] references unknown provider %q", prefix, i, target.Provider) + } + } + return nil +} diff --git a/internal/config/validate_test.go b/internal/config/validate_test.go index 8cea09c..e0272ea 100644 --- a/internal/config/validate_test.go +++ b/internal/config/validate_test.go @@ -7,28 +7,23 @@ import ( func validConfig() Config { return Config{ - Server: ServerConfig{Port: 8080}, - MCP: MCPConfig{Path: "/mcp", SessionTimeout: 10 * time.Minute}, + Version: CurrentConfigVersion, + Server: ServerConfig{Port: 8080}, + MCP: MCPConfig{Path: "/mcp", SessionTimeout: 10 * time.Minute}, Auth: AuthConfig{ Keys: []APIKey{{ID: "test", Value: "secret"}}, }, Database: DatabaseConfig{URL: "postgres://example"}, AI: AIConfig{ - Provider: "litellm", - Embeddings: AIEmbeddingConfig{ + Providers: map[string]ProviderConfig{ + "default": {Type: "litellm", BaseURL: "http://localhost:4000/v1", APIKey: "key"}, + }, + Embeddings: EmbeddingsRoleConfig{ Dimensions: 1536, + Primary: RoleTarget{Provider: "default", Model: "text-embed"}, }, - LiteLLM: LiteLLMConfig{ - BaseURL: "http://localhost:4000/v1", - APIKey: "key", - }, - Ollama: OllamaConfig{ - BaseURL: "http://localhost:11434/v1", - APIKey: "ollama", - }, - OpenRouter: OpenRouterAIConfig{ - BaseURL: "https://openrouter.ai/api/v1", - APIKey: "key", + Metadata: MetadataRoleConfig{ + Primary: RoleTarget{Provider: "default", Model: "gpt-4"}, }, }, Search: SearchConfig{DefaultLimit: 10, MaxLimit: 50}, @@ -36,29 +31,44 @@ func validConfig() Config { } } -func TestValidateAcceptsSupportedProviders(t *testing.T) { - cfg := validConfig() - if err := cfg.Validate(); err != nil { - t.Fatalf("Validate litellm error = %v", err) - } - - cfg.AI.Provider = "ollama" - if err := cfg.Validate(); err != nil { - t.Fatalf("Validate ollama error = %v", err) - } - - cfg.AI.Provider = "openrouter" - if err := cfg.Validate(); err != nil { - t.Fatalf("Validate openrouter error = %v", err) +func TestValidateAcceptsSupportedProviderTypes(t *testing.T) { + for _, providerType := range []string{"litellm", "ollama", "openrouter"} { + cfg := validConfig() + p := cfg.AI.Providers["default"] + p.Type = providerType + cfg.AI.Providers["default"] = p + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate %s error = %v", providerType, err) + } } } -func TestValidateRejectsInvalidProvider(t *testing.T) { +func TestValidateRejectsInvalidProviderType(t *testing.T) { cfg := validConfig() - cfg.AI.Provider = "unknown" + p := cfg.AI.Providers["default"] + p.Type = "unknown" + cfg.AI.Providers["default"] = p if err := cfg.Validate(); err == nil { - t.Fatal("Validate() error = nil, want error for unsupported provider") + t.Fatal("Validate() error = nil, want error for unsupported provider type") + } +} + +func TestValidateRejectsChainWithUnknownProvider(t *testing.T) { + cfg := validConfig() + cfg.AI.Metadata.Primary = RoleTarget{Provider: "does-not-exist", Model: "x"} + + if err := cfg.Validate(); err == nil { + t.Fatal("Validate() error = nil, want error for chain referencing unknown provider") + } +} + +func TestValidateRejectsEmptyProviders(t *testing.T) { + cfg := validConfig() + cfg.AI.Providers = map[string]ProviderConfig{} + + if err := cfg.Validate(); err == nil { + t.Fatal("Validate() error = nil, want error for empty providers") } } diff --git a/internal/mcpserver/server.go b/internal/mcpserver/server.go index 8f3426f..c9e55b5 100644 --- a/internal/mcpserver/server.go +++ b/internal/mcpserver/server.go @@ -35,7 +35,7 @@ type ToolSet struct { Files *tools.FilesTool Backfill *tools.BackfillTool Reparse *tools.ReparseMetadataTool - RetryMetadata *tools.RetryMetadataTool + RetryMetadata *tools.RetryEnrichmentTool Maintenance *tools.MaintenanceTool Skills *tools.SkillsTool ChatHistory *tools.ChatHistoryTool diff --git a/internal/mcpserver/streamable_integration_test.go b/internal/mcpserver/streamable_integration_test.go index 699b178..89a82a7 100644 --- a/internal/mcpserver/streamable_integration_test.go +++ b/internal/mcpserver/streamable_integration_test.go @@ -126,7 +126,7 @@ func streamableTestToolSet() ToolSet { Files: new(tools.FilesTool), Backfill: new(tools.BackfillTool), Reparse: new(tools.ReparseMetadataTool), - RetryMetadata: new(tools.RetryMetadataTool), + RetryMetadata: new(tools.RetryEnrichmentTool), Maintenance: new(tools.MaintenanceTool), Skills: new(tools.SkillsTool), } diff --git a/internal/tools/backfill.go b/internal/tools/backfill.go index 521a9b7..52d96bc 100644 --- a/internal/tools/backfill.go +++ b/internal/tools/backfill.go @@ -18,10 +18,10 @@ import ( const backfillConcurrency = 4 type BackfillTool struct { - store *store.DB - provider ai.Provider - sessions *session.ActiveProjects - logger *slog.Logger + store *store.DB + embeddings *ai.EmbeddingRunner + sessions *session.ActiveProjects + logger *slog.Logger } type BackfillInput struct { @@ -47,15 +47,15 @@ type BackfillOutput struct { Failures []BackfillFailure `json:"failures,omitempty"` } -func NewBackfillTool(db *store.DB, provider ai.Provider, sessions *session.ActiveProjects, logger *slog.Logger) *BackfillTool { - return &BackfillTool{store: db, provider: provider, sessions: sessions, logger: logger} +func NewBackfillTool(db *store.DB, embeddings *ai.EmbeddingRunner, sessions *session.ActiveProjects, logger *slog.Logger) *BackfillTool { + return &BackfillTool{store: db, embeddings: embeddings, sessions: sessions, logger: logger} } // QueueThought queues a single thought for background embedding generation. // It is used by capture when the embedding provider is temporarily unavailable. func (t *BackfillTool) QueueThought(ctx context.Context, id uuid.UUID, content string) { go func() { - vec, err := t.provider.Embed(ctx, content) + result, err := t.embeddings.Embed(ctx, content) if err != nil { t.logger.Warn("background embedding retry failed", slog.String("thought_id", id.String()), @@ -63,15 +63,17 @@ func (t *BackfillTool) QueueThought(ctx context.Context, id uuid.UUID, content s ) return } - model := t.provider.EmbeddingModel() - if err := t.store.UpsertEmbedding(ctx, id, model, vec); err != nil { + if err := t.store.UpsertEmbedding(ctx, id, result.Model, result.Vector); err != nil { t.logger.Warn("background embedding upsert failed", slog.String("thought_id", id.String()), slog.String("error", err.Error()), ) return } - t.logger.Info("background embedding retry succeeded", slog.String("thought_id", id.String())) + t.logger.Info("background embedding retry succeeded", + slog.String("thought_id", id.String()), + slog.String("model", result.Model), + ) }() } @@ -91,15 +93,15 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in projectID = &project.ID } - model := t.provider.EmbeddingModel() + primaryModel := t.embeddings.PrimaryModel() - thoughts, err := t.store.ListThoughtsMissingEmbedding(ctx, model, limit, projectID, in.IncludeArchived, in.OlderThanDays) + thoughts, err := t.store.ListThoughtsMissingEmbedding(ctx, primaryModel, limit, projectID, in.IncludeArchived, in.OlderThanDays) if err != nil { return nil, BackfillOutput{}, err } out := BackfillOutput{ - Model: model, + Model: primaryModel, Scanned: len(thoughts), DryRun: in.DryRun, } @@ -125,7 +127,7 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in defer wg.Done() defer sem.Release(1) - vec, embedErr := t.provider.Embed(ctx, content) + result, embedErr := t.embeddings.Embed(ctx, content) if embedErr != nil { mu.Lock() out.Failures = append(out.Failures, BackfillFailure{ID: id.String(), Error: embedErr.Error()}) @@ -134,7 +136,7 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in return } - if upsertErr := t.store.UpsertEmbedding(ctx, id, model, vec); upsertErr != nil { + if upsertErr := t.store.UpsertEmbedding(ctx, id, result.Model, result.Vector); upsertErr != nil { mu.Lock() out.Failures = append(out.Failures, BackfillFailure{ID: id.String(), Error: upsertErr.Error()}) mu.Unlock() @@ -154,7 +156,7 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in out.Skipped = out.Scanned - out.Embedded - out.Failed t.logger.Info("backfill completed", - slog.String("model", model), + slog.String("model", primaryModel), slog.Int("scanned", out.Scanned), slog.Int("embedded", out.Embedded), slog.Int("failed", out.Failed), diff --git a/internal/tools/capture.go b/internal/tools/capture.go index eb33ea8..4005410 100644 --- a/internal/tools/capture.go +++ b/internal/tools/capture.go @@ -22,13 +22,20 @@ type EmbeddingQueuer interface { QueueThought(ctx context.Context, id uuid.UUID, content string) } +// MetadataQueuer queues a thought for background metadata retry. Both +// MetadataRetryer and EnrichmentRetryer satisfy this. +type MetadataQueuer interface { + QueueThought(id uuid.UUID) +} + type CaptureTool struct { store *store.DB - provider ai.Provider + embeddings *ai.EmbeddingRunner + metadata *ai.MetadataRunner capture config.CaptureConfig sessions *session.ActiveProjects metadataTimeout time.Duration - retryer *MetadataRetryer + retryer MetadataQueuer embedRetryer EmbeddingQueuer log *slog.Logger } @@ -42,8 +49,8 @@ type CaptureOutput struct { Thought thoughttypes.Thought `json:"thought"` } -func NewCaptureTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, retryer *MetadataRetryer, embedRetryer EmbeddingQueuer, log *slog.Logger) *CaptureTool { - return &CaptureTool{store: db, provider: provider, capture: capture, sessions: sessions, metadataTimeout: metadataTimeout, retryer: retryer, embedRetryer: embedRetryer, log: log} +func NewCaptureTool(db *store.DB, embeddings *ai.EmbeddingRunner, metadata *ai.MetadataRunner, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, retryer MetadataQueuer, embedRetryer EmbeddingQueuer, log *slog.Logger) *CaptureTool { + return &CaptureTool{store: db, embeddings: embeddings, metadata: metadata, capture: capture, sessions: sessions, metadataTimeout: metadataTimeout, retryer: retryer, embedRetryer: embedRetryer, log: log} } func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in CaptureInput) (*mcp.CallToolResult, CaptureOutput, error) { @@ -66,7 +73,7 @@ func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in C thought.ProjectID = &project.ID } - created, err := t.store.InsertThought(ctx, thought, t.provider.EmbeddingModel()) + created, err := t.store.InsertThought(ctx, thought, t.embeddings.PrimaryModel()) if err != nil { return nil, CaptureOutput{}, err } @@ -89,7 +96,7 @@ func (t *CaptureTool) launchEnrichment(id uuid.UUID, content string) { if t.retryer != nil { attemptedAt := time.Now().UTC() rawMetadata := metadata.Fallback(t.capture) - extracted, err := t.provider.ExtractMetadata(ctx, content) + extracted, err := t.metadata.ExtractMetadata(ctx, content) if err != nil { failed := metadata.MarkMetadataFailed(rawMetadata, t.capture, attemptedAt, err) if _, updateErr := t.store.UpdateThoughtMetadata(ctx, id, failed); updateErr != nil { @@ -100,7 +107,7 @@ func (t *CaptureTool) launchEnrichment(id uuid.UUID, content string) { } t.log.Warn("deferred metadata extraction failed", slog.String("thought_id", id.String()), - slog.String("provider", t.provider.Name()), + slog.String("provider", t.metadata.PrimaryProvider()), slog.String("error", err.Error()), ) t.retryer.QueueThought(id) @@ -116,10 +123,10 @@ func (t *CaptureTool) launchEnrichment(id uuid.UUID, content string) { } if t.embedRetryer != nil { - if _, err := t.provider.Embed(ctx, content); err != nil { + if _, err := t.embeddings.Embed(ctx, content); err != nil { t.log.Warn("deferred embedding failed", slog.String("thought_id", id.String()), - slog.String("provider", t.provider.Name()), + slog.String("provider", t.embeddings.PrimaryProvider()), slog.String("error", err.Error()), ) } diff --git a/internal/tools/context.go b/internal/tools/context.go index e65168e..ffc449e 100644 --- a/internal/tools/context.go +++ b/internal/tools/context.go @@ -15,10 +15,10 @@ import ( ) type ContextTool struct { - store *store.DB - provider ai.Provider - search config.SearchConfig - sessions *session.ActiveProjects + store *store.DB + embeddings *ai.EmbeddingRunner + search config.SearchConfig + sessions *session.ActiveProjects } type ProjectContextInput struct { @@ -41,8 +41,8 @@ type ProjectContextOutput struct { Items []ContextItem `json:"items"` } -func NewContextTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *ContextTool { - return &ContextTool{store: db, provider: provider, search: search, sessions: sessions} +func NewContextTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig, sessions *session.ActiveProjects) *ContextTool { + return &ContextTool{store: db, embeddings: embeddings, search: search, sessions: sessions} } func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in ProjectContextInput) (*mcp.CallToolResult, ProjectContextOutput, error) { @@ -72,7 +72,7 @@ func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in P query := strings.TrimSpace(in.Query) if query != "" { - semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, &project.ID, nil) + semantic, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, t.search.DefaultThreshold, &project.ID, nil) if err != nil { return nil, ProjectContextOutput{}, err } diff --git a/internal/tools/enrichment_retry.go b/internal/tools/enrichment_retry.go index 6a3d4d4..ef770b3 100644 --- a/internal/tools/enrichment_retry.go +++ b/internal/tools/enrichment_retry.go @@ -32,7 +32,7 @@ var enrichmentRetryBackoff = []time.Duration{ type EnrichmentRetryer struct { backgroundCtx context.Context store *store.DB - provider ai.Provider + metadata *ai.MetadataRunner capture config.CaptureConfig sessions *session.ActiveProjects metadataTimeout time.Duration @@ -66,14 +66,14 @@ type RetryEnrichmentOutput struct { Failures []RetryEnrichmentFailure `json:"failures,omitempty"` } -func NewEnrichmentRetryer(backgroundCtx context.Context, db *store.DB, provider ai.Provider, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, logger *slog.Logger) *EnrichmentRetryer { +func NewEnrichmentRetryer(backgroundCtx context.Context, db *store.DB, metadataRunner *ai.MetadataRunner, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, logger *slog.Logger) *EnrichmentRetryer { if backgroundCtx == nil { backgroundCtx = context.Background() } return &EnrichmentRetryer{ backgroundCtx: backgroundCtx, store: db, - provider: provider, + metadata: metadataRunner, capture: capture, sessions: sessions, metadataTimeout: metadataTimeout, @@ -190,7 +190,7 @@ func (r *EnrichmentRetryer) retryOne(ctx context.Context, id uuid.UUID) (bool, e } attemptedAt := time.Now().UTC() - extracted, extractErr := r.provider.ExtractMetadata(attemptCtx, thought.Content) + extracted, extractErr := r.metadata.ExtractMetadata(attemptCtx, thought.Content) if extractErr != nil { failedMetadata := metadata.MarkMetadataFailed(thought.Metadata, r.capture, attemptedAt, extractErr) if _, updateErr := r.store.UpdateThoughtMetadata(ctx, thought.ID, failedMetadata); updateErr != nil { diff --git a/internal/tools/links.go b/internal/tools/links.go index 1cde072..a61d501 100644 --- a/internal/tools/links.go +++ b/internal/tools/links.go @@ -13,9 +13,9 @@ import ( ) type LinksTool struct { - store *store.DB - provider ai.Provider - search config.SearchConfig + store *store.DB + embeddings *ai.EmbeddingRunner + search config.SearchConfig } type LinkInput struct { @@ -47,8 +47,8 @@ type RelatedOutput struct { Related []RelatedThought `json:"related"` } -func NewLinksTool(db *store.DB, provider ai.Provider, search config.SearchConfig) *LinksTool { - return &LinksTool{store: db, provider: provider, search: search} +func NewLinksTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig) *LinksTool { + return &LinksTool{store: db, embeddings: embeddings, search: search} } func (t *LinksTool) Link(ctx context.Context, _ *mcp.CallToolRequest, in LinkInput) (*mcp.CallToolResult, LinkOutput, error) { @@ -117,7 +117,7 @@ func (t *LinksTool) Related(ctx context.Context, _ *mcp.CallToolRequest, in Rela } if includeSemantic { - semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, thought.Content, t.search.DefaultLimit, t.search.DefaultThreshold, thought.ProjectID, &thought.ID) + semantic, err := semanticSearch(ctx, t.store, t.embeddings, t.search, thought.Content, t.search.DefaultLimit, t.search.DefaultThreshold, thought.ProjectID, &thought.ID) if err != nil { return nil, RelatedOutput{}, err } diff --git a/internal/tools/metadata_retry.go b/internal/tools/metadata_retry.go index ceb2268..06fa321 100644 --- a/internal/tools/metadata_retry.go +++ b/internal/tools/metadata_retry.go @@ -23,7 +23,7 @@ const metadataRetryConcurrency = 4 type MetadataRetryer struct { backgroundCtx context.Context store *store.DB - provider ai.Provider + metadata *ai.MetadataRunner capture config.CaptureConfig sessions *session.ActiveProjects metadataTimeout time.Duration @@ -87,14 +87,14 @@ type RetryMetadataOutput struct { Failures []RetryMetadataFailure `json:"failures,omitempty"` } -func NewMetadataRetryer(backgroundCtx context.Context, db *store.DB, provider ai.Provider, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, logger *slog.Logger) *MetadataRetryer { +func NewMetadataRetryer(backgroundCtx context.Context, db *store.DB, metadataRunner *ai.MetadataRunner, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, logger *slog.Logger) *MetadataRetryer { if backgroundCtx == nil { backgroundCtx = context.Background() } return &MetadataRetryer{ backgroundCtx: backgroundCtx, store: db, - provider: provider, + metadata: metadataRunner, capture: capture, sessions: sessions, metadataTimeout: metadataTimeout, @@ -223,7 +223,7 @@ func (r *MetadataRetryer) retryOne(ctx context.Context, id uuid.UUID) (bool, err } attemptedAt := time.Now().UTC() - extracted, extractErr := r.provider.ExtractMetadata(attemptCtx, thought.Content) + extracted, extractErr := r.metadata.ExtractMetadata(attemptCtx, thought.Content) if extractErr != nil { failedMetadata := metadata.MarkMetadataFailed(thought.Metadata, r.capture, attemptedAt, extractErr) if _, updateErr := r.store.UpdateThoughtMetadata(ctx, thought.ID, failedMetadata); updateErr != nil { diff --git a/internal/tools/recall.go b/internal/tools/recall.go index fa824b0..16fb784 100644 --- a/internal/tools/recall.go +++ b/internal/tools/recall.go @@ -15,10 +15,10 @@ import ( ) type RecallTool struct { - store *store.DB - provider ai.Provider - search config.SearchConfig - sessions *session.ActiveProjects + store *store.DB + embeddings *ai.EmbeddingRunner + search config.SearchConfig + sessions *session.ActiveProjects } type RecallInput struct { @@ -32,8 +32,8 @@ type RecallOutput struct { Items []ContextItem `json:"items"` } -func NewRecallTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *RecallTool { - return &RecallTool{store: db, provider: provider, search: search, sessions: sessions} +func NewRecallTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig, sessions *session.ActiveProjects) *RecallTool { + return &RecallTool{store: db, embeddings: embeddings, search: search, sessions: sessions} } func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in RecallInput) (*mcp.CallToolResult, RecallOutput, error) { @@ -54,7 +54,7 @@ func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Re projectID = &project.ID } - semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, projectID, nil) + semantic, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, t.search.DefaultThreshold, projectID, nil) if err != nil { return nil, RecallOutput{}, err } diff --git a/internal/tools/reparse_metadata.go b/internal/tools/reparse_metadata.go index a7d5bb2..7b5e31c 100644 --- a/internal/tools/reparse_metadata.go +++ b/internal/tools/reparse_metadata.go @@ -23,7 +23,7 @@ const metadataReparseConcurrency = 4 type ReparseMetadataTool struct { store *store.DB - provider ai.Provider + metadata *ai.MetadataRunner capture config.CaptureConfig sessions *session.ActiveProjects logger *slog.Logger @@ -53,8 +53,8 @@ type ReparseMetadataOutput struct { Failures []ReparseMetadataFailure `json:"failures,omitempty"` } -func NewReparseMetadataTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, sessions *session.ActiveProjects, logger *slog.Logger) *ReparseMetadataTool { - return &ReparseMetadataTool{store: db, provider: provider, capture: capture, sessions: sessions, logger: logger} +func NewReparseMetadataTool(db *store.DB, metadataRunner *ai.MetadataRunner, capture config.CaptureConfig, sessions *session.ActiveProjects, logger *slog.Logger) *ReparseMetadataTool { + return &ReparseMetadataTool{store: db, metadata: metadataRunner, capture: capture, sessions: sessions, logger: logger} } func (t *ReparseMetadataTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in ReparseMetadataInput) (*mcp.CallToolResult, ReparseMetadataOutput, error) { @@ -107,7 +107,7 @@ func (t *ReparseMetadataTool) Handle(ctx context.Context, req *mcp.CallToolReque normalizedCurrent := metadata.Normalize(thought.Metadata, t.capture) attemptedAt := time.Now().UTC() - extracted, extractErr := t.provider.ExtractMetadata(ctx, thought.Content) + extracted, extractErr := t.metadata.ExtractMetadata(ctx, thought.Content) normalizedTarget := normalizedCurrent if extractErr != nil { normalizedTarget = metadata.MarkMetadataFailed(normalizedCurrent, t.capture, attemptedAt, extractErr) diff --git a/internal/tools/retrieval.go b/internal/tools/retrieval.go index f937e74..5d51c18 100644 --- a/internal/tools/retrieval.go +++ b/internal/tools/retrieval.go @@ -11,12 +11,14 @@ import ( thoughttypes "git.warky.dev/wdevs/amcs/internal/types" ) -// semanticSearch runs vector similarity search if embeddings exist for the active model -// in the given scope, otherwise falls back to Postgres full-text search. +// semanticSearch runs vector similarity search if embeddings exist for the +// primary embedding model in the given scope, otherwise falls back to Postgres +// full-text search. Search always uses the primary model so query vectors +// match rows stored under the primary model name. func semanticSearch( ctx context.Context, db *store.DB, - provider ai.Provider, + embeddings *ai.EmbeddingRunner, search config.SearchConfig, query string, limit int, @@ -24,17 +26,18 @@ func semanticSearch( projectID *uuid.UUID, excludeID *uuid.UUID, ) ([]thoughttypes.SearchResult, error) { - hasEmbeddings, err := db.HasEmbeddingsForModel(ctx, provider.EmbeddingModel(), projectID) + model := embeddings.PrimaryModel() + hasEmbeddings, err := db.HasEmbeddingsForModel(ctx, model, projectID) if err != nil { return nil, err } if hasEmbeddings { - embedding, err := provider.Embed(ctx, query) + embedding, err := embeddings.EmbedPrimary(ctx, query) if err != nil { return nil, err } - return db.SearchSimilarThoughts(ctx, embedding, provider.EmbeddingModel(), threshold, limit, projectID, excludeID) + return db.SearchSimilarThoughts(ctx, embedding, model, threshold, limit, projectID, excludeID) } return db.SearchThoughtsText(ctx, query, limit, projectID, excludeID) diff --git a/internal/tools/search.go b/internal/tools/search.go index db05ab3..6aade33 100644 --- a/internal/tools/search.go +++ b/internal/tools/search.go @@ -15,10 +15,10 @@ import ( ) type SearchTool struct { - store *store.DB - provider ai.Provider - search config.SearchConfig - sessions *session.ActiveProjects + store *store.DB + embeddings *ai.EmbeddingRunner + search config.SearchConfig + sessions *session.ActiveProjects } type SearchInput struct { @@ -32,8 +32,8 @@ type SearchOutput struct { Results []thoughttypes.SearchResult `json:"results"` } -func NewSearchTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *SearchTool { - return &SearchTool{store: db, provider: provider, search: search, sessions: sessions} +func NewSearchTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig, sessions *session.ActiveProjects) *SearchTool { + return &SearchTool{store: db, embeddings: embeddings, search: search, sessions: sessions} } func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in SearchInput) (*mcp.CallToolResult, SearchOutput, error) { @@ -56,7 +56,7 @@ func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Se _ = t.store.TouchProject(ctx, project.ID) } - results, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, threshold, projectID, nil) + results, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, threshold, projectID, nil) if err != nil { return nil, SearchOutput{}, err } diff --git a/internal/tools/summarize.go b/internal/tools/summarize.go index 6c60939..3762bce 100644 --- a/internal/tools/summarize.go +++ b/internal/tools/summarize.go @@ -14,10 +14,11 @@ import ( ) type SummarizeTool struct { - store *store.DB - provider ai.Provider - search config.SearchConfig - sessions *session.ActiveProjects + store *store.DB + embeddings *ai.EmbeddingRunner + metadata *ai.MetadataRunner + search config.SearchConfig + sessions *session.ActiveProjects } type SummarizeInput struct { @@ -32,8 +33,8 @@ type SummarizeOutput struct { Count int `json:"count"` } -func NewSummarizeTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *SummarizeTool { - return &SummarizeTool{store: db, provider: provider, search: search, sessions: sessions} +func NewSummarizeTool(db *store.DB, embeddings *ai.EmbeddingRunner, metadata *ai.MetadataRunner, search config.SearchConfig, sessions *session.ActiveProjects) *SummarizeTool { + return &SummarizeTool{store: db, embeddings: embeddings, metadata: metadata, search: search, sessions: sessions} } func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in SummarizeInput) (*mcp.CallToolResult, SummarizeOutput, error) { @@ -52,7 +53,7 @@ func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in if project != nil { projectID = &project.ID } - results, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, projectID, nil) + results, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, t.search.DefaultThreshold, projectID, nil) if err != nil { return nil, SummarizeOutput{}, err } @@ -77,7 +78,7 @@ func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in userPrompt := formatContextBlock("Summarize the following thoughts into concise prose with themes, action items, and notable people.", lines) systemPrompt := "You summarize note collections. Be concise, concrete, and structured in plain prose." - summary, err := t.provider.Summarize(ctx, systemPrompt, userPrompt) + summary, err := t.metadata.Summarize(ctx, systemPrompt, userPrompt) if err != nil { return nil, SummarizeOutput{}, err } diff --git a/internal/tools/update.go b/internal/tools/update.go index 0f8a3b5..4865128 100644 --- a/internal/tools/update.go +++ b/internal/tools/update.go @@ -16,10 +16,11 @@ import ( ) type UpdateTool struct { - store *store.DB - provider ai.Provider - capture config.CaptureConfig - log *slog.Logger + store *store.DB + embeddings *ai.EmbeddingRunner + metadata *ai.MetadataRunner + capture config.CaptureConfig + log *slog.Logger } type UpdateInput struct { @@ -33,8 +34,8 @@ type UpdateOutput struct { Thought thoughttypes.Thought `json:"thought"` } -func NewUpdateTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, log *slog.Logger) *UpdateTool { - return &UpdateTool{store: db, provider: provider, capture: capture, log: log} +func NewUpdateTool(db *store.DB, embeddings *ai.EmbeddingRunner, metadata *ai.MetadataRunner, capture config.CaptureConfig, log *slog.Logger) *UpdateTool { + return &UpdateTool{store: db, embeddings: embeddings, metadata: metadata, capture: capture, log: log} } func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in UpdateInput) (*mcp.CallToolResult, UpdateOutput, error) { @@ -50,6 +51,7 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda content := current.Content var embedding []float32 + embeddingModel := "" mergedMetadata := current.Metadata projectID := current.ProjectID @@ -58,11 +60,13 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda if content == "" { return nil, UpdateOutput{}, errInvalidInput("content must not be empty") } - embedding, err = t.provider.Embed(ctx, content) + embedResult, err := t.embeddings.Embed(ctx, content) if err != nil { return nil, UpdateOutput{}, err } - extracted, extractErr := t.provider.ExtractMetadata(ctx, content) + embedding = embedResult.Vector + embeddingModel = embedResult.Model + extracted, extractErr := t.metadata.ExtractMetadata(ctx, content) if extractErr != nil { t.log.Warn("metadata extraction failed during update, keeping current metadata", slog.String("error", extractErr.Error())) mergedMetadata = metadata.MarkMetadataFailed(mergedMetadata, t.capture, time.Now().UTC(), extractErr) @@ -82,7 +86,7 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda projectID = &project.ID } - updated, err := t.store.UpdateThought(ctx, id, content, embedding, t.provider.EmbeddingModel(), mergedMetadata, projectID) + updated, err := t.store.UpdateThought(ctx, id, content, embedding, embeddingModel, mergedMetadata, projectID) if err != nil { return nil, UpdateOutput{}, err }