feat(config): update fallback model handling to support multiple models

This commit is contained in:
2026-03-27 00:44:51 +02:00
parent 8775c3e4ce
commit 4f3d027f9e
7 changed files with 309 additions and 107 deletions

View File

@@ -42,7 +42,7 @@ ai:
dimensions: 1536
metadata:
model: "gpt-4o-mini"
fallback_model: ""
fallback_models: []
temperature: 0.1
log_conversations: false
litellm:
@@ -52,7 +52,7 @@ ai:
request_headers: {}
embedding_model: "openrouter/openai/text-embedding-3-small"
metadata_model: "gpt-4o-mini"
fallback_metadata_model: ""
fallback_metadata_models: []
ollama:
base_url: "http://localhost:11434/v1"
api_key: "ollama"

View File

@@ -13,6 +13,7 @@ import (
"regexp"
"slices"
"strings"
"sync"
"time"
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
@@ -40,13 +41,15 @@ type Client struct {
apiKey string
embeddingModel string
metadataModel string
fallbackMetadataModel string
fallbackMetadataModels []string
temperature float64
headers map[string]string
httpClient *http.Client
log *slog.Logger
dimensions int
logConversations bool
modelHealthMu sync.Mutex
modelHealth map[string]modelHealthState
}
type Config struct {
@@ -55,7 +58,7 @@ type Config struct {
APIKey string
EmbeddingModel string
MetadataModel string
FallbackMetadataModel string
FallbackMetadataModels []string
Temperature float64
Headers map[string]string
HTTPClient *http.Client
@@ -114,20 +117,50 @@ type providerError struct {
const maxMetadataAttempts = 3
const (
emptyResponseCircuitThreshold = 3
emptyResponseCircuitTTL = 5 * time.Minute
)
var (
errMetadataEmptyResponse = errors.New("metadata empty response")
errMetadataNoJSONObject = errors.New("metadata response contains no JSON object")
)
type modelHealthState struct {
consecutiveEmpty int
unhealthyUntil time.Time
}
func New(cfg Config) *Client {
fallbacks := make([]string, 0, len(cfg.FallbackMetadataModels))
seen := make(map[string]struct{}, len(cfg.FallbackMetadataModels))
for _, model := range cfg.FallbackMetadataModels {
model = strings.TrimSpace(model)
if model == "" {
continue
}
if _, ok := seen[model]; ok {
continue
}
seen[model] = struct{}{}
fallbacks = append(fallbacks, model)
}
return &Client{
name: cfg.Name,
baseURL: cfg.BaseURL,
apiKey: cfg.APIKey,
embeddingModel: cfg.EmbeddingModel,
metadataModel: cfg.MetadataModel,
fallbackMetadataModel: cfg.FallbackMetadataModel,
fallbackMetadataModels: fallbacks,
temperature: cfg.Temperature,
headers: cfg.Headers,
httpClient: cfg.HTTPClient,
log: cfg.Log,
dimensions: cfg.Dimensions,
logConversations: cfg.LogConversations,
modelHealth: make(map[string]modelHealthState),
}
}
@@ -165,21 +198,38 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype
}
result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel)
if errors.Is(err, errMetadataEmptyResponse) {
c.noteEmptyResponse(c.metadataModel)
}
if err == nil {
c.noteModelSuccess(c.metadataModel)
return result, nil
}
if c.fallbackMetadataModel != "" && ctx.Err() == nil {
for _, fallbackModel := range c.fallbackMetadataModels {
if ctx.Err() != nil {
break
}
if fallbackModel == "" || fallbackModel == c.metadataModel {
continue
}
if c.shouldBypassModel(fallbackModel) {
continue
}
if c.log != nil {
c.log.Warn("metadata extraction failed, trying fallback model",
slog.String("provider", c.name),
slog.String("primary_model", c.metadataModel),
slog.String("fallback_model", c.fallbackMetadataModel),
slog.String("fallback_model", fallbackModel),
slog.String("error", err.Error()),
)
}
fallbackResult, fallbackErr := c.extractMetadataWithModel(ctx, input, c.fallbackMetadataModel)
fallbackResult, fallbackErr := c.extractMetadataWithModel(ctx, input, fallbackModel)
if errors.Is(fallbackErr, errMetadataEmptyResponse) {
c.noteEmptyResponse(fallbackModel)
}
if fallbackErr == nil {
c.noteModelSuccess(fallbackModel)
return fallbackResult, nil
}
err = fallbackErr
@@ -196,6 +246,10 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype
}
func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) {
if c.shouldBypassModel(model) {
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: model %q temporarily bypassed after repeated empty responses", c.name, model)
}
stream := false
req := chatCompletionsRequest{
Model: model,
@@ -249,8 +303,9 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri
metadataText = stripCodeFence(metadataText)
metadataText = extractJSONObject(metadataText)
if metadataText == "" {
lastErr = fmt.Errorf("%s metadata: response contains no JSON object", c.name)
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataNoJSONObject)
if strings.TrimSpace(rawResponse) == "" && attempt < maxMetadataAttempts && ctx.Err() == nil {
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataEmptyResponse)
if c.log != nil {
c.log.Warn("metadata response empty, waiting and retrying",
slog.String("provider", c.name),
@@ -263,6 +318,9 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri
}
continue
}
if strings.TrimSpace(rawResponse) == "" {
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataEmptyResponse)
}
return thoughttypes.ThoughtMetadata{}, lastErr
}
@@ -278,7 +336,7 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri
if lastErr != nil {
return thoughttypes.ThoughtMetadata{}, lastErr
}
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: response contains no JSON object", c.name)
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: %w", c.name, errMetadataNoJSONObject)
}
func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
@@ -740,3 +798,40 @@ func sleepMetadataRetry(ctx context.Context, attempt int) error {
return nil
}
}
func (c *Client) shouldBypassModel(model string) bool {
c.modelHealthMu.Lock()
defer c.modelHealthMu.Unlock()
state, ok := c.modelHealth[model]
if !ok {
return false
}
return !state.unhealthyUntil.IsZero() && time.Now().Before(state.unhealthyUntil)
}
func (c *Client) noteEmptyResponse(model string) {
c.modelHealthMu.Lock()
defer c.modelHealthMu.Unlock()
state := c.modelHealth[model]
state.consecutiveEmpty++
if state.consecutiveEmpty >= emptyResponseCircuitThreshold {
state.unhealthyUntil = time.Now().Add(emptyResponseCircuitTTL)
if c.log != nil {
c.log.Warn("metadata model marked temporarily unhealthy after repeated empty responses",
slog.String("provider", c.name),
slog.String("model", model),
slog.Time("until", state.unhealthyUntil),
)
}
}
c.modelHealth[model] = state
}
func (c *Client) noteModelSuccess(model string) {
c.modelHealthMu.Lock()
defer c.modelHealthMu.Unlock()
delete(c.modelHealth, model)
}

