feat(config): update fallback model handling to support multiple models

This commit is contained in:
2026-03-27 00:44:51 +02:00
parent 8775c3e4ce
commit 4f3d027f9e
7 changed files with 309 additions and 107 deletions

View File

@@ -42,7 +42,7 @@ ai:
dimensions: 1536 dimensions: 1536
metadata: metadata:
model: "gpt-4o-mini" model: "gpt-4o-mini"
fallback_model: "" fallback_models: []
temperature: 0.1 temperature: 0.1
log_conversations: false log_conversations: false
litellm: litellm:
@@ -52,7 +52,7 @@ ai:
request_headers: {} request_headers: {}
embedding_model: "openrouter/openai/text-embedding-3-small" embedding_model: "openrouter/openai/text-embedding-3-small"
metadata_model: "gpt-4o-mini" metadata_model: "gpt-4o-mini"
fallback_metadata_model: "" fallback_metadata_models: []
ollama: ollama:
base_url: "http://localhost:11434/v1" base_url: "http://localhost:11434/v1"
api_key: "ollama" api_key: "ollama"

View File

@@ -13,6 +13,7 @@ import (
"regexp" "regexp"
"slices" "slices"
"strings" "strings"
"sync"
"time" "time"
thoughttypes "git.warky.dev/wdevs/amcs/internal/types" thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
@@ -35,33 +36,35 @@ Rules:
- Do not include any text outside the JSON object.` - Do not include any text outside the JSON object.`
type Client struct { type Client struct {
name string name string
baseURL string baseURL string
apiKey string apiKey string
embeddingModel string embeddingModel string
metadataModel string metadataModel string
fallbackMetadataModel string fallbackMetadataModels []string
temperature float64 temperature float64
headers map[string]string headers map[string]string
httpClient *http.Client httpClient *http.Client
log *slog.Logger log *slog.Logger
dimensions int dimensions int
logConversations bool logConversations bool
modelHealthMu sync.Mutex
modelHealth map[string]modelHealthState
} }
type Config struct { type Config struct {
Name string Name string
BaseURL string BaseURL string
APIKey string APIKey string
EmbeddingModel string EmbeddingModel string
MetadataModel string MetadataModel string
FallbackMetadataModel string FallbackMetadataModels []string
Temperature float64 Temperature float64
Headers map[string]string Headers map[string]string
HTTPClient *http.Client HTTPClient *http.Client
Log *slog.Logger Log *slog.Logger
Dimensions int Dimensions int
LogConversations bool LogConversations bool
} }
type embeddingsRequest struct { type embeddingsRequest struct {
@@ -114,20 +117,50 @@ type providerError struct {
const maxMetadataAttempts = 3 const maxMetadataAttempts = 3
const (
emptyResponseCircuitThreshold = 3
emptyResponseCircuitTTL = 5 * time.Minute
)
var (
errMetadataEmptyResponse = errors.New("metadata empty response")
errMetadataNoJSONObject = errors.New("metadata response contains no JSON object")
)
type modelHealthState struct {
consecutiveEmpty int
unhealthyUntil time.Time
}
func New(cfg Config) *Client { func New(cfg Config) *Client {
fallbacks := make([]string, 0, len(cfg.FallbackMetadataModels))
seen := make(map[string]struct{}, len(cfg.FallbackMetadataModels))
for _, model := range cfg.FallbackMetadataModels {
model = strings.TrimSpace(model)
if model == "" {
continue
}
if _, ok := seen[model]; ok {
continue
}
seen[model] = struct{}{}
fallbacks = append(fallbacks, model)
}
return &Client{ return &Client{
name: cfg.Name, name: cfg.Name,
baseURL: cfg.BaseURL, baseURL: cfg.BaseURL,
apiKey: cfg.APIKey, apiKey: cfg.APIKey,
embeddingModel: cfg.EmbeddingModel, embeddingModel: cfg.EmbeddingModel,
metadataModel: cfg.MetadataModel, metadataModel: cfg.MetadataModel,
fallbackMetadataModel: cfg.FallbackMetadataModel, fallbackMetadataModels: fallbacks,
temperature: cfg.Temperature, temperature: cfg.Temperature,
headers: cfg.Headers, headers: cfg.Headers,
httpClient: cfg.HTTPClient, httpClient: cfg.HTTPClient,
log: cfg.Log, log: cfg.Log,
dimensions: cfg.Dimensions, dimensions: cfg.Dimensions,
logConversations: cfg.LogConversations, logConversations: cfg.LogConversations,
modelHealth: make(map[string]modelHealthState),
} }
} }
@@ -165,21 +198,38 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype
} }
result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel) result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel)
if errors.Is(err, errMetadataEmptyResponse) {
c.noteEmptyResponse(c.metadataModel)
}
if err == nil { if err == nil {
c.noteModelSuccess(c.metadataModel)
return result, nil return result, nil
} }
if c.fallbackMetadataModel != "" && ctx.Err() == nil { for _, fallbackModel := range c.fallbackMetadataModels {
if ctx.Err() != nil {
break
}
if fallbackModel == "" || fallbackModel == c.metadataModel {
continue
}
if c.shouldBypassModel(fallbackModel) {
continue
}
if c.log != nil { if c.log != nil {
c.log.Warn("metadata extraction failed, trying fallback model", c.log.Warn("metadata extraction failed, trying fallback model",
slog.String("provider", c.name), slog.String("provider", c.name),
slog.String("primary_model", c.metadataModel), slog.String("primary_model", c.metadataModel),
slog.String("fallback_model", c.fallbackMetadataModel), slog.String("fallback_model", fallbackModel),
slog.String("error", err.Error()), slog.String("error", err.Error()),
) )
} }
fallbackResult, fallbackErr := c.extractMetadataWithModel(ctx, input, c.fallbackMetadataModel) fallbackResult, fallbackErr := c.extractMetadataWithModel(ctx, input, fallbackModel)
if errors.Is(fallbackErr, errMetadataEmptyResponse) {
c.noteEmptyResponse(fallbackModel)
}
if fallbackErr == nil { if fallbackErr == nil {
c.noteModelSuccess(fallbackModel)
return fallbackResult, nil return fallbackResult, nil
} }
err = fallbackErr err = fallbackErr
@@ -196,6 +246,10 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype
} }
func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) { func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) {
if c.shouldBypassModel(model) {
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: model %q temporarily bypassed after repeated empty responses", c.name, model)
}
stream := false stream := false
req := chatCompletionsRequest{ req := chatCompletionsRequest{
Model: model, Model: model,
@@ -249,8 +303,9 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri
metadataText = stripCodeFence(metadataText) metadataText = stripCodeFence(metadataText)
metadataText = extractJSONObject(metadataText) metadataText = extractJSONObject(metadataText)
if metadataText == "" { if metadataText == "" {
lastErr = fmt.Errorf("%s metadata: response contains no JSON object", c.name) lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataNoJSONObject)
if strings.TrimSpace(rawResponse) == "" && attempt < maxMetadataAttempts && ctx.Err() == nil { if strings.TrimSpace(rawResponse) == "" && attempt < maxMetadataAttempts && ctx.Err() == nil {
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataEmptyResponse)
if c.log != nil { if c.log != nil {
c.log.Warn("metadata response empty, waiting and retrying", c.log.Warn("metadata response empty, waiting and retrying",
slog.String("provider", c.name), slog.String("provider", c.name),
@@ -263,6 +318,9 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri
} }
continue continue
} }
if strings.TrimSpace(rawResponse) == "" {
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataEmptyResponse)
}
return thoughttypes.ThoughtMetadata{}, lastErr return thoughttypes.ThoughtMetadata{}, lastErr
} }
@@ -278,7 +336,7 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri
if lastErr != nil { if lastErr != nil {
return thoughttypes.ThoughtMetadata{}, lastErr return thoughttypes.ThoughtMetadata{}, lastErr
} }
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: response contains no JSON object", c.name) return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: %w", c.name, errMetadataNoJSONObject)
} }
func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) { func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
@@ -740,3 +798,40 @@ func sleepMetadataRetry(ctx context.Context, attempt int) error {
return nil return nil
} }
} }
func (c *Client) shouldBypassModel(model string) bool {
c.modelHealthMu.Lock()
defer c.modelHealthMu.Unlock()
state, ok := c.modelHealth[model]
if !ok {
return false
}
return !state.unhealthyUntil.IsZero() && time.Now().Before(state.unhealthyUntil)
}
func (c *Client) noteEmptyResponse(model string) {
c.modelHealthMu.Lock()
defer c.modelHealthMu.Unlock()
state := c.modelHealth[model]
state.consecutiveEmpty++
if state.consecutiveEmpty >= emptyResponseCircuitThreshold {
state.unhealthyUntil = time.Now().Add(emptyResponseCircuitTTL)
if c.log != nil {
c.log.Warn("metadata model marked temporarily unhealthy after repeated empty responses",
slog.String("provider", c.name),
slog.String("model", model),
slog.Time("until", state.unhealthyUntil),
)
}
}
c.modelHealth[model] = state
}
func (c *Client) noteModelSuccess(model string) {
c.modelHealthMu.Lock()
defer c.modelHealthMu.Unlock()
delete(c.modelHealth, model)
}

