feat(config): add fallback model support for AI configurations

This commit is contained in:
2026-03-26 23:54:15 +02:00
parent 0eb6ac7ee5
commit a5c7b90f49
6 changed files with 104 additions and 71 deletions

View File

@@ -33,29 +33,31 @@ Rules:
- Do not include any text outside the JSON object.`
type Client struct {
name string
baseURL string
apiKey string
embeddingModel string
metadataModel string
temperature float64
headers map[string]string
httpClient *http.Client
log *slog.Logger
dimensions int
name string
baseURL string
apiKey string
embeddingModel string
metadataModel string
fallbackMetadataModel string
temperature float64
headers map[string]string
httpClient *http.Client
log *slog.Logger
dimensions int
}
type Config struct {
Name string
BaseURL string
APIKey string
EmbeddingModel string
MetadataModel string
Temperature float64
Headers map[string]string
HTTPClient *http.Client
Log *slog.Logger
Dimensions int
Name string
BaseURL string
APIKey string
EmbeddingModel string
MetadataModel string
FallbackMetadataModel string
Temperature float64
Headers map[string]string
HTTPClient *http.Client
Log *slog.Logger
Dimensions int
}
type embeddingsRequest struct {
@@ -100,16 +102,17 @@ type providerError struct {
func New(cfg Config) *Client {
return &Client{
name: cfg.Name,
baseURL: cfg.BaseURL,
apiKey: cfg.APIKey,
embeddingModel: cfg.EmbeddingModel,
metadataModel: cfg.MetadataModel,
temperature: cfg.Temperature,
headers: cfg.Headers,
httpClient: cfg.HTTPClient,
log: cfg.Log,
dimensions: cfg.Dimensions,
name: cfg.Name,
baseURL: cfg.BaseURL,
apiKey: cfg.APIKey,
embeddingModel: cfg.EmbeddingModel,
metadataModel: cfg.MetadataModel,
fallbackMetadataModel: cfg.FallbackMetadataModel,
temperature: cfg.Temperature,
headers: cfg.Headers,
httpClient: cfg.HTTPClient,
log: cfg.Log,
dimensions: cfg.Dimensions,
}
}
@@ -146,8 +149,24 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: input must not be empty", c.name)
}
result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel)
if err != nil && c.fallbackMetadataModel != "" && ctx.Err() == nil {
if c.log != nil {
c.log.Warn("metadata extraction failed, trying fallback model",
slog.String("provider", c.name),
slog.String("primary_model", c.metadataModel),
slog.String("fallback_model", c.fallbackMetadataModel),
slog.String("error", err.Error()),
)
}
return c.extractMetadataWithModel(ctx, input, c.fallbackMetadataModel)
}
return result, err
}
func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) {
req := chatCompletionsRequest{
Model: c.metadataModel,
Model: model,
Temperature: c.temperature,
ResponseFormat: &responseType{
Type: "json_object",
@@ -172,6 +191,9 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype
metadataText := strings.TrimSpace(resp.Choices[0].Message.Content)
metadataText = stripCodeFence(metadataText)
metadataText = extractJSONObject(metadataText)
if metadataText == "" {
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: response contains no JSON object", c.name)
}
var metadata thoughttypes.ThoughtMetadata
if err := json.Unmarshal([]byte(metadataText), &metadata); err != nil {
@@ -241,7 +263,7 @@ func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest
resp, err := c.httpClient.Do(req)
if err != nil {
lastErr = fmt.Errorf("%s request failed: %w", c.name, err)
if attempt < maxAttempts && isRetryableError(err) {
if attempt < maxAttempts && ctx.Err() == nil && isRetryableError(err) {
if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil {
return retryErr
}
@@ -293,7 +315,7 @@ func extractJSONObject(s string) string {
start := strings.Index(s, "{")
end := strings.LastIndex(s, "}")
if start == -1 || end == -1 || end <= start {
return s
return ""
}
return s[start : end+1]
}

View File

@@ -9,16 +9,21 @@ import (
)
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
fallback := cfg.LiteLLM.FallbackMetadataModel
if fallback == "" {
fallback = cfg.Metadata.FallbackModel
}
return compat.New(compat.Config{
Name: "litellm",
BaseURL: cfg.LiteLLM.BaseURL,
APIKey: cfg.LiteLLM.APIKey,
EmbeddingModel: cfg.LiteLLM.EmbeddingModel,
MetadataModel: cfg.LiteLLM.MetadataModel,
Temperature: cfg.Metadata.Temperature,
Headers: cfg.LiteLLM.RequestHeaders,
HTTPClient: httpClient,
Log: log,
Dimensions: cfg.Embeddings.Dimensions,
Name: "litellm",
BaseURL: cfg.LiteLLM.BaseURL,
APIKey: cfg.LiteLLM.APIKey,
EmbeddingModel: cfg.LiteLLM.EmbeddingModel,
MetadataModel: cfg.LiteLLM.MetadataModel,
FallbackMetadataModel: fallback,
Temperature: cfg.Metadata.Temperature,
Headers: cfg.LiteLLM.RequestHeaders,
HTTPClient: httpClient,
Log: log,
Dimensions: cfg.Embeddings.Dimensions,
}), nil
}

View File

@@ -10,15 +10,16 @@ import (
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
return compat.New(compat.Config{
Name: "ollama",
BaseURL: cfg.Ollama.BaseURL,
APIKey: cfg.Ollama.APIKey,
EmbeddingModel: cfg.Embeddings.Model,
MetadataModel: cfg.Metadata.Model,
Temperature: cfg.Metadata.Temperature,
Headers: cfg.Ollama.RequestHeaders,
HTTPClient: httpClient,
Log: log,
Dimensions: cfg.Embeddings.Dimensions,
Name: "ollama",
BaseURL: cfg.Ollama.BaseURL,
APIKey: cfg.Ollama.APIKey,
EmbeddingModel: cfg.Embeddings.Model,
MetadataModel: cfg.Metadata.Model,
FallbackMetadataModel: cfg.Metadata.FallbackModel,
Temperature: cfg.Metadata.Temperature,
Headers: cfg.Ollama.RequestHeaders,
HTTPClient: httpClient,
Log: log,
Dimensions: cfg.Embeddings.Dimensions,
}), nil
}

View File

@@ -21,15 +21,16 @@ func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compa
}
return compat.New(compat.Config{
Name: "openrouter",
BaseURL: cfg.OpenRouter.BaseURL,
APIKey: cfg.OpenRouter.APIKey,
EmbeddingModel: cfg.Embeddings.Model,
MetadataModel: cfg.Metadata.Model,
Temperature: cfg.Metadata.Temperature,
Headers: headers,
HTTPClient: httpClient,
Log: log,
Dimensions: cfg.Embeddings.Dimensions,
Name: "openrouter",
BaseURL: cfg.OpenRouter.BaseURL,
APIKey: cfg.OpenRouter.APIKey,
EmbeddingModel: cfg.Embeddings.Model,
MetadataModel: cfg.Metadata.Model,
FallbackMetadataModel: cfg.Metadata.FallbackModel,
Temperature: cfg.Metadata.Temperature,
Headers: headers,
HTTPClient: httpClient,
Log: log,
Dimensions: cfg.Embeddings.Dimensions,
}), nil
}