View File

@@ -127,7 +127,7 @@ func TestExtractMetadataFallbackModel(t *testing.T) {
BaseURL: server.URL,
APIKey: "secret",
MetadataModel: "primary-model",
FallbackMetadataModel: "fallback-model",
FallbackMetadataModels: []string{"fallback-model"},
HTTPClient: server.Client(),
Log: discardLogger(),
})
@@ -269,7 +269,7 @@ func TestExtractMetadataFallsBackToHeuristicsWhenModelsFail(t *testing.T) {
BaseURL: server.URL,
APIKey: "secret",
MetadataModel: "primary",
FallbackMetadataModel: "secondary",
FallbackMetadataModels: []string{"secondary"},
HTTPClient: server.Client(),
Log: discardLogger(),
})
@@ -341,3 +341,66 @@ func TestExtractMetadataRetriesEmptyResponse(t *testing.T) {
t.Fatalf("metadata type = %q, want observation", metadata.Type)
}
}
func TestExtractMetadataBypassesModelAfterRepeatedEmptyResponses(t *testing.T) {
var primaryCalls atomic.Int32
var fallbackCalls atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req chatCompletionsRequest
_ = json.NewDecoder(r.Body).Decode(&req)
switch req.Model {
case "primary":
primaryCalls.Add(1)
_ = json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{"message": map[string]any{"content": ""}},
},
})
case "fallback":
fallbackCalls.Add(1)
_ = json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{"message": map[string]any{"content": "{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"mcp\"],\"type\":\"observation\",\"source\":\"mcp\"}"}},
},
})
default:
t.Fatalf("unexpected model %q", req.Model)
}
}))
defer server.Close()
client := New(Config{
Name: "test",
BaseURL: server.URL,
APIKey: "secret",
MetadataModel: "primary",
FallbackMetadataModels: []string{"fallback"},
HTTPClient: server.Client(),
Log: discardLogger(),
})
// First three calls should probe primary and then use fallback.
for i := 0; i < 3; i++ {
if _, err := client.ExtractMetadata(context.Background(), "hello"); err != nil {
t.Fatalf("ExtractMetadata() error = %v", err)
}
}
primaryBefore := primaryCalls.Load()
if primaryBefore == 0 {
t.Fatal("expected primary model to be called before bypass")
}
// Fourth call should bypass primary (no additional primary calls).
if _, err := client.ExtractMetadata(context.Background(), "hello"); err != nil {
t.Fatalf("ExtractMetadata() error = %v", err)
}
if primaryCalls.Load() != primaryBefore {
t.Fatalf("primary calls increased after bypass: before=%d after=%d", primaryBefore, primaryCalls.Load())
}
if fallbackCalls.Load() < 4 {
t.Fatalf("fallback calls = %d, want at least 4", fallbackCalls.Load())
}
}

