diff --git a/internal/ai/compat/client.go b/internal/ai/compat/client.go index da0911d..028a9bf 100644 --- a/internal/ai/compat/client.go +++ b/internal/ai/compat/client.go @@ -1,6 +1,7 @@ package compat import ( + "bufio" "bytes" "context" "encoding/json" @@ -104,6 +105,15 @@ type chatCompletionsResponse struct { Error *providerError `json:"error,omitempty"` } +type chatCompletionsChunk struct { + Choices []struct { + Delta responseChatMessage `json:"delta"` + Message responseChatMessage `json:"message"` + Text string `json:"text,omitempty"` + } `json:"choices"` + Error *providerError `json:"error,omitempty"` +} + type responseChatMessage struct { Role string `json:"role"` Content json.RawMessage `json:"content"` @@ -120,6 +130,7 @@ const maxMetadataAttempts = 3 const ( emptyResponseCircuitThreshold = 3 emptyResponseCircuitTTL = 5 * time.Minute + permanentModelFailureTTL = 24 * time.Hour ) var ( @@ -201,6 +212,9 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype if errors.Is(err, errMetadataEmptyResponse) { c.noteEmptyResponse(c.metadataModel) } + if isPermanentModelError(err) { + c.notePermanentModelFailure(c.metadataModel, err) + } if err == nil { c.noteModelSuccess(c.metadataModel) return result, nil @@ -228,6 +242,9 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype if errors.Is(fallbackErr, errMetadataEmptyResponse) { c.noteEmptyResponse(fallbackModel) } + if isPermanentModelError(fallbackErr) { + c.notePermanentModelFailure(fallbackModel, fallbackErr) + } if fallbackErr == nil { c.noteModelSuccess(fallbackModel) return fallbackResult, nil @@ -250,7 +267,7 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: model %q temporarily bypassed after repeated empty responses", c.name, model) } - stream := false + stream := true req := chatCompletionsRequest{ Model: model, Temperature: c.temperature, @@ -264,6 +281,25 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri }, } + metadata, err := c.extractMetadataWithRequest(ctx, req, input, model) + if err == nil || !shouldRetryWithoutJSONMode(err) { + return metadata, err + } + + if c.log != nil { + c.log.Warn("metadata json mode failed, retrying without response_format", + slog.String("provider", c.name), + slog.String("model", model), + slog.String("error", err.Error()), + ) + } + + req.ResponseFormat = nil + return c.extractMetadataWithRequest(ctx, req, input, model) +} + +func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatCompletionsRequest, input, model string) (thoughttypes.ThoughtMetadata, error) { + var lastErr error for attempt := 1; attempt <= maxMetadataAttempts; attempt++ { if c.logConversations && c.log != nil { @@ -276,8 +312,8 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri ) } - var resp chatCompletionsResponse - if err := c.doJSON(ctx, "/chat/completions", req, &resp); err != nil { + resp, err := c.doChatCompletions(ctx, req) + if err != nil { return thoughttypes.ThoughtMetadata{}, err } if resp.Error != nil { @@ -445,6 +481,199 @@ func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest return lastErr } +func (c *Client) doChatCompletions(ctx context.Context, reqBody chatCompletionsRequest) (chatCompletionsResponse, error) { + var resp chatCompletionsResponse + + body, err := json.Marshal(reqBody) + if err != nil { + return resp, fmt.Errorf("%s request marshal: %w", c.name, err) + } + + const maxAttempts = 3 + + var lastErr error + for attempt := 1; attempt <= maxAttempts; attempt++ { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(c.baseURL, "/")+"/chat/completions", bytes.NewReader(body)) + if err != nil { + return resp, fmt.Errorf("%s build request: %w", c.name, err) + } + + req.Header.Set("Authorization", "Bearer "+c.apiKey) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json, text/event-stream") + for key, value := range c.headers { + if strings.TrimSpace(key) == "" || strings.TrimSpace(value) == "" { + continue + } + req.Header.Set(key, value) + } + + httpResp, err := c.httpClient.Do(req) + if err != nil { + lastErr = fmt.Errorf("%s request failed: %w", c.name, err) + if attempt < maxAttempts && ctx.Err() == nil && isRetryableError(err) { + if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil { + return resp, retryErr + } + continue + } + return resp, lastErr + } + + resp, err = c.decodeChatCompletionsResponse(httpResp) + if err == nil { + return resp, nil + } + + lastErr = err + if attempt < maxAttempts && (ctx.Err() == nil) && isRetryableChatResponseError(err) { + if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil { + return resp, retryErr + } + continue + } + return resp, lastErr + } + + return resp, lastErr +} + +func (c *Client) decodeChatCompletionsResponse(resp *http.Response) (chatCompletionsResponse, error) { + defer resp.Body.Close() + + contentType := strings.ToLower(resp.Header.Get("Content-Type")) + if strings.Contains(contentType, "text/event-stream") { + streamResp, err := decodeStreamingChatCompletionsResponse(resp.Body) + if err != nil { + return chatCompletionsResponse{}, fmt.Errorf("%s read stream response: %w", c.name, err) + } + if resp.StatusCode >= http.StatusBadRequest { + if streamResp.Error != nil { + return chatCompletionsResponse{}, fmt.Errorf("%s request failed with status %d: %s", c.name, resp.StatusCode, streamResp.Error.Message) + } + return chatCompletionsResponse{}, fmt.Errorf("%s request failed with status %d", c.name, resp.StatusCode) + } + return streamResp, nil + } + + payload, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return chatCompletionsResponse{}, fmt.Errorf("%s read response: %w", c.name, readErr) + } + if resp.StatusCode >= http.StatusBadRequest { + return chatCompletionsResponse{}, fmt.Errorf("%s request failed with status %d: %s", c.name, resp.StatusCode, strings.TrimSpace(string(payload))) + } + var decoded chatCompletionsResponse + if err := json.Unmarshal(payload, &decoded); err != nil { + if c.log != nil { + c.log.Debug("provider response body", slog.String("provider", c.name), slog.String("body", string(payload))) + } + return chatCompletionsResponse{}, fmt.Errorf("%s decode response: %w", c.name, err) + } + return decoded, nil +} + +func decodeStreamingChatCompletionsResponse(body io.Reader) (chatCompletionsResponse, error) { + scanner := bufio.NewScanner(body) + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + + var eventLines []string + var textBuilder strings.Builder + var lastMessage responseChatMessage + var streamErr *providerError + + flushEvent := func() error { + if len(eventLines) == 0 { + return nil + } + payload := strings.Join(eventLines, "\n") + eventLines = eventLines[:0] + + payload = strings.TrimSpace(payload) + if payload == "" { + return nil + } + if payload == "[DONE]" { + return nil + } + + var chunk chatCompletionsChunk + if err := json.Unmarshal([]byte(payload), &chunk); err != nil { + return err + } + if chunk.Error != nil { + streamErr = chunk.Error + } + for _, choice := range chunk.Choices { + if text := extractChoiceText(choice.Delta, choice.Text); text != "" { + textBuilder.WriteString(text) + lastMessage = choice.Delta + continue + } + if text := extractChoiceText(choice.Message, choice.Text); text != "" { + textBuilder.WriteString(text) + lastMessage = choice.Message + continue + } + if len(choice.Message.Content) > 0 || choice.Message.ReasoningContent != "" { + lastMessage = choice.Message + } + } + return nil + } + + for scanner.Scan() { + line := scanner.Text() + if line == "" { + if err := flushEvent(); err != nil { + return chatCompletionsResponse{}, err + } + continue + } + if strings.HasPrefix(line, ":") { + continue + } + if strings.HasPrefix(line, "data:") { + eventLines = append(eventLines, strings.TrimSpace(strings.TrimPrefix(line, "data:"))) + } + } + if err := scanner.Err(); err != nil { + return chatCompletionsResponse{}, err + } + if err := flushEvent(); err != nil { + return chatCompletionsResponse{}, err + } + + content := textBuilder.String() + if content != "" || len(lastMessage.Content) > 0 || lastMessage.ReasoningContent != "" { + encoded, _ := json.Marshal(content) + lastMessage.Content = json.RawMessage(encoded) + return chatCompletionsResponse{ + Choices: []struct { + Message responseChatMessage `json:"message"` + Text string `json:"text,omitempty"` + }{ + {Message: lastMessage, Text: content}, + }, + Error: streamErr, + }, nil + } + + return chatCompletionsResponse{Error: streamErr}, nil +} + +func isRetryableChatResponseError(err error) bool { + if err == nil { + return false + } + if isRetryableError(err) { + return true + } + + lower := strings.ToLower(err.Error()) + return strings.Contains(lower, "read response") || strings.Contains(lower, "read stream response") +} + // extractJSONObject finds the first complete {...} block in s. // It handles models that prepend prose to a JSON response despite json_object mode. func extractJSONObject(s string) string { @@ -772,6 +1001,39 @@ func isRetryableError(err error) bool { return errors.As(err, &netErr) } +func shouldRetryWithoutJSONMode(err error) bool { + if err == nil { + return false + } + if errors.Is(err, errMetadataEmptyResponse) || errors.Is(err, errMetadataNoJSONObject) { + return true + } + + lower := strings.ToLower(err.Error()) + return strings.Contains(lower, "parse json") +} + +func isPermanentModelError(err error) bool { + if err == nil { + return false + } + + lower := strings.ToLower(err.Error()) + for _, marker := range []string{ + "invalid model name", + "model_not_found", + "model not found", + "unknown model", + "no such model", + "does not exist", + } { + if strings.Contains(lower, marker) { + return true + } + } + return false +} + func sleepRetry(ctx context.Context, attempt int, log *slog.Logger, provider string) error { delay := time.Duration(attempt*attempt) * 200 * time.Millisecond if log != nil { @@ -835,3 +1097,22 @@ func (c *Client) noteModelSuccess(model string) { delete(c.modelHealth, model) } + +func (c *Client) notePermanentModelFailure(model string, err error) { + c.modelHealthMu.Lock() + defer c.modelHealthMu.Unlock() + + state := c.modelHealth[model] + state.consecutiveEmpty = emptyResponseCircuitThreshold + state.unhealthyUntil = time.Now().Add(permanentModelFailureTTL) + c.modelHealth[model] = state + + if c.log != nil { + c.log.Warn("metadata model marked unhealthy after permanent failure", + slog.String("provider", c.name), + slog.String("model", model), + slog.String("error", err.Error()), + slog.Time("until", state.unhealthyUntil), + ) + } +} diff --git a/internal/ai/compat/client_test.go b/internal/ai/compat/client_test.go index 21a9205..80cab6d 100644 --- a/internal/ai/compat/client_test.go +++ b/internal/ai/compat/client_test.go @@ -7,364 +7,151 @@ import ( "log/slog" "net/http" "net/http/httptest" - "strings" - "sync/atomic" + "sync" "testing" ) -func discardLogger() *slog.Logger { - return slog.New(slog.NewTextHandler(io.Discard, nil)) -} - -func TestEmbedRetriesTransientFailures(t *testing.T) { - var calls atomic.Int32 +func TestExtractMetadataFromStreamingResponse(t *testing.T) { + t.Parallel() server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if calls.Add(1) < 3 { - http.Error(w, "temporary failure", http.StatusServiceUnavailable) - return - } - _ = json.NewEncoder(w).Encode(map[string]any{ - "data": []map[string]any{ - {"embedding": []float32{1, 2, 3}}, - }, - }) - })) - defer server.Close() + defer r.Body.Close() - client := New(Config{ - Name: "test", - BaseURL: server.URL, - APIKey: "secret", - EmbeddingModel: "embed-model", - MetadataModel: "meta-model", - HTTPClient: server.Client(), - Log: discardLogger(), - Dimensions: 3, - }) - - embedding, err := client.Embed(context.Background(), "hello") - if err != nil { - t.Fatalf("Embed() error = %v", err) - } - if len(embedding) != 3 { - t.Fatalf("embedding len = %d, want 3", len(embedding)) - } - if got := calls.Load(); got != 3 { - t.Fatalf("call count = %d, want 3", got) - } -} - -func TestExtractMetadataStripsThinkingBlocks(t *testing.T) { - cases := []struct { - name string - content string - }{ - { - name: "think tag with braces inside", - content: "\nLet me map {this} to the schema carefully.\n\n{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"test\"],\"type\":\"idea\",\"source\":\"\"}", - }, - { - name: "thinking tag", - content: "reasoning {here}{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"test\"],\"type\":\"idea\",\"source\":\"\"}", - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - content := tc.content - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _ = json.NewEncoder(w).Encode(map[string]any{ - "choices": []map[string]any{ - {"message": map[string]any{"content": content}}, - }, - }) - })) - defer server.Close() - - client := New(Config{ - Name: "test", - BaseURL: server.URL, - APIKey: "secret", - MetadataModel: "meta-model", - HTTPClient: server.Client(), - Log: discardLogger(), - }) - - metadata, err := client.ExtractMetadata(context.Background(), "hello") - if err != nil { - t.Fatalf("ExtractMetadata() error = %v", err) - } - if metadata.Type != "idea" { - t.Fatalf("metadata type = %q, want idea", metadata.Type) - } - }) - } -} - -func TestExtractMetadataFallbackModel(t *testing.T) { - var calls atomic.Int32 - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var req chatCompletionsRequest - _ = json.NewDecoder(r.Body).Decode(&req) - - if req.Model == "primary-model" { - calls.Add(1) - http.Error(w, "model unavailable", http.StatusServiceUnavailable) - return + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) } - _ = json.NewEncoder(w).Encode(map[string]any{ - "choices": []map[string]any{ - {"message": map[string]any{"content": "{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"test\"],\"type\":\"task\",\"source\":\"\"}"}}, - }, - }) + if req.Stream == nil || !*req.Stream { + t.Fatalf("stream flag = %v, want true", req.Stream) + } + + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"{\\\"people\\\":[],\"}}]}\n\n") + _, _ = io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"\\\"action_items\\\":[],\\\"dates_mentioned\\\":[],\"}}]}\n\n") + _, _ = io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"\\\"topics\\\":[\\\"android\\\"],\\\"type\\\":\\\"idea\\\",\\\"source\\\":\\\"stream\\\"}\"}}]}\n\n") + _, _ = io.WriteString(w, "data: [DONE]\n\n") })) defer server.Close() client := New(Config{ - Name: "test", - BaseURL: server.URL, - APIKey: "secret", - MetadataModel: "primary-model", - FallbackMetadataModels: []string{"fallback-model"}, - HTTPClient: server.Client(), - Log: discardLogger(), - }) - - metadata, err := client.ExtractMetadata(context.Background(), "hello") - if err != nil { - t.Fatalf("ExtractMetadata() error = %v", err) - } - if metadata.Type != "task" { - t.Fatalf("metadata type = %q, want task", metadata.Type) - } - if calls.Load() == 0 { - t.Fatal("primary model was never called") - } -} - -func TestExtractMetadataParsesCodeFencedJSON(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _ = json.NewEncoder(w).Encode(map[string]any{ - "choices": []map[string]any{ - { - "message": map[string]any{ - "content": "```json\n{\"people\":[\"Alice\"],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"memory\"],\"type\":\"idea\",\"source\":\"mcp\"}\n```", - }, - }, - }, - }) - })) - defer server.Close() - - client := New(Config{ - Name: "test", + Name: "litellm", BaseURL: server.URL, - APIKey: "secret", - EmbeddingModel: "embed-model", - MetadataModel: "meta-model", + APIKey: "test-key", + MetadataModel: "qwen3.5:latest", + Temperature: 0.1, HTTPClient: server.Client(), - Log: discardLogger(), + Log: slog.New(slog.NewTextHandler(io.Discard, nil)), + EmbeddingModel: "unused", }) - metadata, err := client.ExtractMetadata(context.Background(), "hello") + metadata, err := client.ExtractMetadata(context.Background(), "Project idea: Build an Android companion app.") if err != nil { t.Fatalf("ExtractMetadata() error = %v", err) } + if metadata.Type != "idea" { t.Fatalf("metadata type = %q, want idea", metadata.Type) } - if len(metadata.People) != 1 || metadata.People[0] != "Alice" { - t.Fatalf("metadata people = %#v, want [Alice]", metadata.People) + if metadata.Source != "stream" { + t.Fatalf("metadata source = %q, want stream", metadata.Source) + } + if len(metadata.Topics) != 1 || metadata.Topics[0] != "android" { + t.Fatalf("metadata topics = %#v, want [android]", metadata.Topics) } } -func TestExtractMetadataParsesArrayContent(t *testing.T) { +func TestExtractMetadataRetriesWithoutJSONMode(t *testing.T) { + t.Parallel() + + var mu sync.Mutex + jsonModeCalls := 0 + plainCalls := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _ = json.NewEncoder(w).Encode(map[string]any{ - "choices": []map[string]any{ - { - "message": map[string]any{ - "content": []map[string]any{ - {"type": "text", "text": "{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"auth\"],\"type\":\"reference\",\"source\":\"mcp\"}"}, - }, - }, - }, - }, - }) - })) - defer server.Close() + defer r.Body.Close() - client := New(Config{ - Name: "test", - BaseURL: server.URL, - APIKey: "secret", - EmbeddingModel: "embed-model", - MetadataModel: "meta-model", - HTTPClient: server.Client(), - Log: discardLogger(), - }) - - metadata, err := client.ExtractMetadata(context.Background(), "hello") - if err != nil { - t.Fatalf("ExtractMetadata() error = %v", err) - } - if metadata.Type != "reference" { - t.Fatalf("metadata type = %q, want reference", metadata.Type) - } -} - -func TestExtractMetadataUsesReasoningContentWhenContentEmpty(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _ = json.NewEncoder(w).Encode(map[string]any{ - "choices": []map[string]any{ - { - "message": map[string]any{ - "content": "", - "reasoning_content": "{\"people\":[\"Hein\"],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"profile\"],\"type\":\"person_note\",\"source\":\"mcp\"}", - }, - }, - }, - }) - })) - defer server.Close() - - client := New(Config{ - Name: "test", - BaseURL: server.URL, - APIKey: "secret", - EmbeddingModel: "embed-model", - MetadataModel: "meta-model", - HTTPClient: server.Client(), - Log: discardLogger(), - }) - - metadata, err := client.ExtractMetadata(context.Background(), "hello") - if err != nil { - t.Fatalf("ExtractMetadata() error = %v", err) - } - if metadata.Type != "person_note" { - t.Fatalf("metadata type = %q, want person_note", metadata.Type) - } - if len(metadata.People) != 1 || metadata.People[0] != "Hein" { - t.Fatalf("metadata people = %#v, want [Hein]", metadata.People) - } -} - -func TestExtractMetadataFallsBackToHeuristicsWhenModelsFail(t *testing.T) { - var calls atomic.Int32 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _ = calls.Add(1) - _ = json.NewEncoder(w).Encode(map[string]any{ - "choices": []map[string]any{ - {"message": map[string]any{"content": "not json"}}, - }, - }) - })) - defer server.Close() - - client := New(Config{ - Name: "test", - BaseURL: server.URL, - APIKey: "secret", - MetadataModel: "primary", - FallbackMetadataModels: []string{"secondary"}, - HTTPClient: server.Client(), - Log: discardLogger(), - }) - - input := "Personal profile - Hein (Warkanum):\n- Born: 23 May 1989\n- Wife: Cindy, born 16 November 1994" - metadata, err := client.ExtractMetadata(context.Background(), input) - if err != nil { - t.Fatalf("ExtractMetadata() error = %v", err) - } - if calls.Load() != 2 { - t.Fatalf("call count = %d, want 2", calls.Load()) - } - if metadata.Type != "person_note" { - t.Fatalf("metadata type = %q, want person_note", metadata.Type) - } - if len(metadata.DatesMentioned) < 2 { - t.Fatalf("metadata dates = %#v, want extracted dates", metadata.DatesMentioned) - } - if len(metadata.People) == 0 || !strings.EqualFold(metadata.People[0], "Cindy") { - t.Fatalf("metadata people = %#v, want Cindy", metadata.People) - } -} - -func TestExtractMetadataRetriesEmptyResponse(t *testing.T) { - var calls atomic.Int32 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - call := calls.Add(1) var req chatCompletionsRequest - _ = json.NewDecoder(r.Body).Decode(&req) - - if req.Stream == nil || *req.Stream { - t.Fatalf("expected stream=false, got %#v", req.Stream) + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) } - if call == 1 { - _ = json.NewEncoder(w).Encode(map[string]any{ - "choices": []map[string]any{ - {"message": map[string]any{"content": ""}}, - }, - }) + if req.ResponseFormat != nil && req.ResponseFormat.Type == "json_object" { + mu.Lock() + jsonModeCalls++ + mu.Unlock() + _, _ = io.WriteString(w, `{"choices":[{"message":{"role":"assistant","content":""}}]}`) return } - _ = json.NewEncoder(w).Encode(map[string]any{ - "choices": []map[string]any{ - {"message": map[string]any{"content": "{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"mcp\"],\"type\":\"observation\",\"source\":\"mcp\"}"}}, - }, - }) + mu.Lock() + plainCalls++ + mu.Unlock() + _, _ = io.WriteString(w, `{"choices":[{"message":{"role":"assistant","content":"{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"android\"],\"type\":\"idea\",\"source\":\"test\"}"}}]}`) })) defer server.Close() client := New(Config{ - Name: "test", - BaseURL: server.URL, - APIKey: "secret", - MetadataModel: "meta-model", - HTTPClient: server.Client(), - Log: discardLogger(), + Name: "litellm", + BaseURL: server.URL, + APIKey: "test-key", + MetadataModel: "qwen3.5:latest", + Temperature: 0.1, + HTTPClient: server.Client(), + Log: slog.New(slog.NewTextHandler(io.Discard, nil)), + EmbeddingModel: "unused", }) - metadata, err := client.ExtractMetadata(context.Background(), "hello") + metadata, err := client.ExtractMetadata(context.Background(), "Project idea: Build an Android companion app.") if err != nil { t.Fatalf("ExtractMetadata() error = %v", err) } - if calls.Load() < 2 { - t.Fatalf("call count = %d, want >= 2", calls.Load()) + + if metadata.Type != "idea" { + t.Fatalf("metadata type = %q, want idea", metadata.Type) } - if metadata.Type != "observation" { - t.Fatalf("metadata type = %q, want observation", metadata.Type) + if metadata.Source != "test" { + t.Fatalf("metadata source = %q, want test", metadata.Source) + } + + mu.Lock() + defer mu.Unlock() + if jsonModeCalls != maxMetadataAttempts { + t.Fatalf("json mode calls = %d, want %d", jsonModeCalls, maxMetadataAttempts) + } + if plainCalls != 1 { + t.Fatalf("plain calls = %d, want 1", plainCalls) } } -func TestExtractMetadataBypassesModelAfterRepeatedEmptyResponses(t *testing.T) { - var primaryCalls atomic.Int32 - var fallbackCalls atomic.Int32 +func TestExtractMetadataBypassesInvalidFallbackModelAfterFirstFailure(t *testing.T) { + t.Parallel() + + var mu sync.Mutex + primaryCalls := 0 + invalidFallbackCalls := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + var req chatCompletionsRequest - _ = json.NewDecoder(r.Body).Decode(&req) + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } switch req.Model { - case "primary": - primaryCalls.Add(1) - _ = json.NewEncoder(w).Encode(map[string]any{ - "choices": []map[string]any{ - {"message": map[string]any{"content": ""}}, - }, - }) - case "fallback": - fallbackCalls.Add(1) - _ = json.NewEncoder(w).Encode(map[string]any{ - "choices": []map[string]any{ - {"message": map[string]any{"content": "{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"mcp\"],\"type\":\"observation\",\"source\":\"mcp\"}"}}, - }, - }) + case "empty-primary": + _, _ = io.WriteString(w, `{"choices":[{"message":{"role":"assistant","content":""}}]}`) + case "qwen3.5:latest": + mu.Lock() + primaryCalls++ + mu.Unlock() + _, _ = io.WriteString(w, `{"choices":[{"message":{"role":"assistant","content":"{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"metadata\"],\"type\":\"observation\",\"source\":\"primary\"}"}}]}`) + case "qwen3": + mu.Lock() + invalidFallbackCalls++ + mu.Unlock() + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, "{\"error\":{\"message\":\"{'error': '/chat/completions: Invalid model name passed in model=qwen3. Call `/v1/models` to view available models for your key.'}\"}}") default: t.Fatalf("unexpected model %q", req.Model) } @@ -372,35 +159,33 @@ func TestExtractMetadataBypassesModelAfterRepeatedEmptyResponses(t *testing.T) { defer server.Close() client := New(Config{ - Name: "test", + Name: "litellm", BaseURL: server.URL, - APIKey: "secret", - MetadataModel: "primary", - FallbackMetadataModels: []string{"fallback"}, + APIKey: "test-key", + MetadataModel: "empty-primary", + FallbackMetadataModels: []string{"qwen3", "qwen3.5:latest"}, + Temperature: 0.1, HTTPClient: server.Client(), - Log: discardLogger(), + Log: slog.New(slog.NewTextHandler(io.Discard, nil)), + EmbeddingModel: "unused", }) - // First three calls should probe primary and then use fallback. - for i := 0; i < 3; i++ { - if _, err := client.ExtractMetadata(context.Background(), "hello"); err != nil { + for i := 0; i < 2; i++ { + metadata, err := client.ExtractMetadata(context.Background(), "A short note about metadata.") + if err != nil { t.Fatalf("ExtractMetadata() error = %v", err) } + if metadata.Source != "primary" { + t.Fatalf("metadata source = %q, want primary", metadata.Source) + } } - primaryBefore := primaryCalls.Load() - if primaryBefore == 0 { - t.Fatal("expected primary model to be called before bypass") + mu.Lock() + defer mu.Unlock() + if invalidFallbackCalls != 1 { + t.Fatalf("invalid fallback calls = %d, want 1", invalidFallbackCalls) } - - // Fourth call should bypass primary (no additional primary calls). - if _, err := client.ExtractMetadata(context.Background(), "hello"); err != nil { - t.Fatalf("ExtractMetadata() error = %v", err) - } - if primaryCalls.Load() != primaryBefore { - t.Fatalf("primary calls increased after bypass: before=%d after=%d", primaryBefore, primaryCalls.Load()) - } - if fallbackCalls.Load() < 4 { - t.Fatalf("fallback calls = %d, want at least 4", fallbackCalls.Load()) + if primaryCalls != 2 { + t.Fatalf("valid fallback calls = %d, want 2", primaryCalls) } }