feat(config): add fallback model support for AI configurations

This commit is contained in:
2026-03-26 23:54:15 +02:00
parent 0eb6ac7ee5
commit a5c7b90f49
6 changed files with 104 additions and 71 deletions

View File

@@ -42,6 +42,7 @@ ai:
dimensions: 1536 dimensions: 1536
metadata: metadata:
model: "gpt-4o-mini" model: "gpt-4o-mini"
fallback_model: ""
temperature: 0.1 temperature: 0.1
litellm: litellm:
base_url: "http://localhost:4000/v1" base_url: "http://localhost:4000/v1"
@@ -50,6 +51,7 @@ ai:
request_headers: {} request_headers: {}
embedding_model: "openrouter/openai/text-embedding-3-small" embedding_model: "openrouter/openai/text-embedding-3-small"
metadata_model: "gpt-4o-mini" metadata_model: "gpt-4o-mini"
fallback_metadata_model: ""
ollama: ollama:
base_url: "http://localhost:11434/v1" base_url: "http://localhost:11434/v1"
api_key: "ollama" api_key: "ollama"

View File

@@ -33,29 +33,31 @@ Rules:
- Do not include any text outside the JSON object.` - Do not include any text outside the JSON object.`
type Client struct { type Client struct {
name string name string
baseURL string baseURL string
apiKey string apiKey string
embeddingModel string embeddingModel string
metadataModel string metadataModel string
temperature float64 fallbackMetadataModel string
headers map[string]string temperature float64
httpClient *http.Client headers map[string]string
log *slog.Logger httpClient *http.Client
dimensions int log *slog.Logger
dimensions int
} }
type Config struct { type Config struct {
Name string Name string
BaseURL string BaseURL string
APIKey string APIKey string
EmbeddingModel string EmbeddingModel string
MetadataModel string MetadataModel string
Temperature float64 FallbackMetadataModel string
Headers map[string]string Temperature float64
HTTPClient *http.Client Headers map[string]string
Log *slog.Logger HTTPClient *http.Client
Dimensions int Log *slog.Logger
Dimensions int
} }
type embeddingsRequest struct { type embeddingsRequest struct {
@@ -100,16 +102,17 @@ type providerError struct {
func New(cfg Config) *Client { func New(cfg Config) *Client {
return &Client{ return &Client{
name: cfg.Name, name: cfg.Name,
baseURL: cfg.BaseURL, baseURL: cfg.BaseURL,
apiKey: cfg.APIKey, apiKey: cfg.APIKey,
embeddingModel: cfg.EmbeddingModel, embeddingModel: cfg.EmbeddingModel,
metadataModel: cfg.MetadataModel, metadataModel: cfg.MetadataModel,
temperature: cfg.Temperature, fallbackMetadataModel: cfg.FallbackMetadataModel,
headers: cfg.Headers, temperature: cfg.Temperature,
httpClient: cfg.HTTPClient, headers: cfg.Headers,
log: cfg.Log, httpClient: cfg.HTTPClient,
dimensions: cfg.Dimensions, log: cfg.Log,
dimensions: cfg.Dimensions,
} }
} }
@@ -146,8 +149,24 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: input must not be empty", c.name) return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: input must not be empty", c.name)
} }
result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel)
if err != nil && c.fallbackMetadataModel != "" && ctx.Err() == nil {
if c.log != nil {
c.log.Warn("metadata extraction failed, trying fallback model",
slog.String("provider", c.name),
slog.String("primary_model", c.metadataModel),
slog.String("fallback_model", c.fallbackMetadataModel),
slog.String("error", err.Error()),
)
}
return c.extractMetadataWithModel(ctx, input, c.fallbackMetadataModel)
}
return result, err
}
func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) {
req := chatCompletionsRequest{ req := chatCompletionsRequest{
Model: c.metadataModel, Model: model,
Temperature: c.temperature, Temperature: c.temperature,
ResponseFormat: &responseType{ ResponseFormat: &responseType{
Type: "json_object", Type: "json_object",
@@ -172,6 +191,9 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype
metadataText := strings.TrimSpace(resp.Choices[0].Message.Content) metadataText := strings.TrimSpace(resp.Choices[0].Message.Content)
metadataText = stripCodeFence(metadataText) metadataText = stripCodeFence(metadataText)
metadataText = extractJSONObject(metadataText) metadataText = extractJSONObject(metadataText)
if metadataText == "" {
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: response contains no JSON object", c.name)
}
var metadata thoughttypes.ThoughtMetadata var metadata thoughttypes.ThoughtMetadata
if err := json.Unmarshal([]byte(metadataText), &metadata); err != nil { if err := json.Unmarshal([]byte(metadataText), &metadata); err != nil {
@@ -241,7 +263,7 @@ func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest
resp, err := c.httpClient.Do(req) resp, err := c.httpClient.Do(req)
if err != nil { if err != nil {
lastErr = fmt.Errorf("%s request failed: %w", c.name, err) lastErr = fmt.Errorf("%s request failed: %w", c.name, err)
if attempt < maxAttempts && isRetryableError(err) { if attempt < maxAttempts && ctx.Err() == nil && isRetryableError(err) {
if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil { if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil {
return retryErr return retryErr
} }
@@ -293,7 +315,7 @@ func extractJSONObject(s string) string {
start := strings.Index(s, "{") start := strings.Index(s, "{")
end := strings.LastIndex(s, "}") end := strings.LastIndex(s, "}")
if start == -1 || end == -1 || end <= start { if start == -1 || end == -1 || end <= start {
return s return ""
} }
return s[start : end+1] return s[start : end+1]
} }

View File

@@ -9,16 +9,21 @@ import (
) )
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) { func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
fallback := cfg.LiteLLM.FallbackMetadataModel
if fallback == "" {
fallback = cfg.Metadata.FallbackModel
}
return compat.New(compat.Config{ return compat.New(compat.Config{
Name: "litellm", Name: "litellm",
BaseURL: cfg.LiteLLM.BaseURL, BaseURL: cfg.LiteLLM.BaseURL,
APIKey: cfg.LiteLLM.APIKey, APIKey: cfg.LiteLLM.APIKey,
EmbeddingModel: cfg.LiteLLM.EmbeddingModel, EmbeddingModel: cfg.LiteLLM.EmbeddingModel,
MetadataModel: cfg.LiteLLM.MetadataModel, MetadataModel: cfg.LiteLLM.MetadataModel,
Temperature: cfg.Metadata.Temperature, FallbackMetadataModel: fallback,
Headers: cfg.LiteLLM.RequestHeaders, Temperature: cfg.Metadata.Temperature,
HTTPClient: httpClient, Headers: cfg.LiteLLM.RequestHeaders,
Log: log, HTTPClient: httpClient,
Dimensions: cfg.Embeddings.Dimensions, Log: log,
Dimensions: cfg.Embeddings.Dimensions,
}), nil }), nil
} }

View File

@@ -10,15 +10,16 @@ import (
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) { func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
return compat.New(compat.Config{ return compat.New(compat.Config{
Name: "ollama", Name: "ollama",
BaseURL: cfg.Ollama.BaseURL, BaseURL: cfg.Ollama.BaseURL,
APIKey: cfg.Ollama.APIKey, APIKey: cfg.Ollama.APIKey,
EmbeddingModel: cfg.Embeddings.Model, EmbeddingModel: cfg.Embeddings.Model,
MetadataModel: cfg.Metadata.Model, MetadataModel: cfg.Metadata.Model,
Temperature: cfg.Metadata.Temperature, FallbackMetadataModel: cfg.Metadata.FallbackModel,
Headers: cfg.Ollama.RequestHeaders, Temperature: cfg.Metadata.Temperature,
HTTPClient: httpClient, Headers: cfg.Ollama.RequestHeaders,
Log: log, HTTPClient: httpClient,
Dimensions: cfg.Embeddings.Dimensions, Log: log,
Dimensions: cfg.Embeddings.Dimensions,
}), nil }), nil
} }

View File

@@ -21,15 +21,16 @@ func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compa
} }
return compat.New(compat.Config{ return compat.New(compat.Config{
Name: "openrouter", Name: "openrouter",
BaseURL: cfg.OpenRouter.BaseURL, BaseURL: cfg.OpenRouter.BaseURL,
APIKey: cfg.OpenRouter.APIKey, APIKey: cfg.OpenRouter.APIKey,
EmbeddingModel: cfg.Embeddings.Model, EmbeddingModel: cfg.Embeddings.Model,
MetadataModel: cfg.Metadata.Model, MetadataModel: cfg.Metadata.Model,
Temperature: cfg.Metadata.Temperature, FallbackMetadataModel: cfg.Metadata.FallbackModel,
Headers: headers, Temperature: cfg.Metadata.Temperature,
HTTPClient: httpClient, Headers: headers,
Log: log, HTTPClient: httpClient,
Dimensions: cfg.Embeddings.Dimensions, Log: log,
Dimensions: cfg.Embeddings.Dimensions,
}), nil }), nil
} }

View File

@@ -84,17 +84,19 @@ type AIEmbeddingConfig struct {
} }
type AIMetadataConfig struct { type AIMetadataConfig struct {
Model string `yaml:"model"` Model string `yaml:"model"`
Temperature float64 `yaml:"temperature"` FallbackModel string `yaml:"fallback_model"`
Temperature float64 `yaml:"temperature"`
} }
type LiteLLMConfig struct { type LiteLLMConfig struct {
BaseURL string `yaml:"base_url"` BaseURL string `yaml:"base_url"`
APIKey string `yaml:"api_key"` APIKey string `yaml:"api_key"`
UseResponsesAPI bool `yaml:"use_responses_api"` UseResponsesAPI bool `yaml:"use_responses_api"`
RequestHeaders map[string]string `yaml:"request_headers"` RequestHeaders map[string]string `yaml:"request_headers"`
EmbeddingModel string `yaml:"embedding_model"` EmbeddingModel string `yaml:"embedding_model"`
MetadataModel string `yaml:"metadata_model"` MetadataModel string `yaml:"metadata_model"`
FallbackMetadataModel string `yaml:"fallback_metadata_model"`
} }
type OllamaConfig struct { type OllamaConfig struct {