diff --git a/configs/config.example.yaml b/configs/config.example.yaml index 53afcc9..3cf4960 100644 --- a/configs/config.example.yaml +++ b/configs/config.example.yaml @@ -42,6 +42,7 @@ ai: dimensions: 1536 metadata: model: "gpt-4o-mini" + fallback_model: "" temperature: 0.1 litellm: base_url: "http://localhost:4000/v1" @@ -50,6 +51,7 @@ ai: request_headers: {} embedding_model: "openrouter/openai/text-embedding-3-small" metadata_model: "gpt-4o-mini" + fallback_metadata_model: "" ollama: base_url: "http://localhost:11434/v1" api_key: "ollama" diff --git a/internal/ai/compat/client.go b/internal/ai/compat/client.go index 62b7a68..885d87d 100644 --- a/internal/ai/compat/client.go +++ b/internal/ai/compat/client.go @@ -33,29 +33,31 @@ Rules: - Do not include any text outside the JSON object.` type Client struct { - name string - baseURL string - apiKey string - embeddingModel string - metadataModel string - temperature float64 - headers map[string]string - httpClient *http.Client - log *slog.Logger - dimensions int + name string + baseURL string + apiKey string + embeddingModel string + metadataModel string + fallbackMetadataModel string + temperature float64 + headers map[string]string + httpClient *http.Client + log *slog.Logger + dimensions int } type Config struct { - Name string - BaseURL string - APIKey string - EmbeddingModel string - MetadataModel string - Temperature float64 - Headers map[string]string - HTTPClient *http.Client - Log *slog.Logger - Dimensions int + Name string + BaseURL string + APIKey string + EmbeddingModel string + MetadataModel string + FallbackMetadataModel string + Temperature float64 + Headers map[string]string + HTTPClient *http.Client + Log *slog.Logger + Dimensions int } type embeddingsRequest struct { @@ -100,16 +102,17 @@ type providerError struct { func New(cfg Config) *Client { return &Client{ - name: cfg.Name, - baseURL: cfg.BaseURL, - apiKey: cfg.APIKey, - embeddingModel: cfg.EmbeddingModel, - metadataModel: cfg.MetadataModel, - temperature: cfg.Temperature, - headers: cfg.Headers, - httpClient: cfg.HTTPClient, - log: cfg.Log, - dimensions: cfg.Dimensions, + name: cfg.Name, + baseURL: cfg.BaseURL, + apiKey: cfg.APIKey, + embeddingModel: cfg.EmbeddingModel, + metadataModel: cfg.MetadataModel, + fallbackMetadataModel: cfg.FallbackMetadataModel, + temperature: cfg.Temperature, + headers: cfg.Headers, + httpClient: cfg.HTTPClient, + log: cfg.Log, + dimensions: cfg.Dimensions, } } @@ -146,8 +149,24 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: input must not be empty", c.name) } + result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel) + if err != nil && c.fallbackMetadataModel != "" && ctx.Err() == nil { + if c.log != nil { + c.log.Warn("metadata extraction failed, trying fallback model", + slog.String("provider", c.name), + slog.String("primary_model", c.metadataModel), + slog.String("fallback_model", c.fallbackMetadataModel), + slog.String("error", err.Error()), + ) + } + return c.extractMetadataWithModel(ctx, input, c.fallbackMetadataModel) + } + return result, err +} + +func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) { req := chatCompletionsRequest{ - Model: c.metadataModel, + Model: model, Temperature: c.temperature, ResponseFormat: &responseType{ Type: "json_object", @@ -172,6 +191,9 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype metadataText := strings.TrimSpace(resp.Choices[0].Message.Content) metadataText = stripCodeFence(metadataText) metadataText = extractJSONObject(metadataText) + if metadataText == "" { + return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: response contains no JSON object", c.name) + } var metadata thoughttypes.ThoughtMetadata if err := json.Unmarshal([]byte(metadataText), &metadata); err != nil { @@ -241,7 +263,7 @@ func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest resp, err := c.httpClient.Do(req) if err != nil { lastErr = fmt.Errorf("%s request failed: %w", c.name, err) - if attempt < maxAttempts && isRetryableError(err) { + if attempt < maxAttempts && ctx.Err() == nil && isRetryableError(err) { if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil { return retryErr } @@ -293,7 +315,7 @@ func extractJSONObject(s string) string { start := strings.Index(s, "{") end := strings.LastIndex(s, "}") if start == -1 || end == -1 || end <= start { - return s + return "" } return s[start : end+1] } diff --git a/internal/ai/litellm/client.go b/internal/ai/litellm/client.go index c2753dc..88bff34 100644 --- a/internal/ai/litellm/client.go +++ b/internal/ai/litellm/client.go @@ -9,16 +9,21 @@ import ( ) func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) { + fallback := cfg.LiteLLM.FallbackMetadataModel + if fallback == "" { + fallback = cfg.Metadata.FallbackModel + } return compat.New(compat.Config{ - Name: "litellm", - BaseURL: cfg.LiteLLM.BaseURL, - APIKey: cfg.LiteLLM.APIKey, - EmbeddingModel: cfg.LiteLLM.EmbeddingModel, - MetadataModel: cfg.LiteLLM.MetadataModel, - Temperature: cfg.Metadata.Temperature, - Headers: cfg.LiteLLM.RequestHeaders, - HTTPClient: httpClient, - Log: log, - Dimensions: cfg.Embeddings.Dimensions, + Name: "litellm", + BaseURL: cfg.LiteLLM.BaseURL, + APIKey: cfg.LiteLLM.APIKey, + EmbeddingModel: cfg.LiteLLM.EmbeddingModel, + MetadataModel: cfg.LiteLLM.MetadataModel, + FallbackMetadataModel: fallback, + Temperature: cfg.Metadata.Temperature, + Headers: cfg.LiteLLM.RequestHeaders, + HTTPClient: httpClient, + Log: log, + Dimensions: cfg.Embeddings.Dimensions, }), nil } diff --git a/internal/ai/ollama/client.go b/internal/ai/ollama/client.go index ca83a34..71443f4 100644 --- a/internal/ai/ollama/client.go +++ b/internal/ai/ollama/client.go @@ -10,15 +10,16 @@ import ( func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) { return compat.New(compat.Config{ - Name: "ollama", - BaseURL: cfg.Ollama.BaseURL, - APIKey: cfg.Ollama.APIKey, - EmbeddingModel: cfg.Embeddings.Model, - MetadataModel: cfg.Metadata.Model, - Temperature: cfg.Metadata.Temperature, - Headers: cfg.Ollama.RequestHeaders, - HTTPClient: httpClient, - Log: log, - Dimensions: cfg.Embeddings.Dimensions, + Name: "ollama", + BaseURL: cfg.Ollama.BaseURL, + APIKey: cfg.Ollama.APIKey, + EmbeddingModel: cfg.Embeddings.Model, + MetadataModel: cfg.Metadata.Model, + FallbackMetadataModel: cfg.Metadata.FallbackModel, + Temperature: cfg.Metadata.Temperature, + Headers: cfg.Ollama.RequestHeaders, + HTTPClient: httpClient, + Log: log, + Dimensions: cfg.Embeddings.Dimensions, }), nil } diff --git a/internal/ai/openrouter/client.go b/internal/ai/openrouter/client.go index c715080..0195edc 100644 --- a/internal/ai/openrouter/client.go +++ b/internal/ai/openrouter/client.go @@ -21,15 +21,16 @@ func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compa } return compat.New(compat.Config{ - Name: "openrouter", - BaseURL: cfg.OpenRouter.BaseURL, - APIKey: cfg.OpenRouter.APIKey, - EmbeddingModel: cfg.Embeddings.Model, - MetadataModel: cfg.Metadata.Model, - Temperature: cfg.Metadata.Temperature, - Headers: headers, - HTTPClient: httpClient, - Log: log, - Dimensions: cfg.Embeddings.Dimensions, + Name: "openrouter", + BaseURL: cfg.OpenRouter.BaseURL, + APIKey: cfg.OpenRouter.APIKey, + EmbeddingModel: cfg.Embeddings.Model, + MetadataModel: cfg.Metadata.Model, + FallbackMetadataModel: cfg.Metadata.FallbackModel, + Temperature: cfg.Metadata.Temperature, + Headers: headers, + HTTPClient: httpClient, + Log: log, + Dimensions: cfg.Embeddings.Dimensions, }), nil } diff --git a/internal/config/config.go b/internal/config/config.go index 1a57682..345712f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -84,17 +84,19 @@ type AIEmbeddingConfig struct { } type AIMetadataConfig struct { - Model string `yaml:"model"` - Temperature float64 `yaml:"temperature"` + Model string `yaml:"model"` + FallbackModel string `yaml:"fallback_model"` + Temperature float64 `yaml:"temperature"` } type LiteLLMConfig struct { - BaseURL string `yaml:"base_url"` - APIKey string `yaml:"api_key"` - UseResponsesAPI bool `yaml:"use_responses_api"` - RequestHeaders map[string]string `yaml:"request_headers"` - EmbeddingModel string `yaml:"embedding_model"` - MetadataModel string `yaml:"metadata_model"` + BaseURL string `yaml:"base_url"` + APIKey string `yaml:"api_key"` + UseResponsesAPI bool `yaml:"use_responses_api"` + RequestHeaders map[string]string `yaml:"request_headers"` + EmbeddingModel string `yaml:"embedding_model"` + MetadataModel string `yaml:"metadata_model"` + FallbackMetadataModel string `yaml:"fallback_metadata_model"` } type OllamaConfig struct {