View File

@@ -123,13 +123,13 @@ func TestExtractMetadataFallbackModel(t *testing.T) {
defer server.Close() defer server.Close()
client := New(Config{ client := New(Config{
Name: "test", Name: "test",
BaseURL: server.URL, BaseURL: server.URL,
APIKey: "secret", APIKey: "secret",
MetadataModel: "primary-model", MetadataModel: "primary-model",
FallbackMetadataModel: "fallback-model", FallbackMetadataModels: []string{"fallback-model"},
HTTPClient: server.Client(), HTTPClient: server.Client(),
Log: discardLogger(), Log: discardLogger(),
}) })
metadata, err := client.ExtractMetadata(context.Background(), "hello") metadata, err := client.ExtractMetadata(context.Background(), "hello")
@@ -265,13 +265,13 @@ func TestExtractMetadataFallsBackToHeuristicsWhenModelsFail(t *testing.T) {
defer server.Close() defer server.Close()
client := New(Config{ client := New(Config{
Name: "test", Name: "test",
BaseURL: server.URL, BaseURL: server.URL,
APIKey: "secret", APIKey: "secret",
MetadataModel: "primary", MetadataModel: "primary",
FallbackMetadataModel: "secondary", FallbackMetadataModels: []string{"secondary"},
HTTPClient: server.Client(), HTTPClient: server.Client(),
Log: discardLogger(), Log: discardLogger(),
}) })
input := "Personal profile - Hein (Warkanum):\n- Born: 23 May 1989\n- Wife: Cindy, born 16 November 1994" input := "Personal profile - Hein (Warkanum):\n- Born: 23 May 1989\n- Wife: Cindy, born 16 November 1994"
@@ -341,3 +341,66 @@ func TestExtractMetadataRetriesEmptyResponse(t *testing.T) {
t.Fatalf("metadata type = %q, want observation", metadata.Type) t.Fatalf("metadata type = %q, want observation", metadata.Type)
} }
} }
func TestExtractMetadataBypassesModelAfterRepeatedEmptyResponses(t *testing.T) {
var primaryCalls atomic.Int32
var fallbackCalls atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req chatCompletionsRequest
_ = json.NewDecoder(r.Body).Decode(&req)
switch req.Model {
case "primary":
primaryCalls.Add(1)
_ = json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{"message": map[string]any{"content": ""}},
},
})
case "fallback":
fallbackCalls.Add(1)
_ = json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{"message": map[string]any{"content": "{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"mcp\"],\"type\":\"observation\",\"source\":\"mcp\"}"}},
},
})
default:
t.Fatalf("unexpected model %q", req.Model)
}
}))
defer server.Close()
client := New(Config{
Name: "test",
BaseURL: server.URL,
APIKey: "secret",
MetadataModel: "primary",
FallbackMetadataModels: []string{"fallback"},
HTTPClient: server.Client(),
Log: discardLogger(),
})
// First three calls should probe primary and then use fallback.
for i := 0; i < 3; i++ {
if _, err := client.ExtractMetadata(context.Background(), "hello"); err != nil {
t.Fatalf("ExtractMetadata() error = %v", err)
}
}
primaryBefore := primaryCalls.Load()
if primaryBefore == 0 {
t.Fatal("expected primary model to be called before bypass")
}
// Fourth call should bypass primary (no additional primary calls).
if _, err := client.ExtractMetadata(context.Background(), "hello"); err != nil {
t.Fatalf("ExtractMetadata() error = %v", err)
}
if primaryCalls.Load() != primaryBefore {
t.Fatalf("primary calls increased after bypass: before=%d after=%d", primaryBefore, primaryCalls.Load())
}
if fallbackCalls.Load() < 4 {
t.Fatalf("fallback calls = %d, want at least 4", fallbackCalls.Load())
}
}

