From f76d1bbd23c7314430d58cf7ab4ffe6acf0a3e64 Mon Sep 17 00:00:00 2001 From: Hein Date: Fri, 27 Mar 2026 00:24:16 +0200 Subject: [PATCH] feat(metadata): enhance metadata extraction with reasoning content support --- internal/ai/compat/client.go | 121 ++++++++++++++++++++++++++++-- internal/ai/compat/client_test.go | 72 ++++++++++++++++++ 2 files changed, 185 insertions(+), 8 deletions(-) diff --git a/internal/ai/compat/client.go b/internal/ai/compat/client.go index bbf257d..ccce7be 100644 --- a/internal/ai/compat/client.go +++ b/internal/ai/compat/client.go @@ -92,11 +92,18 @@ type chatMessage struct { type chatCompletionsResponse struct { Choices []struct { - Message chatMessage `json:"message"` + Message responseChatMessage `json:"message"` + Text string `json:"text,omitempty"` } `json:"choices"` Error *providerError `json:"error,omitempty"` } +type responseChatMessage struct { + Role string `json:"role"` + Content json.RawMessage `json:"content"` + ReasoningContent string `json:"reasoning_content,omitempty"` +} + type providerError struct { Message string `json:"message"` Type string `json:"type,omitempty"` @@ -200,7 +207,7 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: no choices returned", c.name) } - rawResponse := resp.Choices[0].Message.Content + rawResponse := extractChoiceText(resp.Choices[0].Message, resp.Choices[0].Text) if c.logConversations && c.log != nil { c.log.Info("metadata conversation response", @@ -247,7 +254,7 @@ func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string) return "", fmt.Errorf("%s summarize: no choices returned", c.name) } - return strings.TrimSpace(resp.Choices[0].Message.Content), nil + return extractChoiceText(resp.Choices[0].Message, resp.Choices[0].Text), nil } func (c *Client) Name() string { @@ -335,12 +342,45 @@ func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest // extractJSONObject finds the first complete {...} block in s. // It handles models that prepend prose to a JSON response despite json_object mode. func extractJSONObject(s string) string { - start := strings.Index(s, "{") - end := strings.LastIndex(s, "}") - if start == -1 || end == -1 || end <= start { - return "" + for start := 0; start < len(s); start++ { + if s[start] != '{' { + continue + } + depth := 0 + inString := false + escaped := false + for end := start; end < len(s); end++ { + ch := s[end] + if escaped { + escaped = false + continue + } + if ch == '\\' && inString { + escaped = true + continue + } + if ch == '"' { + inString = !inString + continue + } + if inString { + continue + } + switch ch { + case '{': + depth++ + case '}': + depth-- + if depth == 0 { + candidate := s[start : end+1] + if json.Valid([]byte(candidate)) { + return candidate + } + } + } + } } - return s[start : end+1] + return "" } // stripThinkingBlocks removes ... and ... @@ -379,6 +419,71 @@ func stripCodeFence(value string) string { return strings.TrimSpace(value) } +func extractChoiceText(message responseChatMessage, fallbackText string) string { + if text := strings.TrimSpace(extractMessageContent(message.Content)); text != "" { + return text + } + if text := strings.TrimSpace(message.ReasoningContent); text != "" { + return text + } + return strings.TrimSpace(fallbackText) +} + +func extractMessageContent(raw json.RawMessage) string { + raw = bytes.TrimSpace(raw) + if len(raw) == 0 { + return "" + } + + var contentString string + if err := json.Unmarshal(raw, &contentString); err == nil { + return contentString + } + + var contentArray []any + if err := json.Unmarshal(raw, &contentArray); err == nil { + parts := make([]string, 0, len(contentArray)) + for _, item := range contentArray { + if text := strings.TrimSpace(extractTextFromAny(item)); text != "" { + parts = append(parts, text) + } + } + return strings.Join(parts, "\n") + } + + var contentObject map[string]any + if err := json.Unmarshal(raw, &contentObject); err == nil { + return extractTextFromAny(contentObject) + } + + return strings.TrimSpace(string(raw)) +} + +func extractTextFromAny(value any) string { + switch typed := value.(type) { + case string: + return typed + case []any: + parts := make([]string, 0, len(typed)) + for _, item := range typed { + if text := strings.TrimSpace(extractTextFromAny(item)); text != "" { + parts = append(parts, text) + } + } + return strings.Join(parts, "\n") + case map[string]any: + // Common provider shapes for chat content parts. + for _, key := range []string{"text", "output_text", "content", "value"} { + if nested, ok := typed[key]; ok { + if text := strings.TrimSpace(extractTextFromAny(nested)); text != "" { + return text + } + } + } + } + return "" +} + func isRetryableStatus(status int) bool { switch status { case http.StatusTooManyRequests, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: diff --git a/internal/ai/compat/client_test.go b/internal/ai/compat/client_test.go index 881c1a9..df0cd28 100644 --- a/internal/ai/compat/client_test.go +++ b/internal/ai/compat/client_test.go @@ -178,3 +178,75 @@ func TestExtractMetadataParsesCodeFencedJSON(t *testing.T) { t.Fatalf("metadata people = %#v, want [Alice]", metadata.People) } } + +func TestExtractMetadataParsesArrayContent(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{ + "content": []map[string]any{ + {"type": "text", "text": "{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"auth\"],\"type\":\"reference\",\"source\":\"mcp\"}"}, + }, + }, + }, + }, + }) + })) + defer server.Close() + + client := New(Config{ + Name: "test", + BaseURL: server.URL, + APIKey: "secret", + EmbeddingModel: "embed-model", + MetadataModel: "meta-model", + HTTPClient: server.Client(), + Log: discardLogger(), + }) + + metadata, err := client.ExtractMetadata(context.Background(), "hello") + if err != nil { + t.Fatalf("ExtractMetadata() error = %v", err) + } + if metadata.Type != "reference" { + t.Fatalf("metadata type = %q, want reference", metadata.Type) + } +} + +func TestExtractMetadataUsesReasoningContentWhenContentEmpty(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + { + "message": map[string]any{ + "content": "", + "reasoning_content": "{\"people\":[\"Hein\"],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"profile\"],\"type\":\"person_note\",\"source\":\"mcp\"}", + }, + }, + }, + }) + })) + defer server.Close() + + client := New(Config{ + Name: "test", + BaseURL: server.URL, + APIKey: "secret", + EmbeddingModel: "embed-model", + MetadataModel: "meta-model", + HTTPClient: server.Client(), + Log: discardLogger(), + }) + + metadata, err := client.ExtractMetadata(context.Background(), "hello") + if err != nil { + t.Fatalf("ExtractMetadata() error = %v", err) + } + if metadata.Type != "person_note" { + t.Fatalf("metadata type = %q, want person_note", metadata.Type) + } + if len(metadata.People) != 1 || metadata.People[0] != "Hein" { + t.Fatalf("metadata people = %#v, want [Hein]", metadata.People) + } +}