diff --git a/internal/ai/compat/client.go b/internal/ai/compat/client.go index 885d87d..c6671db 100644 --- a/internal/ai/compat/client.go +++ b/internal/ai/compat/client.go @@ -189,6 +189,7 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri } metadataText := strings.TrimSpace(resp.Choices[0].Message.Content) + metadataText = stripThinkingBlocks(metadataText) metadataText = stripCodeFence(metadataText) metadataText = extractJSONObject(metadataText) if metadataText == "" { @@ -320,6 +321,30 @@ func extractJSONObject(s string) string { return s[start : end+1] } +// stripThinkingBlocks removes ... and ... +// blocks produced by reasoning models (DeepSeek R1, QwQ, etc.) so that the +// remaining text can be parsed as JSON without interference from thinking content +// that may itself contain braces. +func stripThinkingBlocks(s string) string { + for _, tag := range []string{"think", "thinking"} { + open := "<" + tag + ">" + close := "" + for { + start := strings.Index(s, open) + if start == -1 { + break + } + end := strings.Index(s[start:], close) + if end == -1 { + s = s[:start] + break + } + s = s[:start] + s[start+end+len(close):] + } + } + return strings.TrimSpace(s) +} + func stripCodeFence(value string) string { value = strings.TrimSpace(value) if !strings.HasPrefix(value, "```") { diff --git a/internal/ai/compat/client_test.go b/internal/ai/compat/client_test.go index e25a841..881c1a9 100644 --- a/internal/ai/compat/client_test.go +++ b/internal/ai/compat/client_test.go @@ -54,6 +54,95 @@ func TestEmbedRetriesTransientFailures(t *testing.T) { } } +func TestExtractMetadataStripsThinkingBlocks(t *testing.T) { + cases := []struct { + name string + content string + }{ + { + name: "think tag with braces inside", + content: "\nLet me map {this} to the schema carefully.\n\n{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"test\"],\"type\":\"idea\",\"source\":\"\"}", + }, + { + name: "thinking tag", + content: "reasoning {here}{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"test\"],\"type\":\"idea\",\"source\":\"\"}", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + content := tc.content + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + {"message": map[string]any{"content": content}}, + }, + }) + })) + defer server.Close() + + client := New(Config{ + Name: "test", + BaseURL: server.URL, + APIKey: "secret", + MetadataModel: "meta-model", + HTTPClient: server.Client(), + Log: discardLogger(), + }) + + metadata, err := client.ExtractMetadata(context.Background(), "hello") + if err != nil { + t.Fatalf("ExtractMetadata() error = %v", err) + } + if metadata.Type != "idea" { + t.Fatalf("metadata type = %q, want idea", metadata.Type) + } + }) + } +} + +func TestExtractMetadataFallbackModel(t *testing.T) { + var calls atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var req chatCompletionsRequest + _ = json.NewDecoder(r.Body).Decode(&req) + + if req.Model == "primary-model" { + calls.Add(1) + http.Error(w, "model unavailable", http.StatusServiceUnavailable) + return + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + {"message": map[string]any{"content": "{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"test\"],\"type\":\"task\",\"source\":\"\"}"}}, + }, + }) + })) + defer server.Close() + + client := New(Config{ + Name: "test", + BaseURL: server.URL, + APIKey: "secret", + MetadataModel: "primary-model", + FallbackMetadataModel: "fallback-model", + HTTPClient: server.Client(), + Log: discardLogger(), + }) + + metadata, err := client.ExtractMetadata(context.Background(), "hello") + if err != nil { + t.Fatalf("ExtractMetadata() error = %v", err) + } + if metadata.Type != "task" { + t.Fatalf("metadata type = %q, want task", metadata.Type) + } + if calls.Load() == 0 { + t.Fatal("primary model was never called") + } +} + func TestExtractMetadataParsesCodeFencedJSON(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _ = json.NewEncoder(w).Encode(map[string]any{