diff --git a/internal/ai/compat/client.go b/internal/ai/compat/client.go index ccce7be..42de883 100644 --- a/internal/ai/compat/client.go +++ b/internal/ai/compat/client.go @@ -10,6 +10,8 @@ import ( "log/slog" "net" "net/http" + "regexp" + "slices" "strings" "time" @@ -78,6 +80,7 @@ type chatCompletionsRequest struct { Model string `json:"model"` Temperature float64 `json:"temperature,omitempty"` ResponseFormat *responseType `json:"response_format,omitempty"` + Stream *bool `json:"stream,omitempty"` Messages []chatMessage `json:"messages"` } @@ -109,6 +112,8 @@ type providerError struct { Type string `json:"type,omitempty"` } +const maxMetadataAttempts = 3 + func New(cfg Config) *Client { return &Client{ name: cfg.Name, @@ -160,7 +165,11 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype } result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel) - if err != nil && c.fallbackMetadataModel != "" && ctx.Err() == nil { + if err == nil { + return result, nil + } + + if c.fallbackMetadataModel != "" && ctx.Err() == nil { if c.log != nil { c.log.Warn("metadata extraction failed, trying fallback model", slog.String("provider", c.name), @@ -169,68 +178,107 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype slog.String("error", err.Error()), ) } - return c.extractMetadataWithModel(ctx, input, c.fallbackMetadataModel) + fallbackResult, fallbackErr := c.extractMetadataWithModel(ctx, input, c.fallbackMetadataModel) + if fallbackErr == nil { + return fallbackResult, nil + } + err = fallbackErr } - return result, err + + heuristic := heuristicMetadataFromInput(input) + if c.log != nil { + c.log.Warn("metadata extraction failed for all models, using heuristic fallback", + slog.String("provider", c.name), + slog.String("error", err.Error()), + ) + } + return heuristic, nil } func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) { + stream := false req := chatCompletionsRequest{ Model: model, Temperature: c.temperature, ResponseFormat: &responseType{ Type: "json_object", }, + Stream: &stream, Messages: []chatMessage{ {Role: "system", Content: metadataSystemPrompt}, {Role: "user", Content: input}, }, } - if c.logConversations && c.log != nil { - c.log.Info("metadata conversation request", - slog.String("provider", c.name), - slog.String("model", model), - slog.String("system", metadataSystemPrompt), - slog.String("input", input), - ) + var lastErr error + for attempt := 1; attempt <= maxMetadataAttempts; attempt++ { + if c.logConversations && c.log != nil { + c.log.Info("metadata conversation request", + slog.String("provider", c.name), + slog.String("model", model), + slog.Int("attempt", attempt), + slog.String("system", metadataSystemPrompt), + slog.String("input", input), + ) + } + + var resp chatCompletionsResponse + if err := c.doJSON(ctx, "/chat/completions", req, &resp); err != nil { + return thoughttypes.ThoughtMetadata{}, err + } + if resp.Error != nil { + return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata error: %s", c.name, resp.Error.Message) + } + if len(resp.Choices) == 0 { + return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: no choices returned", c.name) + } + + rawResponse := extractChoiceText(resp.Choices[0].Message, resp.Choices[0].Text) + + if c.logConversations && c.log != nil { + c.log.Info("metadata conversation response", + slog.String("provider", c.name), + slog.String("model", model), + slog.Int("attempt", attempt), + slog.String("response", rawResponse), + ) + } + + metadataText := strings.TrimSpace(rawResponse) + metadataText = stripThinkingBlocks(metadataText) + metadataText = stripCodeFence(metadataText) + metadataText = extractJSONObject(metadataText) + if metadataText == "" { + lastErr = fmt.Errorf("%s metadata: response contains no JSON object", c.name) + if strings.TrimSpace(rawResponse) == "" && attempt < maxMetadataAttempts && ctx.Err() == nil { + if c.log != nil { + c.log.Warn("metadata response empty, waiting and retrying", + slog.String("provider", c.name), + slog.String("model", model), + slog.Int("attempt", attempt+1), + ) + } + if err := sleepMetadataRetry(ctx, attempt); err != nil { + return thoughttypes.ThoughtMetadata{}, err + } + continue + } + return thoughttypes.ThoughtMetadata{}, lastErr + } + + var metadata thoughttypes.ThoughtMetadata + if err := json.Unmarshal([]byte(metadataText), &metadata); err != nil { + lastErr = fmt.Errorf("%s metadata: parse json: %w", c.name, err) + return thoughttypes.ThoughtMetadata{}, lastErr + } + + return metadata, nil } - var resp chatCompletionsResponse - if err := c.doJSON(ctx, "/chat/completions", req, &resp); err != nil { - return thoughttypes.ThoughtMetadata{}, err + if lastErr != nil { + return thoughttypes.ThoughtMetadata{}, lastErr } - if resp.Error != nil { - return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata error: %s", c.name, resp.Error.Message) - } - if len(resp.Choices) == 0 { - return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: no choices returned", c.name) - } - - rawResponse := extractChoiceText(resp.Choices[0].Message, resp.Choices[0].Text) - - if c.logConversations && c.log != nil { - c.log.Info("metadata conversation response", - slog.String("provider", c.name), - slog.String("model", model), - slog.String("response", rawResponse), - ) - } - - metadataText := strings.TrimSpace(rawResponse) - metadataText = stripThinkingBlocks(metadataText) - metadataText = stripCodeFence(metadataText) - metadataText = extractJSONObject(metadataText) - if metadataText == "" { - return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: response contains no JSON object", c.name) - } - - var metadata thoughttypes.ThoughtMetadata - if err := json.Unmarshal([]byte(metadataText), &metadata); err != nil { - return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: parse json: %w", c.name, err) - } - - return metadata, nil + return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: response contains no JSON object", c.name) } func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) { @@ -484,6 +532,168 @@ func extractTextFromAny(value any) string { return "" } +var ( + monthDatePattern = regexp.MustCompile(`(?i)\b\d{1,2}\s+(?:jan(?:uary)?|feb(?:ruary)?|mar(?:ch)?|apr(?:il)?|may|jun(?:e)?|jul(?:y)?|aug(?:ust)?|sep(?:t(?:ember)?)?|oct(?:ober)?|nov(?:ember)?|dec(?:ember)?)\s+\d{4}\b`) + isoDatePattern = regexp.MustCompile(`\b\d{4}-\d{2}-\d{2}\b`) + wordPattern = regexp.MustCompile(`[a-zA-Z][a-zA-Z0-9_/-]{2,}`) +) + +func heuristicMetadataFromInput(input string) thoughttypes.ThoughtMetadata { + text := strings.TrimSpace(input) + lower := strings.ToLower(text) + + metadata := thoughttypes.ThoughtMetadata{ + People: heuristicPeople(text), + ActionItems: heuristicActionItems(text), + DatesMentioned: heuristicDates(text), + Topics: heuristicTopics(lower), + Type: heuristicType(lower), + Source: "", + } + + if len(metadata.Topics) == 0 { + metadata.Topics = []string{"uncategorized"} + } + if metadata.Type == "" { + metadata.Type = "observation" + } + return metadata +} + +func heuristicType(lower string) string { + switch { + case strings.Contains(lower, "preferred name"), strings.Contains(lower, "personal profile"), strings.Contains(lower, "wife:"), strings.Contains(lower, "daughter:"), strings.Contains(lower, "born:"): + return "person_note" + case strings.Contains(lower, "todo"), strings.Contains(lower, "action item"), strings.Contains(lower, "need to"), strings.Contains(lower, "must "), strings.Contains(lower, "should "): + return "task" + case strings.Contains(lower, "idea"), strings.Contains(lower, "proposal"), strings.Contains(lower, "brainstorm"): + return "idea" + case strings.Contains(lower, "reference"), strings.Contains(lower, "rfc "), strings.Contains(lower, "docs"), strings.Contains(lower, "spec"): + return "reference" + default: + return "observation" + } +} + +func heuristicTopics(lower string) []string { + candidates := []string{ + "mcp", "auth", "oauth", "api_keys", "token", "middleware", "postgres", "search", "embeddings", "metadata", + "go", "server", "project", "memory", "claude", "automation", "calendar", "email", "atlassian", "n8n", + } + + topics := make([]string, 0, 6) + for _, topic := range candidates { + if strings.Contains(lower, topic) { + topics = append(topics, topic) + } + if len(topics) >= 6 { + break + } + } + + if len(topics) > 0 { + return topics + } + + words := wordPattern.FindAllString(lower, -1) + for _, w := range words { + if len(w) < 4 { + continue + } + if slices.Contains(topics, w) { + continue + } + topics = append(topics, w) + if len(topics) >= 4 { + break + } + } + return topics +} + +func heuristicDates(text string) []string { + values := make([]string, 0, 4) + seen := map[string]struct{}{} + + for _, match := range monthDatePattern.FindAllString(text, -1) { + key := strings.ToLower(match) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + values = append(values, match) + } + + for _, match := range isoDatePattern.FindAllString(text, -1) { + key := strings.ToLower(match) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + values = append(values, match) + } + + return values +} + +func heuristicPeople(text string) []string { + lines := strings.Split(text, "\n") + people := make([]string, 0, 4) + seen := map[string]struct{}{} + + add := func(name string) { + name = strings.TrimSpace(name) + if name == "" { + return + } + key := strings.ToLower(name) + if _, ok := seen[key]; ok { + return + } + seen[key] = struct{}{} + people = append(people, name) + } + + for _, line := range lines { + l := strings.TrimSpace(line) + l = strings.TrimSpace(strings.TrimPrefix(strings.TrimPrefix(l, "-"), "*")) + if strings.Contains(strings.ToLower(l), "preferred name") && strings.Contains(l, " is ") { + parts := strings.SplitN(l, " is ", 2) + add(parts[0]) + } + for _, marker := range []string{"Wife:", "Daughter:", "Son:", "Partner:", "Name:"} { + if strings.HasPrefix(l, marker) { + rest := strings.TrimSpace(strings.TrimPrefix(l, marker)) + if idx := strings.Index(rest, ","); idx > 0 { + rest = rest[:idx] + } + add(rest) + } + } + } + + return people +} + +func heuristicActionItems(text string) []string { + lines := strings.Split(text, "\n") + items := make([]string, 0, 5) + for _, line := range lines { + l := strings.TrimSpace(strings.TrimPrefix(strings.TrimPrefix(line, "-"), "*")) + if l == "" { + continue + } + ll := strings.ToLower(l) + if strings.Contains(ll, "todo") || strings.HasPrefix(ll, "fix ") || strings.HasPrefix(ll, "add ") || strings.HasPrefix(ll, "update ") || strings.HasPrefix(ll, "implement ") { + items = append(items, l) + } + if len(items) >= 5 { + break + } + } + return items +} + func isRetryableStatus(status int) bool { switch status { case http.StatusTooManyRequests, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: @@ -518,3 +728,15 @@ func sleepRetry(ctx context.Context, attempt int, log *slog.Logger, provider str return nil } } + +func sleepMetadataRetry(ctx context.Context, attempt int) error { + delay := time.Duration(attempt) * 350 * time.Millisecond + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} diff --git a/internal/ai/compat/client_test.go b/internal/ai/compat/client_test.go index df0cd28..a79d518 100644 --- a/internal/ai/compat/client_test.go +++ b/internal/ai/compat/client_test.go @@ -7,6 +7,7 @@ import ( "log/slog" "net/http" "net/http/httptest" + "strings" "sync/atomic" "testing" ) @@ -250,3 +251,93 @@ func TestExtractMetadataUsesReasoningContentWhenContentEmpty(t *testing.T) { t.Fatalf("metadata people = %#v, want [Hein]", metadata.People) } } + +func TestExtractMetadataFallsBackToHeuristicsWhenModelsFail(t *testing.T) { + var calls atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = calls.Add(1) + _ = json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + {"message": map[string]any{"content": "not json"}}, + }, + }) + })) + defer server.Close() + + client := New(Config{ + Name: "test", + BaseURL: server.URL, + APIKey: "secret", + MetadataModel: "primary", + FallbackMetadataModel: "secondary", + HTTPClient: server.Client(), + Log: discardLogger(), + }) + + input := "Personal profile - Hein (Warkanum):\n- Born: 23 May 1989\n- Wife: Cindy, born 16 November 1994" + metadata, err := client.ExtractMetadata(context.Background(), input) + if err != nil { + t.Fatalf("ExtractMetadata() error = %v", err) + } + if calls.Load() != 2 { + t.Fatalf("call count = %d, want 2", calls.Load()) + } + if metadata.Type != "person_note" { + t.Fatalf("metadata type = %q, want person_note", metadata.Type) + } + if len(metadata.DatesMentioned) < 2 { + t.Fatalf("metadata dates = %#v, want extracted dates", metadata.DatesMentioned) + } + if len(metadata.People) == 0 || !strings.EqualFold(metadata.People[0], "Cindy") { + t.Fatalf("metadata people = %#v, want Cindy", metadata.People) + } +} + +func TestExtractMetadataRetriesEmptyResponse(t *testing.T) { + var calls atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + call := calls.Add(1) + var req chatCompletionsRequest + _ = json.NewDecoder(r.Body).Decode(&req) + + if req.Stream == nil || *req.Stream { + t.Fatalf("expected stream=false, got %#v", req.Stream) + } + + if call == 1 { + _ = json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + {"message": map[string]any{"content": ""}}, + }, + }) + return + } + + _ = json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + {"message": map[string]any{"content": "{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"mcp\"],\"type\":\"observation\",\"source\":\"mcp\"}"}}, + }, + }) + })) + defer server.Close() + + client := New(Config{ + Name: "test", + BaseURL: server.URL, + APIKey: "secret", + MetadataModel: "meta-model", + HTTPClient: server.Client(), + Log: discardLogger(), + }) + + metadata, err := client.ExtractMetadata(context.Background(), "hello") + if err != nil { + t.Fatalf("ExtractMetadata() error = %v", err) + } + if calls.Load() < 2 { + t.Fatalf("call count = %d, want >= 2", calls.Load()) + } + if metadata.Type != "observation" { + t.Fatalf("metadata type = %q, want observation", metadata.Type) + } +}