From 4f3d027f9e85c3617be87378e41d6dab50f3916d Mon Sep 17 00:00:00 2001 From: Hein Date: Fri, 27 Mar 2026 00:44:51 +0200 Subject: [PATCH] feat(config): update fallback model handling to support multiple models --- configs/config.example.yaml | 4 +- internal/ai/compat/client.go | 177 +++++++++++++++++++++++------- internal/ai/compat/client_test.go | 91 ++++++++++++--- internal/ai/litellm/client.go | 30 ++--- internal/ai/ollama/client.go | 24 ++-- internal/ai/openrouter/client.go | 24 ++-- internal/config/config.go | 66 +++++++++-- 7 files changed, 309 insertions(+), 107 deletions(-) diff --git a/configs/config.example.yaml b/configs/config.example.yaml index e29ce8b..332b346 100644 --- a/configs/config.example.yaml +++ b/configs/config.example.yaml @@ -42,7 +42,7 @@ ai: dimensions: 1536 metadata: model: "gpt-4o-mini" - fallback_model: "" + fallback_models: [] temperature: 0.1 log_conversations: false litellm: @@ -52,7 +52,7 @@ ai: request_headers: {} embedding_model: "openrouter/openai/text-embedding-3-small" metadata_model: "gpt-4o-mini" - fallback_metadata_model: "" + fallback_metadata_models: [] ollama: base_url: "http://localhost:11434/v1" api_key: "ollama" diff --git a/internal/ai/compat/client.go b/internal/ai/compat/client.go index 42de883..da0911d 100644 --- a/internal/ai/compat/client.go +++ b/internal/ai/compat/client.go @@ -13,6 +13,7 @@ import ( "regexp" "slices" "strings" + "sync" "time" thoughttypes "git.warky.dev/wdevs/amcs/internal/types" @@ -35,33 +36,35 @@ Rules: - Do not include any text outside the JSON object.` type Client struct { - name string - baseURL string - apiKey string - embeddingModel string - metadataModel string - fallbackMetadataModel string - temperature float64 - headers map[string]string - httpClient *http.Client - log *slog.Logger - dimensions int - logConversations bool + name string + baseURL string + apiKey string + embeddingModel string + metadataModel string + fallbackMetadataModels []string + temperature float64 + headers map[string]string + httpClient *http.Client + log *slog.Logger + dimensions int + logConversations bool + modelHealthMu sync.Mutex + modelHealth map[string]modelHealthState } type Config struct { - Name string - BaseURL string - APIKey string - EmbeddingModel string - MetadataModel string - FallbackMetadataModel string - Temperature float64 - Headers map[string]string - HTTPClient *http.Client - Log *slog.Logger - Dimensions int - LogConversations bool + Name string + BaseURL string + APIKey string + EmbeddingModel string + MetadataModel string + FallbackMetadataModels []string + Temperature float64 + Headers map[string]string + HTTPClient *http.Client + Log *slog.Logger + Dimensions int + LogConversations bool } type embeddingsRequest struct { @@ -114,20 +117,50 @@ type providerError struct { const maxMetadataAttempts = 3 +const ( + emptyResponseCircuitThreshold = 3 + emptyResponseCircuitTTL = 5 * time.Minute +) + +var ( + errMetadataEmptyResponse = errors.New("metadata empty response") + errMetadataNoJSONObject = errors.New("metadata response contains no JSON object") +) + +type modelHealthState struct { + consecutiveEmpty int + unhealthyUntil time.Time +} + func New(cfg Config) *Client { + fallbacks := make([]string, 0, len(cfg.FallbackMetadataModels)) + seen := make(map[string]struct{}, len(cfg.FallbackMetadataModels)) + for _, model := range cfg.FallbackMetadataModels { + model = strings.TrimSpace(model) + if model == "" { + continue + } + if _, ok := seen[model]; ok { + continue + } + seen[model] = struct{}{} + fallbacks = append(fallbacks, model) + } + return &Client{ - name: cfg.Name, - baseURL: cfg.BaseURL, - apiKey: cfg.APIKey, - embeddingModel: cfg.EmbeddingModel, - metadataModel: cfg.MetadataModel, - fallbackMetadataModel: cfg.FallbackMetadataModel, - temperature: cfg.Temperature, - headers: cfg.Headers, - httpClient: cfg.HTTPClient, - log: cfg.Log, - dimensions: cfg.Dimensions, - logConversations: cfg.LogConversations, + name: cfg.Name, + baseURL: cfg.BaseURL, + apiKey: cfg.APIKey, + embeddingModel: cfg.EmbeddingModel, + metadataModel: cfg.MetadataModel, + fallbackMetadataModels: fallbacks, + temperature: cfg.Temperature, + headers: cfg.Headers, + httpClient: cfg.HTTPClient, + log: cfg.Log, + dimensions: cfg.Dimensions, + logConversations: cfg.LogConversations, + modelHealth: make(map[string]modelHealthState), } } @@ -165,21 +198,38 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype } result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel) + if errors.Is(err, errMetadataEmptyResponse) { + c.noteEmptyResponse(c.metadataModel) + } if err == nil { + c.noteModelSuccess(c.metadataModel) return result, nil } - if c.fallbackMetadataModel != "" && ctx.Err() == nil { + for _, fallbackModel := range c.fallbackMetadataModels { + if ctx.Err() != nil { + break + } + if fallbackModel == "" || fallbackModel == c.metadataModel { + continue + } + if c.shouldBypassModel(fallbackModel) { + continue + } if c.log != nil { c.log.Warn("metadata extraction failed, trying fallback model", slog.String("provider", c.name), slog.String("primary_model", c.metadataModel), - slog.String("fallback_model", c.fallbackMetadataModel), + slog.String("fallback_model", fallbackModel), slog.String("error", err.Error()), ) } - fallbackResult, fallbackErr := c.extractMetadataWithModel(ctx, input, c.fallbackMetadataModel) + fallbackResult, fallbackErr := c.extractMetadataWithModel(ctx, input, fallbackModel) + if errors.Is(fallbackErr, errMetadataEmptyResponse) { + c.noteEmptyResponse(fallbackModel) + } if fallbackErr == nil { + c.noteModelSuccess(fallbackModel) return fallbackResult, nil } err = fallbackErr @@ -196,6 +246,10 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype } func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) { + if c.shouldBypassModel(model) { + return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: model %q temporarily bypassed after repeated empty responses", c.name, model) + } + stream := false req := chatCompletionsRequest{ Model: model, @@ -249,8 +303,9 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri metadataText = stripCodeFence(metadataText) metadataText = extractJSONObject(metadataText) if metadataText == "" { - lastErr = fmt.Errorf("%s metadata: response contains no JSON object", c.name) + lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataNoJSONObject) if strings.TrimSpace(rawResponse) == "" && attempt < maxMetadataAttempts && ctx.Err() == nil { + lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataEmptyResponse) if c.log != nil { c.log.Warn("metadata response empty, waiting and retrying", slog.String("provider", c.name), @@ -263,6 +318,9 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri } continue } + if strings.TrimSpace(rawResponse) == "" { + lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataEmptyResponse) + } return thoughttypes.ThoughtMetadata{}, lastErr } @@ -278,7 +336,7 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri if lastErr != nil { return thoughttypes.ThoughtMetadata{}, lastErr } - return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: response contains no JSON object", c.name) + return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: %w", c.name, errMetadataNoJSONObject) } func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) { @@ -740,3 +798,40 @@ func sleepMetadataRetry(ctx context.Context, attempt int) error { return nil } } + +func (c *Client) shouldBypassModel(model string) bool { + c.modelHealthMu.Lock() + defer c.modelHealthMu.Unlock() + + state, ok := c.modelHealth[model] + if !ok { + return false + } + return !state.unhealthyUntil.IsZero() && time.Now().Before(state.unhealthyUntil) +} + +func (c *Client) noteEmptyResponse(model string) { + c.modelHealthMu.Lock() + defer c.modelHealthMu.Unlock() + + state := c.modelHealth[model] + state.consecutiveEmpty++ + if state.consecutiveEmpty >= emptyResponseCircuitThreshold { + state.unhealthyUntil = time.Now().Add(emptyResponseCircuitTTL) + if c.log != nil { + c.log.Warn("metadata model marked temporarily unhealthy after repeated empty responses", + slog.String("provider", c.name), + slog.String("model", model), + slog.Time("until", state.unhealthyUntil), + ) + } + } + c.modelHealth[model] = state +} + +func (c *Client) noteModelSuccess(model string) { + c.modelHealthMu.Lock() + defer c.modelHealthMu.Unlock() + + delete(c.modelHealth, model) +} diff --git a/internal/ai/compat/client_test.go b/internal/ai/compat/client_test.go index a79d518..21a9205 100644 --- a/internal/ai/compat/client_test.go +++ b/internal/ai/compat/client_test.go @@ -123,13 +123,13 @@ func TestExtractMetadataFallbackModel(t *testing.T) { defer server.Close() client := New(Config{ - Name: "test", - BaseURL: server.URL, - APIKey: "secret", - MetadataModel: "primary-model", - FallbackMetadataModel: "fallback-model", - HTTPClient: server.Client(), - Log: discardLogger(), + Name: "test", + BaseURL: server.URL, + APIKey: "secret", + MetadataModel: "primary-model", + FallbackMetadataModels: []string{"fallback-model"}, + HTTPClient: server.Client(), + Log: discardLogger(), }) metadata, err := client.ExtractMetadata(context.Background(), "hello") @@ -265,13 +265,13 @@ func TestExtractMetadataFallsBackToHeuristicsWhenModelsFail(t *testing.T) { defer server.Close() client := New(Config{ - Name: "test", - BaseURL: server.URL, - APIKey: "secret", - MetadataModel: "primary", - FallbackMetadataModel: "secondary", - HTTPClient: server.Client(), - Log: discardLogger(), + Name: "test", + BaseURL: server.URL, + APIKey: "secret", + MetadataModel: "primary", + FallbackMetadataModels: []string{"secondary"}, + HTTPClient: server.Client(), + Log: discardLogger(), }) input := "Personal profile - Hein (Warkanum):\n- Born: 23 May 1989\n- Wife: Cindy, born 16 November 1994" @@ -341,3 +341,66 @@ func TestExtractMetadataRetriesEmptyResponse(t *testing.T) { t.Fatalf("metadata type = %q, want observation", metadata.Type) } } + +func TestExtractMetadataBypassesModelAfterRepeatedEmptyResponses(t *testing.T) { + var primaryCalls atomic.Int32 + var fallbackCalls atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req chatCompletionsRequest + _ = json.NewDecoder(r.Body).Decode(&req) + + switch req.Model { + case "primary": + primaryCalls.Add(1) + _ = json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + {"message": map[string]any{"content": ""}}, + }, + }) + case "fallback": + fallbackCalls.Add(1) + _ = json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + {"message": map[string]any{"content": "{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"mcp\"],\"type\":\"observation\",\"source\":\"mcp\"}"}}, + }, + }) + default: + t.Fatalf("unexpected model %q", req.Model) + } + })) + defer server.Close() + + client := New(Config{ + Name: "test", + BaseURL: server.URL, + APIKey: "secret", + MetadataModel: "primary", + FallbackMetadataModels: []string{"fallback"}, + HTTPClient: server.Client(), + Log: discardLogger(), + }) + + // First three calls should probe primary and then use fallback. + for i := 0; i < 3; i++ { + if _, err := client.ExtractMetadata(context.Background(), "hello"); err != nil { + t.Fatalf("ExtractMetadata() error = %v", err) + } + } + + primaryBefore := primaryCalls.Load() + if primaryBefore == 0 { + t.Fatal("expected primary model to be called before bypass") + } + + // Fourth call should bypass primary (no additional primary calls). + if _, err := client.ExtractMetadata(context.Background(), "hello"); err != nil { + t.Fatalf("ExtractMetadata() error = %v", err) + } + if primaryCalls.Load() != primaryBefore { + t.Fatalf("primary calls increased after bypass: before=%d after=%d", primaryBefore, primaryCalls.Load()) + } + if fallbackCalls.Load() < 4 { + t.Fatalf("fallback calls = %d, want at least 4", fallbackCalls.Load()) + } +} diff --git a/internal/ai/litellm/client.go b/internal/ai/litellm/client.go index afd48a8..3c9f1b0 100644 --- a/internal/ai/litellm/client.go +++ b/internal/ai/litellm/client.go @@ -9,22 +9,22 @@ import ( ) func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) { - fallback := cfg.LiteLLM.FallbackMetadataModel - if fallback == "" { - fallback = cfg.Metadata.FallbackModel + fallbacks := cfg.LiteLLM.EffectiveFallbackMetadataModels() + if len(fallbacks) == 0 { + fallbacks = cfg.Metadata.EffectiveFallbackModels() } return compat.New(compat.Config{ - Name: "litellm", - BaseURL: cfg.LiteLLM.BaseURL, - APIKey: cfg.LiteLLM.APIKey, - EmbeddingModel: cfg.LiteLLM.EmbeddingModel, - MetadataModel: cfg.LiteLLM.MetadataModel, - FallbackMetadataModel: fallback, - Temperature: cfg.Metadata.Temperature, - Headers: cfg.LiteLLM.RequestHeaders, - HTTPClient: httpClient, - Log: log, - Dimensions: cfg.Embeddings.Dimensions, - LogConversations: cfg.Metadata.LogConversations, + Name: "litellm", + BaseURL: cfg.LiteLLM.BaseURL, + APIKey: cfg.LiteLLM.APIKey, + EmbeddingModel: cfg.LiteLLM.EmbeddingModel, + MetadataModel: cfg.LiteLLM.MetadataModel, + FallbackMetadataModels: fallbacks, + Temperature: cfg.Metadata.Temperature, + Headers: cfg.LiteLLM.RequestHeaders, + HTTPClient: httpClient, + Log: log, + Dimensions: cfg.Embeddings.Dimensions, + LogConversations: cfg.Metadata.LogConversations, }), nil } diff --git a/internal/ai/ollama/client.go b/internal/ai/ollama/client.go index c5f692a..69abf8e 100644 --- a/internal/ai/ollama/client.go +++ b/internal/ai/ollama/client.go @@ -10,17 +10,17 @@ import ( func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) { return compat.New(compat.Config{ - Name: "ollama", - BaseURL: cfg.Ollama.BaseURL, - APIKey: cfg.Ollama.APIKey, - EmbeddingModel: cfg.Embeddings.Model, - MetadataModel: cfg.Metadata.Model, - FallbackMetadataModel: cfg.Metadata.FallbackModel, - Temperature: cfg.Metadata.Temperature, - Headers: cfg.Ollama.RequestHeaders, - HTTPClient: httpClient, - Log: log, - Dimensions: cfg.Embeddings.Dimensions, - LogConversations: cfg.Metadata.LogConversations, + Name: "ollama", + BaseURL: cfg.Ollama.BaseURL, + APIKey: cfg.Ollama.APIKey, + EmbeddingModel: cfg.Embeddings.Model, + MetadataModel: cfg.Metadata.Model, + FallbackMetadataModels: cfg.Metadata.EffectiveFallbackModels(), + Temperature: cfg.Metadata.Temperature, + Headers: cfg.Ollama.RequestHeaders, + HTTPClient: httpClient, + Log: log, + Dimensions: cfg.Embeddings.Dimensions, + LogConversations: cfg.Metadata.LogConversations, }), nil } diff --git a/internal/ai/openrouter/client.go b/internal/ai/openrouter/client.go index e5e94c6..b2fe6d0 100644 --- a/internal/ai/openrouter/client.go +++ b/internal/ai/openrouter/client.go @@ -21,17 +21,17 @@ func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compa } return compat.New(compat.Config{ - Name: "openrouter", - BaseURL: cfg.OpenRouter.BaseURL, - APIKey: cfg.OpenRouter.APIKey, - EmbeddingModel: cfg.Embeddings.Model, - MetadataModel: cfg.Metadata.Model, - FallbackMetadataModel: cfg.Metadata.FallbackModel, - Temperature: cfg.Metadata.Temperature, - Headers: headers, - HTTPClient: httpClient, - Log: log, - Dimensions: cfg.Embeddings.Dimensions, - LogConversations: cfg.Metadata.LogConversations, + Name: "openrouter", + BaseURL: cfg.OpenRouter.BaseURL, + APIKey: cfg.OpenRouter.APIKey, + EmbeddingModel: cfg.Embeddings.Model, + MetadataModel: cfg.Metadata.Model, + FallbackMetadataModels: cfg.Metadata.EffectiveFallbackModels(), + Temperature: cfg.Metadata.Temperature, + Headers: headers, + HTTPClient: httpClient, + Log: log, + Dimensions: cfg.Embeddings.Dimensions, + LogConversations: cfg.Metadata.LogConversations, }), nil } diff --git a/internal/config/config.go b/internal/config/config.go index d588ce1..da39545 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -84,20 +84,22 @@ type AIEmbeddingConfig struct { } type AIMetadataConfig struct { - Model string `yaml:"model"` - FallbackModel string `yaml:"fallback_model"` - Temperature float64 `yaml:"temperature"` - LogConversations bool `yaml:"log_conversations"` + Model string `yaml:"model"` + FallbackModels []string `yaml:"fallback_models"` + FallbackModel string `yaml:"fallback_model"` // legacy single fallback + Temperature float64 `yaml:"temperature"` + LogConversations bool `yaml:"log_conversations"` } type LiteLLMConfig struct { - BaseURL string `yaml:"base_url"` - APIKey string `yaml:"api_key"` - UseResponsesAPI bool `yaml:"use_responses_api"` - RequestHeaders map[string]string `yaml:"request_headers"` - EmbeddingModel string `yaml:"embedding_model"` - MetadataModel string `yaml:"metadata_model"` - FallbackMetadataModel string `yaml:"fallback_metadata_model"` + BaseURL string `yaml:"base_url"` + APIKey string `yaml:"api_key"` + UseResponsesAPI bool `yaml:"use_responses_api"` + RequestHeaders map[string]string `yaml:"request_headers"` + EmbeddingModel string `yaml:"embedding_model"` + MetadataModel string `yaml:"metadata_model"` + FallbackMetadataModels []string `yaml:"fallback_metadata_models"` + FallbackMetadataModel string `yaml:"fallback_metadata_model"` // legacy single fallback } type OllamaConfig struct { @@ -148,3 +150,45 @@ type BackfillConfig struct { MaxPerRun int `yaml:"max_per_run"` IncludeArchived bool `yaml:"include_archived"` } + +func (c AIMetadataConfig) EffectiveFallbackModels() []string { + models := make([]string, 0, len(c.FallbackModels)+1) + for _, model := range c.FallbackModels { + if model != "" { + models = append(models, model) + } + } + if c.FallbackModel != "" { + models = append(models, c.FallbackModel) + } + return dedupeNonEmpty(models) +} + +func (c LiteLLMConfig) EffectiveFallbackMetadataModels() []string { + models := make([]string, 0, len(c.FallbackMetadataModels)+1) + for _, model := range c.FallbackMetadataModels { + if model != "" { + models = append(models, model) + } + } + if c.FallbackMetadataModel != "" { + models = append(models, c.FallbackMetadataModel) + } + return dedupeNonEmpty(models) +} + +func dedupeNonEmpty(values []string) []string { + seen := make(map[string]struct{}, len(values)) + out := make([]string, 0, len(values)) + for _, value := range values { + if value == "" { + continue + } + if _, ok := seen[value]; ok { + continue + } + seen[value] = struct{}{} + out = append(out, value) + } + return out +}