View File

@@ -9,22 +9,22 @@ import (
) )
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) { func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
fallback := cfg.LiteLLM.FallbackMetadataModel fallbacks := cfg.LiteLLM.EffectiveFallbackMetadataModels()
if fallback == "" { if len(fallbacks) == 0 {
fallback = cfg.Metadata.FallbackModel fallbacks = cfg.Metadata.EffectiveFallbackModels()
} }
return compat.New(compat.Config{ return compat.New(compat.Config{
Name: "litellm", Name: "litellm",
BaseURL: cfg.LiteLLM.BaseURL, BaseURL: cfg.LiteLLM.BaseURL,
APIKey: cfg.LiteLLM.APIKey, APIKey: cfg.LiteLLM.APIKey,
EmbeddingModel: cfg.LiteLLM.EmbeddingModel, EmbeddingModel: cfg.LiteLLM.EmbeddingModel,
MetadataModel: cfg.LiteLLM.MetadataModel, MetadataModel: cfg.LiteLLM.MetadataModel,
FallbackMetadataModel: fallback, FallbackMetadataModels: fallbacks,
Temperature: cfg.Metadata.Temperature, Temperature: cfg.Metadata.Temperature,
Headers: cfg.LiteLLM.RequestHeaders, Headers: cfg.LiteLLM.RequestHeaders,
HTTPClient: httpClient, HTTPClient: httpClient,
Log: log, Log: log,
Dimensions: cfg.Embeddings.Dimensions, Dimensions: cfg.Embeddings.Dimensions,
LogConversations: cfg.Metadata.LogConversations, LogConversations: cfg.Metadata.LogConversations,
}), nil }), nil
} }

