feat(metadata): enhance metadata extraction with heuristic fallback support
This commit is contained in:
@@ -10,6 +10,8 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -78,6 +80,7 @@ type chatCompletionsRequest struct {
|
|||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
ResponseFormat *responseType `json:"response_format,omitempty"`
|
ResponseFormat *responseType `json:"response_format,omitempty"`
|
||||||
|
Stream *bool `json:"stream,omitempty"`
|
||||||
Messages []chatMessage `json:"messages"`
|
Messages []chatMessage `json:"messages"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,6 +112,8 @@ type providerError struct {
|
|||||||
Type string `json:"type,omitempty"`
|
Type string `json:"type,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const maxMetadataAttempts = 3
|
||||||
|
|
||||||
func New(cfg Config) *Client {
|
func New(cfg Config) *Client {
|
||||||
return &Client{
|
return &Client{
|
||||||
name: cfg.Name,
|
name: cfg.Name,
|
||||||
@@ -160,7 +165,11 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype
|
|||||||
}
|
}
|
||||||
|
|
||||||
result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel)
|
result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel)
|
||||||
if err != nil && c.fallbackMetadataModel != "" && ctx.Err() == nil {
|
if err == nil {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.fallbackMetadataModel != "" && ctx.Err() == nil {
|
||||||
if c.log != nil {
|
if c.log != nil {
|
||||||
c.log.Warn("metadata extraction failed, trying fallback model",
|
c.log.Warn("metadata extraction failed, trying fallback model",
|
||||||
slog.String("provider", c.name),
|
slog.String("provider", c.name),
|
||||||
@@ -169,28 +178,45 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype
|
|||||||
slog.String("error", err.Error()),
|
slog.String("error", err.Error()),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return c.extractMetadataWithModel(ctx, input, c.fallbackMetadataModel)
|
fallbackResult, fallbackErr := c.extractMetadataWithModel(ctx, input, c.fallbackMetadataModel)
|
||||||
|
if fallbackErr == nil {
|
||||||
|
return fallbackResult, nil
|
||||||
}
|
}
|
||||||
return result, err
|
err = fallbackErr
|
||||||
|
}
|
||||||
|
|
||||||
|
heuristic := heuristicMetadataFromInput(input)
|
||||||
|
if c.log != nil {
|
||||||
|
c.log.Warn("metadata extraction failed for all models, using heuristic fallback",
|
||||||
|
slog.String("provider", c.name),
|
||||||
|
slog.String("error", err.Error()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return heuristic, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) {
|
func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) {
|
||||||
|
stream := false
|
||||||
req := chatCompletionsRequest{
|
req := chatCompletionsRequest{
|
||||||
Model: model,
|
Model: model,
|
||||||
Temperature: c.temperature,
|
Temperature: c.temperature,
|
||||||
ResponseFormat: &responseType{
|
ResponseFormat: &responseType{
|
||||||
Type: "json_object",
|
Type: "json_object",
|
||||||
},
|
},
|
||||||
|
Stream: &stream,
|
||||||
Messages: []chatMessage{
|
Messages: []chatMessage{
|
||||||
{Role: "system", Content: metadataSystemPrompt},
|
{Role: "system", Content: metadataSystemPrompt},
|
||||||
{Role: "user", Content: input},
|
{Role: "user", Content: input},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
for attempt := 1; attempt <= maxMetadataAttempts; attempt++ {
|
||||||
if c.logConversations && c.log != nil {
|
if c.logConversations && c.log != nil {
|
||||||
c.log.Info("metadata conversation request",
|
c.log.Info("metadata conversation request",
|
||||||
slog.String("provider", c.name),
|
slog.String("provider", c.name),
|
||||||
slog.String("model", model),
|
slog.String("model", model),
|
||||||
|
slog.Int("attempt", attempt),
|
||||||
slog.String("system", metadataSystemPrompt),
|
slog.String("system", metadataSystemPrompt),
|
||||||
slog.String("input", input),
|
slog.String("input", input),
|
||||||
)
|
)
|
||||||
@@ -213,6 +239,7 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri
|
|||||||
c.log.Info("metadata conversation response",
|
c.log.Info("metadata conversation response",
|
||||||
slog.String("provider", c.name),
|
slog.String("provider", c.name),
|
||||||
slog.String("model", model),
|
slog.String("model", model),
|
||||||
|
slog.Int("attempt", attempt),
|
||||||
slog.String("response", rawResponse),
|
slog.String("response", rawResponse),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -222,15 +249,36 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri
|
|||||||
metadataText = stripCodeFence(metadataText)
|
metadataText = stripCodeFence(metadataText)
|
||||||
metadataText = extractJSONObject(metadataText)
|
metadataText = extractJSONObject(metadataText)
|
||||||
if metadataText == "" {
|
if metadataText == "" {
|
||||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: response contains no JSON object", c.name)
|
lastErr = fmt.Errorf("%s metadata: response contains no JSON object", c.name)
|
||||||
|
if strings.TrimSpace(rawResponse) == "" && attempt < maxMetadataAttempts && ctx.Err() == nil {
|
||||||
|
if c.log != nil {
|
||||||
|
c.log.Warn("metadata response empty, waiting and retrying",
|
||||||
|
slog.String("provider", c.name),
|
||||||
|
slog.String("model", model),
|
||||||
|
slog.Int("attempt", attempt+1),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if err := sleepMetadataRetry(ctx, attempt); err != nil {
|
||||||
|
return thoughttypes.ThoughtMetadata{}, err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return thoughttypes.ThoughtMetadata{}, lastErr
|
||||||
}
|
}
|
||||||
|
|
||||||
var metadata thoughttypes.ThoughtMetadata
|
var metadata thoughttypes.ThoughtMetadata
|
||||||
if err := json.Unmarshal([]byte(metadataText), &metadata); err != nil {
|
if err := json.Unmarshal([]byte(metadataText), &metadata); err != nil {
|
||||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: parse json: %w", c.name, err)
|
lastErr = fmt.Errorf("%s metadata: parse json: %w", c.name, err)
|
||||||
|
return thoughttypes.ThoughtMetadata{}, lastErr
|
||||||
}
|
}
|
||||||
|
|
||||||
return metadata, nil
|
return metadata, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastErr != nil {
|
||||||
|
return thoughttypes.ThoughtMetadata{}, lastErr
|
||||||
|
}
|
||||||
|
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: response contains no JSON object", c.name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
|
func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
|
||||||
@@ -484,6 +532,168 @@ func extractTextFromAny(value any) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
monthDatePattern = regexp.MustCompile(`(?i)\b\d{1,2}\s+(?:jan(?:uary)?|feb(?:ruary)?|mar(?:ch)?|apr(?:il)?|may|jun(?:e)?|jul(?:y)?|aug(?:ust)?|sep(?:t(?:ember)?)?|oct(?:ober)?|nov(?:ember)?|dec(?:ember)?)\s+\d{4}\b`)
|
||||||
|
isoDatePattern = regexp.MustCompile(`\b\d{4}-\d{2}-\d{2}\b`)
|
||||||
|
wordPattern = regexp.MustCompile(`[a-zA-Z][a-zA-Z0-9_/-]{2,}`)
|
||||||
|
)
|
||||||
|
|
||||||
|
func heuristicMetadataFromInput(input string) thoughttypes.ThoughtMetadata {
|
||||||
|
text := strings.TrimSpace(input)
|
||||||
|
lower := strings.ToLower(text)
|
||||||
|
|
||||||
|
metadata := thoughttypes.ThoughtMetadata{
|
||||||
|
People: heuristicPeople(text),
|
||||||
|
ActionItems: heuristicActionItems(text),
|
||||||
|
DatesMentioned: heuristicDates(text),
|
||||||
|
Topics: heuristicTopics(lower),
|
||||||
|
Type: heuristicType(lower),
|
||||||
|
Source: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(metadata.Topics) == 0 {
|
||||||
|
metadata.Topics = []string{"uncategorized"}
|
||||||
|
}
|
||||||
|
if metadata.Type == "" {
|
||||||
|
metadata.Type = "observation"
|
||||||
|
}
|
||||||
|
return metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
func heuristicType(lower string) string {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(lower, "preferred name"), strings.Contains(lower, "personal profile"), strings.Contains(lower, "wife:"), strings.Contains(lower, "daughter:"), strings.Contains(lower, "born:"):
|
||||||
|
return "person_note"
|
||||||
|
case strings.Contains(lower, "todo"), strings.Contains(lower, "action item"), strings.Contains(lower, "need to"), strings.Contains(lower, "must "), strings.Contains(lower, "should "):
|
||||||
|
return "task"
|
||||||
|
case strings.Contains(lower, "idea"), strings.Contains(lower, "proposal"), strings.Contains(lower, "brainstorm"):
|
||||||
|
return "idea"
|
||||||
|
case strings.Contains(lower, "reference"), strings.Contains(lower, "rfc "), strings.Contains(lower, "docs"), strings.Contains(lower, "spec"):
|
||||||
|
return "reference"
|
||||||
|
default:
|
||||||
|
return "observation"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func heuristicTopics(lower string) []string {
|
||||||
|
candidates := []string{
|
||||||
|
"mcp", "auth", "oauth", "api_keys", "token", "middleware", "postgres", "search", "embeddings", "metadata",
|
||||||
|
"go", "server", "project", "memory", "claude", "automation", "calendar", "email", "atlassian", "n8n",
|
||||||
|
}
|
||||||
|
|
||||||
|
topics := make([]string, 0, 6)
|
||||||
|
for _, topic := range candidates {
|
||||||
|
if strings.Contains(lower, topic) {
|
||||||
|
topics = append(topics, topic)
|
||||||
|
}
|
||||||
|
if len(topics) >= 6 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(topics) > 0 {
|
||||||
|
return topics
|
||||||
|
}
|
||||||
|
|
||||||
|
words := wordPattern.FindAllString(lower, -1)
|
||||||
|
for _, w := range words {
|
||||||
|
if len(w) < 4 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if slices.Contains(topics, w) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
topics = append(topics, w)
|
||||||
|
if len(topics) >= 4 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return topics
|
||||||
|
}
|
||||||
|
|
||||||
|
func heuristicDates(text string) []string {
|
||||||
|
values := make([]string, 0, 4)
|
||||||
|
seen := map[string]struct{}{}
|
||||||
|
|
||||||
|
for _, match := range monthDatePattern.FindAllString(text, -1) {
|
||||||
|
key := strings.ToLower(match)
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
values = append(values, match)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, match := range isoDatePattern.FindAllString(text, -1) {
|
||||||
|
key := strings.ToLower(match)
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
values = append(values, match)
|
||||||
|
}
|
||||||
|
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
||||||
|
func heuristicPeople(text string) []string {
|
||||||
|
lines := strings.Split(text, "\n")
|
||||||
|
people := make([]string, 0, 4)
|
||||||
|
seen := map[string]struct{}{}
|
||||||
|
|
||||||
|
add := func(name string) {
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if name == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key := strings.ToLower(name)
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
people = append(people, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, line := range lines {
|
||||||
|
l := strings.TrimSpace(line)
|
||||||
|
l = strings.TrimSpace(strings.TrimPrefix(strings.TrimPrefix(l, "-"), "*"))
|
||||||
|
if strings.Contains(strings.ToLower(l), "preferred name") && strings.Contains(l, " is ") {
|
||||||
|
parts := strings.SplitN(l, " is ", 2)
|
||||||
|
add(parts[0])
|
||||||
|
}
|
||||||
|
for _, marker := range []string{"Wife:", "Daughter:", "Son:", "Partner:", "Name:"} {
|
||||||
|
if strings.HasPrefix(l, marker) {
|
||||||
|
rest := strings.TrimSpace(strings.TrimPrefix(l, marker))
|
||||||
|
if idx := strings.Index(rest, ","); idx > 0 {
|
||||||
|
rest = rest[:idx]
|
||||||
|
}
|
||||||
|
add(rest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return people
|
||||||
|
}
|
||||||
|
|
||||||
|
func heuristicActionItems(text string) []string {
|
||||||
|
lines := strings.Split(text, "\n")
|
||||||
|
items := make([]string, 0, 5)
|
||||||
|
for _, line := range lines {
|
||||||
|
l := strings.TrimSpace(strings.TrimPrefix(strings.TrimPrefix(line, "-"), "*"))
|
||||||
|
if l == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ll := strings.ToLower(l)
|
||||||
|
if strings.Contains(ll, "todo") || strings.HasPrefix(ll, "fix ") || strings.HasPrefix(ll, "add ") || strings.HasPrefix(ll, "update ") || strings.HasPrefix(ll, "implement ") {
|
||||||
|
items = append(items, l)
|
||||||
|
}
|
||||||
|
if len(items) >= 5 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return items
|
||||||
|
}
|
||||||
|
|
||||||
func isRetryableStatus(status int) bool {
|
func isRetryableStatus(status int) bool {
|
||||||
switch status {
|
switch status {
|
||||||
case http.StatusTooManyRequests, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
|
case http.StatusTooManyRequests, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
|
||||||
@@ -518,3 +728,15 @@ func sleepRetry(ctx context.Context, attempt int, log *slog.Logger, provider str
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sleepMetadataRetry(ctx context.Context, attempt int) error {
|
||||||
|
delay := time.Duration(attempt) * 350 * time.Millisecond
|
||||||
|
timer := time.NewTimer(delay)
|
||||||
|
defer timer.Stop()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-timer.C:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@@ -250,3 +251,93 @@ func TestExtractMetadataUsesReasoningContentWhenContentEmpty(t *testing.T) {
|
|||||||
t.Fatalf("metadata people = %#v, want [Hein]", metadata.People)
|
t.Fatalf("metadata people = %#v, want [Hein]", metadata.People)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtractMetadataFallsBackToHeuristicsWhenModelsFail(t *testing.T) {
|
||||||
|
var calls atomic.Int32
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_ = calls.Add(1)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"choices": []map[string]any{
|
||||||
|
{"message": map[string]any{"content": "not json"}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := New(Config{
|
||||||
|
Name: "test",
|
||||||
|
BaseURL: server.URL,
|
||||||
|
APIKey: "secret",
|
||||||
|
MetadataModel: "primary",
|
||||||
|
FallbackMetadataModel: "secondary",
|
||||||
|
HTTPClient: server.Client(),
|
||||||
|
Log: discardLogger(),
|
||||||
|
})
|
||||||
|
|
||||||
|
input := "Personal profile - Hein (Warkanum):\n- Born: 23 May 1989\n- Wife: Cindy, born 16 November 1994"
|
||||||
|
metadata, err := client.ExtractMetadata(context.Background(), input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExtractMetadata() error = %v", err)
|
||||||
|
}
|
||||||
|
if calls.Load() != 2 {
|
||||||
|
t.Fatalf("call count = %d, want 2", calls.Load())
|
||||||
|
}
|
||||||
|
if metadata.Type != "person_note" {
|
||||||
|
t.Fatalf("metadata type = %q, want person_note", metadata.Type)
|
||||||
|
}
|
||||||
|
if len(metadata.DatesMentioned) < 2 {
|
||||||
|
t.Fatalf("metadata dates = %#v, want extracted dates", metadata.DatesMentioned)
|
||||||
|
}
|
||||||
|
if len(metadata.People) == 0 || !strings.EqualFold(metadata.People[0], "Cindy") {
|
||||||
|
t.Fatalf("metadata people = %#v, want Cindy", metadata.People)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractMetadataRetriesEmptyResponse(t *testing.T) {
|
||||||
|
var calls atomic.Int32
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
call := calls.Add(1)
|
||||||
|
var req chatCompletionsRequest
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
|
||||||
|
if req.Stream == nil || *req.Stream {
|
||||||
|
t.Fatalf("expected stream=false, got %#v", req.Stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
if call == 1 {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"choices": []map[string]any{
|
||||||
|
{"message": map[string]any{"content": ""}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"choices": []map[string]any{
|
||||||
|
{"message": map[string]any{"content": "{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"mcp\"],\"type\":\"observation\",\"source\":\"mcp\"}"}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
client := New(Config{
|
||||||
|
Name: "test",
|
||||||
|
BaseURL: server.URL,
|
||||||
|
APIKey: "secret",
|
||||||
|
MetadataModel: "meta-model",
|
||||||
|
HTTPClient: server.Client(),
|
||||||
|
Log: discardLogger(),
|
||||||
|
})
|
||||||
|
|
||||||
|
metadata, err := client.ExtractMetadata(context.Background(), "hello")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExtractMetadata() error = %v", err)
|
||||||
|
}
|
||||||
|
if calls.Load() < 2 {
|
||||||
|
t.Fatalf("call count = %d, want >= 2", calls.Load())
|
||||||
|
}
|
||||||
|
if metadata.Type != "observation" {
|
||||||
|
t.Fatalf("metadata type = %q, want observation", metadata.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user