View File

@@ -9,9 +9,9 @@ import (
)
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
fallback := cfg.LiteLLM.FallbackMetadataModel
if fallback == "" {
fallback = cfg.Metadata.FallbackModel
fallbacks := cfg.LiteLLM.EffectiveFallbackMetadataModels()
if len(fallbacks) == 0 {
fallbacks = cfg.Metadata.EffectiveFallbackModels()
}
return compat.New(compat.Config{
Name: "litellm",
@@ -19,7 +19,7 @@ func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compa
APIKey: cfg.LiteLLM.APIKey,
EmbeddingModel: cfg.LiteLLM.EmbeddingModel,
MetadataModel: cfg.LiteLLM.MetadataModel,
FallbackMetadataModel: fallback,
FallbackMetadataModels: fallbacks,
Temperature: cfg.Metadata.Temperature,
Headers: cfg.LiteLLM.RequestHeaders,
HTTPClient: httpClient,

View File

@@ -15,7 +15,7 @@ func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compa
APIKey: cfg.Ollama.APIKey,
EmbeddingModel: cfg.Embeddings.Model,
MetadataModel: cfg.Metadata.Model,
FallbackMetadataModel: cfg.Metadata.FallbackModel,
FallbackMetadataModels: cfg.Metadata.EffectiveFallbackModels(),
Temperature: cfg.Metadata.Temperature,
Headers: cfg.Ollama.RequestHeaders,
HTTPClient: httpClient,

View File

@@ -26,7 +26,7 @@ func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compa
APIKey: cfg.OpenRouter.APIKey,
EmbeddingModel: cfg.Embeddings.Model,
MetadataModel: cfg.Metadata.Model,
FallbackMetadataModel: cfg.Metadata.FallbackModel,
FallbackMetadataModels: cfg.Metadata.EffectiveFallbackModels(),
Temperature: cfg.Metadata.Temperature,
Headers: headers,
HTTPClient: httpClient,

View File

@@ -85,7 +85,8 @@ type AIEmbeddingConfig struct {
type AIMetadataConfig struct {
Model string `yaml:"model"`
FallbackModel string `yaml:"fallback_model"`
FallbackModels []string `yaml:"fallback_models"`
FallbackModel string `yaml:"fallback_model"` // legacy single fallback
Temperature float64 `yaml:"temperature"`
LogConversations bool `yaml:"log_conversations"`
}
@@ -97,7 +98,8 @@ type LiteLLMConfig struct {
RequestHeaders map[string]string `yaml:"request_headers"`
EmbeddingModel string `yaml:"embedding_model"`
MetadataModel string `yaml:"metadata_model"`
FallbackMetadataModel string `yaml:"fallback_metadata_model"`
FallbackMetadataModels []string `yaml:"fallback_metadata_models"`
FallbackMetadataModel string `yaml:"fallback_metadata_model"` // legacy single fallback
}
type OllamaConfig struct {
@@ -148,3 +150,45 @@ type BackfillConfig struct {
MaxPerRun int `yaml:"max_per_run"`
IncludeArchived bool `yaml:"include_archived"`
}
func (c AIMetadataConfig) EffectiveFallbackModels() []string {
models := make([]string, 0, len(c.FallbackModels)+1)
for _, model := range c.FallbackModels {
if model != "" {
models = append(models, model)
}
}
if c.FallbackModel != "" {
models = append(models, c.FallbackModel)
}
return dedupeNonEmpty(models)
}
func (c LiteLLMConfig) EffectiveFallbackMetadataModels() []string {
models := make([]string, 0, len(c.FallbackMetadataModels)+1)
for _, model := range c.FallbackMetadataModels {
if model != "" {
models = append(models, model)
}
}
if c.FallbackMetadataModel != "" {
models = append(models, c.FallbackMetadataModel)
}
return dedupeNonEmpty(models)
}
func dedupeNonEmpty(values []string) []string {
seen := make(map[string]struct{}, len(values))
out := make([]string, 0, len(values))
for _, value := range values {
if value == "" {
continue
}
if _, ok := seen[value]; ok {
continue
}
seen[value] = struct{}{}
out = append(out, value)
}
return out
}