View File

@@ -10,17 +10,17 @@ import (
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) { func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
return compat.New(compat.Config{ return compat.New(compat.Config{
Name: "ollama", Name: "ollama",
BaseURL: cfg.Ollama.BaseURL, BaseURL: cfg.Ollama.BaseURL,
APIKey: cfg.Ollama.APIKey, APIKey: cfg.Ollama.APIKey,
EmbeddingModel: cfg.Embeddings.Model, EmbeddingModel: cfg.Embeddings.Model,
MetadataModel: cfg.Metadata.Model, MetadataModel: cfg.Metadata.Model,
FallbackMetadataModel: cfg.Metadata.FallbackModel, FallbackMetadataModels: cfg.Metadata.EffectiveFallbackModels(),
Temperature: cfg.Metadata.Temperature, Temperature: cfg.Metadata.Temperature,
Headers: cfg.Ollama.RequestHeaders, Headers: cfg.Ollama.RequestHeaders,
HTTPClient: httpClient, HTTPClient: httpClient,
Log: log, Log: log,
Dimensions: cfg.Embeddings.Dimensions, Dimensions: cfg.Embeddings.Dimensions,
LogConversations: cfg.Metadata.LogConversations, LogConversations: cfg.Metadata.LogConversations,
}), nil }), nil
} }

View File

