feat(config): update fallback model handling to support multiple models

This commit is contained in:
2026-03-27 00:44:51 +02:00
parent 8775c3e4ce
commit 4f3d027f9e
7 changed files with 309 additions and 107 deletions

View File

@@ -13,6 +13,7 @@ import (
"regexp"
"slices"
"strings"
"sync"
"time"
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
@@ -35,33 +36,35 @@ Rules:
- Do not include any text outside the JSON object.`
type Client struct {
name string
baseURL string
apiKey string
embeddingModel string
metadataModel string
fallbackMetadataModel string
temperature float64
headers map[string]string
httpClient *http.Client
log *slog.Logger
dimensions int
logConversations bool
name string
baseURL string
apiKey string
embeddingModel string
metadataModel string
fallbackMetadataModels []string
temperature float64
headers map[string]string
httpClient *http.Client
log *slog.Logger
dimensions int
logConversations bool
modelHealthMu sync.Mutex
modelHealth map[string]modelHealthState
}
type Config struct {
Name string
BaseURL string
APIKey string
EmbeddingModel string
MetadataModel string
FallbackMetadataModel string
Temperature float64
Headers map[string]string
HTTPClient *http.Client
Log *slog.Logger
Dimensions int
LogConversations bool
Name string
BaseURL string
APIKey string
EmbeddingModel string
MetadataModel string
FallbackMetadataModels []string
Temperature float64
Headers map[string]string
HTTPClient *http.Client
Log *slog.Logger
Dimensions int
LogConversations bool
}
type embeddingsRequest struct {
@@ -114,20 +117,50 @@ type providerError struct {
const maxMetadataAttempts = 3
const (
emptyResponseCircuitThreshold = 3
emptyResponseCircuitTTL = 5 * time.Minute
)
var (
errMetadataEmptyResponse = errors.New("metadata empty response")
errMetadataNoJSONObject = errors.New("metadata response contains no JSON object")
)
type modelHealthState struct {
consecutiveEmpty int
unhealthyUntil time.Time
}
func New(cfg Config) *Client {
fallbacks := make([]string, 0, len(cfg.FallbackMetadataModels))
seen := make(map[string]struct{}, len(cfg.FallbackMetadataModels))
for _, model := range cfg.FallbackMetadataModels {
model = strings.TrimSpace(model)
if model == "" {
continue
}
if _, ok := seen[model]; ok {
continue
}
seen[model] = struct{}{}
fallbacks = append(fallbacks, model)
}
return &Client{
name: cfg.Name,
baseURL: cfg.BaseURL,
apiKey: cfg.APIKey,
embeddingModel: cfg.EmbeddingModel,
metadataModel: cfg.MetadataModel,
fallbackMetadataModel: cfg.FallbackMetadataModel,
temperature: cfg.Temperature,
headers: cfg.Headers,
httpClient: cfg.HTTPClient,
log: cfg.Log,
dimensions: cfg.Dimensions,
logConversations: cfg.LogConversations,
name: cfg.Name,
baseURL: cfg.BaseURL,
apiKey: cfg.APIKey,
embeddingModel: cfg.EmbeddingModel,
metadataModel: cfg.MetadataModel,
fallbackMetadataModels: fallbacks,
temperature: cfg.Temperature,
headers: cfg.Headers,
httpClient: cfg.HTTPClient,
log: cfg.Log,
dimensions: cfg.Dimensions,
logConversations: cfg.LogConversations,
modelHealth: make(map[string]modelHealthState),
}
}
@@ -165,21 +198,38 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype
}
result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel)
if errors.Is(err, errMetadataEmptyResponse) {
c.noteEmptyResponse(c.metadataModel)
}
if err == nil {
c.noteModelSuccess(c.metadataModel)
return result, nil
}
if c.fallbackMetadataModel != "" && ctx.Err() == nil {
for _, fallbackModel := range c.fallbackMetadataModels {
if ctx.Err() != nil {
break
}
if fallbackModel == "" || fallbackModel == c.metadataModel {
continue
}
if c.shouldBypassModel(fallbackModel) {
continue
}
if c.log != nil {
c.log.Warn("metadata extraction failed, trying fallback model",
slog.String("provider", c.name),
slog.String("primary_model", c.metadataModel),
slog.String("fallback_model", c.fallbackMetadataModel),
slog.String("fallback_model", fallbackModel),
slog.String("error", err.Error()),
)
}
fallbackResult, fallbackErr := c.extractMetadataWithModel(ctx, input, c.fallbackMetadataModel)
fallbackResult, fallbackErr := c.extractMetadataWithModel(ctx, input, fallbackModel)
if errors.Is(fallbackErr, errMetadataEmptyResponse) {
c.noteEmptyResponse(fallbackModel)
}
if fallbackErr == nil {
c.noteModelSuccess(fallbackModel)
return fallbackResult, nil
}
err = fallbackErr
@@ -196,6 +246,10 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype
}
func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) {
if c.shouldBypassModel(model) {
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: model %q temporarily bypassed after repeated empty responses", c.name, model)
}
stream := false
req := chatCompletionsRequest{
Model: model,
@@ -249,8 +303,9 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri
metadataText = stripCodeFence(metadataText)
metadataText = extractJSONObject(metadataText)
if metadataText == "" {
lastErr = fmt.Errorf("%s metadata: response contains no JSON object", c.name)
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataNoJSONObject)
if strings.TrimSpace(rawResponse) == "" && attempt < maxMetadataAttempts && ctx.Err() == nil {
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataEmptyResponse)
if c.log != nil {
c.log.Warn("metadata response empty, waiting and retrying",
slog.String("provider", c.name),
@@ -263,6 +318,9 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri
}
continue
}
if strings.TrimSpace(rawResponse) == "" {
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataEmptyResponse)
}
return thoughttypes.ThoughtMetadata{}, lastErr
}
@@ -278,7 +336,7 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri
if lastErr != nil {
return thoughttypes.ThoughtMetadata{}, lastErr
}
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: response contains no JSON object", c.name)
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: %w", c.name, errMetadataNoJSONObject)
}
func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
@@ -740,3 +798,40 @@ func sleepMetadataRetry(ctx context.Context, attempt int) error {
return nil
}
}
func (c *Client) shouldBypassModel(model string) bool {
c.modelHealthMu.Lock()
defer c.modelHealthMu.Unlock()
state, ok := c.modelHealth[model]
if !ok {
return false
}
return !state.unhealthyUntil.IsZero() && time.Now().Before(state.unhealthyUntil)
}
func (c *Client) noteEmptyResponse(model string) {
c.modelHealthMu.Lock()
defer c.modelHealthMu.Unlock()
state := c.modelHealth[model]
state.consecutiveEmpty++
if state.consecutiveEmpty >= emptyResponseCircuitThreshold {
state.unhealthyUntil = time.Now().Add(emptyResponseCircuitTTL)
if c.log != nil {
c.log.Warn("metadata model marked temporarily unhealthy after repeated empty responses",
slog.String("provider", c.name),
slog.String("model", model),
slog.Time("until", state.unhealthyUntil),
)
}
}
c.modelHealth[model] = state
}
func (c *Client) noteModelSuccess(model string) {
c.modelHealthMu.Lock()
defer c.modelHealthMu.Unlock()
delete(c.modelHealth, model)
}

View File

@@ -123,13 +123,13 @@ func TestExtractMetadataFallbackModel(t *testing.T) {
defer server.Close()
client := New(Config{
Name: "test",
BaseURL: server.URL,
APIKey: "secret",
MetadataModel: "primary-model",
FallbackMetadataModel: "fallback-model",
HTTPClient: server.Client(),
Log: discardLogger(),
Name: "test",
BaseURL: server.URL,
APIKey: "secret",
MetadataModel: "primary-model",
FallbackMetadataModels: []string{"fallback-model"},
HTTPClient: server.Client(),
Log: discardLogger(),
})
metadata, err := client.ExtractMetadata(context.Background(), "hello")
@@ -265,13 +265,13 @@ func TestExtractMetadataFallsBackToHeuristicsWhenModelsFail(t *testing.T) {
defer server.Close()
client := New(Config{
Name: "test",
BaseURL: server.URL,
APIKey: "secret",
MetadataModel: "primary",
FallbackMetadataModel: "secondary",
HTTPClient: server.Client(),
Log: discardLogger(),
Name: "test",
BaseURL: server.URL,
APIKey: "secret",
MetadataModel: "primary",
FallbackMetadataModels: []string{"secondary"},
HTTPClient: server.Client(),
Log: discardLogger(),
})
input := "Personal profile - Hein (Warkanum):\n- Born: 23 May 1989\n- Wife: Cindy, born 16 November 1994"
@@ -341,3 +341,66 @@ func TestExtractMetadataRetriesEmptyResponse(t *testing.T) {
t.Fatalf("metadata type = %q, want observation", metadata.Type)
}
}
func TestExtractMetadataBypassesModelAfterRepeatedEmptyResponses(t *testing.T) {
var primaryCalls atomic.Int32
var fallbackCalls atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req chatCompletionsRequest
_ = json.NewDecoder(r.Body).Decode(&req)
switch req.Model {
case "primary":
primaryCalls.Add(1)
_ = json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{"message": map[string]any{"content": ""}},
},
})
case "fallback":
fallbackCalls.Add(1)
_ = json.NewEncoder(w).Encode(map[string]any{
"choices": []map[string]any{
{"message": map[string]any{"content": "{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"mcp\"],\"type\":\"observation\",\"source\":\"mcp\"}"}},
},
})
default:
t.Fatalf("unexpected model %q", req.Model)
}
}))
defer server.Close()
client := New(Config{
Name: "test",
BaseURL: server.URL,
APIKey: "secret",
MetadataModel: "primary",
FallbackMetadataModels: []string{"fallback"},
HTTPClient: server.Client(),
Log: discardLogger(),
})
// First three calls should probe primary and then use fallback.
for i := 0; i < 3; i++ {
if _, err := client.ExtractMetadata(context.Background(), "hello"); err != nil {
t.Fatalf("ExtractMetadata() error = %v", err)
}
}
primaryBefore := primaryCalls.Load()
if primaryBefore == 0 {
t.Fatal("expected primary model to be called before bypass")
}
// Fourth call should bypass primary (no additional primary calls).
if _, err := client.ExtractMetadata(context.Background(), "hello"); err != nil {
t.Fatalf("ExtractMetadata() error = %v", err)
}
if primaryCalls.Load() != primaryBefore {
t.Fatalf("primary calls increased after bypass: before=%d after=%d", primaryBefore, primaryCalls.Load())
}
if fallbackCalls.Load() < 4 {
t.Fatalf("fallback calls = %d, want at least 4", fallbackCalls.Load())
}
}

View File

@@ -9,22 +9,22 @@ import (
)
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
fallback := cfg.LiteLLM.FallbackMetadataModel
if fallback == "" {
fallback = cfg.Metadata.FallbackModel
fallbacks := cfg.LiteLLM.EffectiveFallbackMetadataModels()
if len(fallbacks) == 0 {
fallbacks = cfg.Metadata.EffectiveFallbackModels()
}
return compat.New(compat.Config{
Name: "litellm",
BaseURL: cfg.LiteLLM.BaseURL,
APIKey: cfg.LiteLLM.APIKey,
EmbeddingModel: cfg.LiteLLM.EmbeddingModel,
MetadataModel: cfg.LiteLLM.MetadataModel,
FallbackMetadataModel: fallback,
Temperature: cfg.Metadata.Temperature,
Headers: cfg.LiteLLM.RequestHeaders,
HTTPClient: httpClient,
Log: log,
Dimensions: cfg.Embeddings.Dimensions,
LogConversations: cfg.Metadata.LogConversations,
Name: "litellm",
BaseURL: cfg.LiteLLM.BaseURL,
APIKey: cfg.LiteLLM.APIKey,
EmbeddingModel: cfg.LiteLLM.EmbeddingModel,
MetadataModel: cfg.LiteLLM.MetadataModel,
FallbackMetadataModels: fallbacks,
Temperature: cfg.Metadata.Temperature,
Headers: cfg.LiteLLM.RequestHeaders,
HTTPClient: httpClient,
Log: log,
Dimensions: cfg.Embeddings.Dimensions,
LogConversations: cfg.Metadata.LogConversations,
}), nil
}

View File

@@ -10,17 +10,17 @@ import (
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
return compat.New(compat.Config{
Name: "ollama",
BaseURL: cfg.Ollama.BaseURL,
APIKey: cfg.Ollama.APIKey,
EmbeddingModel: cfg.Embeddings.Model,
MetadataModel: cfg.Metadata.Model,
FallbackMetadataModel: cfg.Metadata.FallbackModel,
Temperature: cfg.Metadata.Temperature,
Headers: cfg.Ollama.RequestHeaders,
HTTPClient: httpClient,
Log: log,
Dimensions: cfg.Embeddings.Dimensions,
LogConversations: cfg.Metadata.LogConversations,
Name: "ollama",
BaseURL: cfg.Ollama.BaseURL,
APIKey: cfg.Ollama.APIKey,
EmbeddingModel: cfg.Embeddings.Model,
MetadataModel: cfg.Metadata.Model,
FallbackMetadataModels: cfg.Metadata.EffectiveFallbackModels(),
Temperature: cfg.Metadata.Temperature,
Headers: cfg.Ollama.RequestHeaders,
HTTPClient: httpClient,
Log: log,
Dimensions: cfg.Embeddings.Dimensions,
LogConversations: cfg.Metadata.LogConversations,
}), nil
}

View File

@@ -21,17 +21,17 @@ func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compa
}
return compat.New(compat.Config{
Name: "openrouter",
BaseURL: cfg.OpenRouter.BaseURL,
APIKey: cfg.OpenRouter.APIKey,
EmbeddingModel: cfg.Embeddings.Model,
MetadataModel: cfg.Metadata.Model,
FallbackMetadataModel: cfg.Metadata.FallbackModel,
Temperature: cfg.Metadata.Temperature,
Headers: headers,
HTTPClient: httpClient,
Log: log,
Dimensions: cfg.Embeddings.Dimensions,
LogConversations: cfg.Metadata.LogConversations,
Name: "openrouter",
BaseURL: cfg.OpenRouter.BaseURL,
APIKey: cfg.OpenRouter.APIKey,
EmbeddingModel: cfg.Embeddings.Model,
MetadataModel: cfg.Metadata.Model,
FallbackMetadataModels: cfg.Metadata.EffectiveFallbackModels(),
Temperature: cfg.Metadata.Temperature,
Headers: headers,
HTTPClient: httpClient,
Log: log,
Dimensions: cfg.Embeddings.Dimensions,
LogConversations: cfg.Metadata.LogConversations,
}), nil
}