feat(config): add fallback model support for AI configurations

This commit is contained in:
2026-03-26 23:54:15 +02:00
parent 0eb6ac7ee5
commit a5c7b90f49
6 changed files with 104 additions and 71 deletions

View File

@@ -42,6 +42,7 @@ ai:
dimensions: 1536 dimensions: 1536
metadata: metadata:
model: "gpt-4o-mini" model: "gpt-4o-mini"
fallback_model: ""
temperature: 0.1 temperature: 0.1
litellm: litellm:
base_url: "http://localhost:4000/v1" base_url: "http://localhost:4000/v1"
@@ -50,6 +51,7 @@ ai:
request_headers: {} request_headers: {}
embedding_model: "openrouter/openai/text-embedding-3-small" embedding_model: "openrouter/openai/text-embedding-3-small"
metadata_model: "gpt-4o-mini" metadata_model: "gpt-4o-mini"
fallback_metadata_model: ""
ollama: ollama:
base_url: "http://localhost:11434/v1" base_url: "http://localhost:11434/v1"
api_key: "ollama" api_key: "ollama"

View File

@@ -38,6 +38,7 @@ type Client struct {
apiKey string apiKey string
embeddingModel string embeddingModel string
metadataModel string metadataModel string
fallbackMetadataModel string
temperature float64 temperature float64
headers map[string]string headers map[string]string
httpClient *http.Client httpClient *http.Client
@@ -51,6 +52,7 @@ type Config struct {
APIKey string APIKey string
EmbeddingModel string EmbeddingModel string
MetadataModel string MetadataModel string
FallbackMetadataModel string
Temperature float64 Temperature float64
Headers map[string]string Headers map[string]string
HTTPClient *http.Client HTTPClient *http.Client
@@ -105,6 +107,7 @@ func New(cfg Config) *Client {
apiKey: cfg.APIKey, apiKey: cfg.APIKey,
embeddingModel: cfg.EmbeddingModel, embeddingModel: cfg.EmbeddingModel,
metadataModel: cfg.MetadataModel, metadataModel: cfg.MetadataModel,
fallbackMetadataModel: cfg.FallbackMetadataModel,
temperature: cfg.Temperature, temperature: cfg.Temperature,
headers: cfg.Headers, headers: cfg.Headers,
httpClient: cfg.HTTPClient, httpClient: cfg.HTTPClient,
@@ -146,8 +149,24 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: input must not be empty", c.name) return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: input must not be empty", c.name)
} }
result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel)
if err != nil && c.fallbackMetadataModel != "" && ctx.Err() == nil {
if c.log != nil {
c.log.Warn("metadata extraction failed, trying fallback model",
slog.String("provider", c.name),
slog.String("primary_model", c.metadataModel),
slog.String("fallback_model", c.fallbackMetadataModel),
slog.String("error", err.Error()),
)
}
return c.extractMetadataWithModel(ctx, input, c.fallbackMetadataModel)
}
return result, err
}
func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) {
req := chatCompletionsRequest{ req := chatCompletionsRequest{
Model: c.metadataModel, Model: model,
Temperature: c.temperature, Temperature: c.temperature,
ResponseFormat: &responseType{ ResponseFormat: &responseType{
Type: "json_object", Type: "json_object",
@@ -172,6 +191,9 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype
metadataText := strings.TrimSpace(resp.Choices[0].Message.Content) metadataText := strings.TrimSpace(resp.Choices[0].Message.Content)
metadataText = stripCodeFence(metadataText) metadataText = stripCodeFence(metadataText)
metadataText = extractJSONObject(metadataText) metadataText = extractJSONObject(metadataText)
if metadataText == "" {
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: response contains no JSON object", c.name)
}
var metadata thoughttypes.ThoughtMetadata var metadata thoughttypes.ThoughtMetadata
if err := json.Unmarshal([]byte(metadataText), &metadata); err != nil { if err := json.Unmarshal([]byte(metadataText), &metadata); err != nil {
@@ -241,7 +263,7 @@ func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest
resp, err := c.httpClient.Do(req) resp, err := c.httpClient.Do(req)
if err != nil { if err != nil {
lastErr = fmt.Errorf("%s request failed: %w", c.name, err) lastErr = fmt.Errorf("%s request failed: %w", c.name, err)
if attempt < maxAttempts && isRetryableError(err) { if attempt < maxAttempts && ctx.Err() == nil && isRetryableError(err) {
if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil { if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil {
return retryErr return retryErr
} }
@@ -293,7 +315,7 @@ func extractJSONObject(s string) string {
start := strings.Index(s, "{") start := strings.Index(s, "{")
end := strings.LastIndex(s, "}") end := strings.LastIndex(s, "}")
if start == -1 || end == -1 || end <= start { if start == -1 || end == -1 || end <= start {
return s return ""
} }
return s[start : end+1] return s[start : end+1]
} }

View File

@@ -9,12 +9,17 @@ import (
) )
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) { func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
fallback := cfg.LiteLLM.FallbackMetadataModel
if fallback == "" {
fallback = cfg.Metadata.FallbackModel
}
return compat.New(compat.Config{ return compat.New(compat.Config{
Name: "litellm", Name: "litellm",
BaseURL: cfg.LiteLLM.BaseURL, BaseURL: cfg.LiteLLM.BaseURL,
APIKey: cfg.LiteLLM.APIKey, APIKey: cfg.LiteLLM.APIKey,
EmbeddingModel: cfg.LiteLLM.EmbeddingModel, EmbeddingModel: cfg.LiteLLM.EmbeddingModel,
MetadataModel: cfg.LiteLLM.MetadataModel, MetadataModel: cfg.LiteLLM.MetadataModel,
FallbackMetadataModel: fallback,
Temperature: cfg.Metadata.Temperature, Temperature: cfg.Metadata.Temperature,
Headers: cfg.LiteLLM.RequestHeaders, Headers: cfg.LiteLLM.RequestHeaders,
HTTPClient: httpClient, HTTPClient: httpClient,

View File

@@ -15,6 +15,7 @@ func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compa
APIKey: cfg.Ollama.APIKey, APIKey: cfg.Ollama.APIKey,
EmbeddingModel: cfg.Embeddings.Model, EmbeddingModel: cfg.Embeddings.Model,
MetadataModel: cfg.Metadata.Model, MetadataModel: cfg.Metadata.Model,
FallbackMetadataModel: cfg.Metadata.FallbackModel,
Temperature: cfg.Metadata.Temperature, Temperature: cfg.Metadata.Temperature,
Headers: cfg.Ollama.RequestHeaders, Headers: cfg.Ollama.RequestHeaders,
HTTPClient: httpClient, HTTPClient: httpClient,

View File

@@ -26,6 +26,7 @@ func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compa
APIKey: cfg.OpenRouter.APIKey, APIKey: cfg.OpenRouter.APIKey,
EmbeddingModel: cfg.Embeddings.Model, EmbeddingModel: cfg.Embeddings.Model,
MetadataModel: cfg.Metadata.Model, MetadataModel: cfg.Metadata.Model,
FallbackMetadataModel: cfg.Metadata.FallbackModel,
Temperature: cfg.Metadata.Temperature, Temperature: cfg.Metadata.Temperature,
Headers: headers, Headers: headers,
HTTPClient: httpClient, HTTPClient: httpClient,

View File

@@ -85,6 +85,7 @@ type AIEmbeddingConfig struct {
type AIMetadataConfig struct { type AIMetadataConfig struct {
Model string `yaml:"model"` Model string `yaml:"model"`
FallbackModel string `yaml:"fallback_model"`
Temperature float64 `yaml:"temperature"` Temperature float64 `yaml:"temperature"`
} }
@@ -95,6 +96,7 @@ type LiteLLMConfig struct {
RequestHeaders map[string]string `yaml:"request_headers"` RequestHeaders map[string]string `yaml:"request_headers"`
EmbeddingModel string `yaml:"embedding_model"` EmbeddingModel string `yaml:"embedding_model"`
MetadataModel string `yaml:"metadata_model"` MetadataModel string `yaml:"metadata_model"`
FallbackMetadataModel string `yaml:"fallback_metadata_model"`
} }
type OllamaConfig struct { type OllamaConfig struct {