@@ -21,17 +21,17 @@ func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compa
} }
return compat.New(compat.Config{ return compat.New(compat.Config{
Name: "openrouter", Name: "openrouter",
BaseURL: cfg.OpenRouter.BaseURL, BaseURL: cfg.OpenRouter.BaseURL,
APIKey: cfg.OpenRouter.APIKey, APIKey: cfg.OpenRouter.APIKey,
EmbeddingModel: cfg.Embeddings.Model, EmbeddingModel: cfg.Embeddings.Model,
MetadataModel: cfg.Metadata.Model, MetadataModel: cfg.Metadata.Model,
FallbackMetadataModel: cfg.Metadata.FallbackModel, FallbackMetadataModels: cfg.Metadata.EffectiveFallbackModels(),
Temperature: cfg.Metadata.Temperature, Temperature: cfg.Metadata.Temperature,
Headers: headers, Headers: headers,
HTTPClient: httpClient, HTTPClient: httpClient,
Log: log, Log: log,
Dimensions: cfg.Embeddings.Dimensions, Dimensions: cfg.Embeddings.Dimensions,
LogConversations: cfg.Metadata.LogConversations, LogConversations: cfg.Metadata.LogConversations,
}), nil }), nil
} }

View File

@@ -84,20 +84,22 @@ type AIEmbeddingConfig struct {
} }
type AIMetadataConfig struct { type AIMetadataConfig struct {
Model string `yaml:"model"` Model string `yaml:"model"`
FallbackModel string `yaml:"fallback_model"` FallbackModels []string `yaml:"fallback_models"`
Temperature float64 `yaml:"temperature"` FallbackModel string `yaml:"fallback_model"` // legacy single fallback
LogConversations bool `yaml:"log_conversations"` Temperature float64 `yaml:"temperature"`
LogConversations bool `yaml:"log_conversations"`
} }
type LiteLLMConfig struct { type LiteLLMConfig struct {
BaseURL string `yaml:"base_url"` BaseURL string `yaml:"base_url"`
APIKey string `yaml:"api_key"` APIKey string `yaml:"api_key"`
UseResponsesAPI bool `yaml:"use_responses_api"` UseResponsesAPI bool `yaml:"use_responses_api"`
RequestHeaders map[string]string `yaml:"request_headers"` RequestHeaders map[string]string `yaml:"request_headers"`
EmbeddingModel string `yaml:"embedding_model"` EmbeddingModel string `yaml:"embedding_model"`
MetadataModel string `yaml:"metadata_model"` MetadataModel string `yaml:"metadata_model"`
FallbackMetadataModel string `yaml:"fallback_metadata_model"` FallbackMetadataModels []string `yaml:"fallback_metadata_models"`
FallbackMetadataModel string `yaml:"fallback_metadata_model"` // legacy single fallback
} }
type OllamaConfig struct { type OllamaConfig struct {
@@ -148,3 +150,45 @@ type BackfillConfig struct {
MaxPerRun int `yaml:"max_per_run"` MaxPerRun int `yaml:"max_per_run"`
IncludeArchived bool `yaml:"include_archived"` IncludeArchived bool `yaml:"include_archived"`
} }
func (c AIMetadataConfig) EffectiveFallbackModels() []string {
models := make([]string, 0, len(c.FallbackModels)+1)
for _, model := range c.FallbackModels {
if model != "" {
models = append(models, model)
}
}
if c.FallbackModel != "" {
models = append(models, c.FallbackModel)
}
return dedupeNonEmpty(models)
}
func (c LiteLLMConfig) EffectiveFallbackMetadataModels() []string {
models := make([]string, 0, len(c.FallbackMetadataModels)+1)
for _, model := range c.FallbackMetadataModels {
if model != "" {
models = append(models, model)
}
}
if c.FallbackMetadataModel != "" {
models = append(models, c.FallbackMetadataModel)
}
return dedupeNonEmpty(models)
}
func dedupeNonEmpty(values []string) []string {
seen := make(map[string]struct{}, len(values))
out := make([]string, 0, len(values))
for _, value := range values {
if value == "" {
continue
}
if _, ok := seen[value]; ok {
continue
}
seen[value] = struct{}{}
out = append(out, value)
}
return out
}