feat(config): add fallback model support for AI configurations
This commit is contained in:
@@ -33,29 +33,31 @@ Rules:
|
||||
- Do not include any text outside the JSON object.`
|
||||
|
||||
type Client struct {
|
||||
name string
|
||||
baseURL string
|
||||
apiKey string
|
||||
embeddingModel string
|
||||
metadataModel string
|
||||
temperature float64
|
||||
headers map[string]string
|
||||
httpClient *http.Client
|
||||
log *slog.Logger
|
||||
dimensions int
|
||||
name string
|
||||
baseURL string
|
||||
apiKey string
|
||||
embeddingModel string
|
||||
metadataModel string
|
||||
fallbackMetadataModel string
|
||||
temperature float64
|
||||
headers map[string]string
|
||||
httpClient *http.Client
|
||||
log *slog.Logger
|
||||
dimensions int
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string
|
||||
BaseURL string
|
||||
APIKey string
|
||||
EmbeddingModel string
|
||||
MetadataModel string
|
||||
Temperature float64
|
||||
Headers map[string]string
|
||||
HTTPClient *http.Client
|
||||
Log *slog.Logger
|
||||
Dimensions int
|
||||
Name string
|
||||
BaseURL string
|
||||
APIKey string
|
||||
EmbeddingModel string
|
||||
MetadataModel string
|
||||
FallbackMetadataModel string
|
||||
Temperature float64
|
||||
Headers map[string]string
|
||||
HTTPClient *http.Client
|
||||
Log *slog.Logger
|
||||
Dimensions int
|
||||
}
|
||||
|
||||
type embeddingsRequest struct {
|
||||
@@ -100,16 +102,17 @@ type providerError struct {
|
||||
|
||||
func New(cfg Config) *Client {
|
||||
return &Client{
|
||||
name: cfg.Name,
|
||||
baseURL: cfg.BaseURL,
|
||||
apiKey: cfg.APIKey,
|
||||
embeddingModel: cfg.EmbeddingModel,
|
||||
metadataModel: cfg.MetadataModel,
|
||||
temperature: cfg.Temperature,
|
||||
headers: cfg.Headers,
|
||||
httpClient: cfg.HTTPClient,
|
||||
log: cfg.Log,
|
||||
dimensions: cfg.Dimensions,
|
||||
name: cfg.Name,
|
||||
baseURL: cfg.BaseURL,
|
||||
apiKey: cfg.APIKey,
|
||||
embeddingModel: cfg.EmbeddingModel,
|
||||
metadataModel: cfg.MetadataModel,
|
||||
fallbackMetadataModel: cfg.FallbackMetadataModel,
|
||||
temperature: cfg.Temperature,
|
||||
headers: cfg.Headers,
|
||||
httpClient: cfg.HTTPClient,
|
||||
log: cfg.Log,
|
||||
dimensions: cfg.Dimensions,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,8 +149,24 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype
|
||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: input must not be empty", c.name)
|
||||
}
|
||||
|
||||
result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel)
|
||||
if err != nil && c.fallbackMetadataModel != "" && ctx.Err() == nil {
|
||||
if c.log != nil {
|
||||
c.log.Warn("metadata extraction failed, trying fallback model",
|
||||
slog.String("provider", c.name),
|
||||
slog.String("primary_model", c.metadataModel),
|
||||
slog.String("fallback_model", c.fallbackMetadataModel),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
return c.extractMetadataWithModel(ctx, input, c.fallbackMetadataModel)
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) {
|
||||
req := chatCompletionsRequest{
|
||||
Model: c.metadataModel,
|
||||
Model: model,
|
||||
Temperature: c.temperature,
|
||||
ResponseFormat: &responseType{
|
||||
Type: "json_object",
|
||||
@@ -172,6 +191,9 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype
|
||||
metadataText := strings.TrimSpace(resp.Choices[0].Message.Content)
|
||||
metadataText = stripCodeFence(metadataText)
|
||||
metadataText = extractJSONObject(metadataText)
|
||||
if metadataText == "" {
|
||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: response contains no JSON object", c.name)
|
||||
}
|
||||
|
||||
var metadata thoughttypes.ThoughtMetadata
|
||||
if err := json.Unmarshal([]byte(metadataText), &metadata); err != nil {
|
||||
@@ -241,7 +263,7 @@ func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("%s request failed: %w", c.name, err)
|
||||
if attempt < maxAttempts && isRetryableError(err) {
|
||||
if attempt < maxAttempts && ctx.Err() == nil && isRetryableError(err) {
|
||||
if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil {
|
||||
return retryErr
|
||||
}
|
||||
@@ -293,7 +315,7 @@ func extractJSONObject(s string) string {
|
||||
start := strings.Index(s, "{")
|
||||
end := strings.LastIndex(s, "}")
|
||||
if start == -1 || end == -1 || end <= start {
|
||||
return s
|
||||
return ""
|
||||
}
|
||||
return s[start : end+1]
|
||||
}
|
||||
|
||||
@@ -9,16 +9,21 @@ import (
|
||||
)
|
||||
|
||||
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
|
||||
fallback := cfg.LiteLLM.FallbackMetadataModel
|
||||
if fallback == "" {
|
||||
fallback = cfg.Metadata.FallbackModel
|
||||
}
|
||||
return compat.New(compat.Config{
|
||||
Name: "litellm",
|
||||
BaseURL: cfg.LiteLLM.BaseURL,
|
||||
APIKey: cfg.LiteLLM.APIKey,
|
||||
EmbeddingModel: cfg.LiteLLM.EmbeddingModel,
|
||||
MetadataModel: cfg.LiteLLM.MetadataModel,
|
||||
Temperature: cfg.Metadata.Temperature,
|
||||
Headers: cfg.LiteLLM.RequestHeaders,
|
||||
HTTPClient: httpClient,
|
||||
Log: log,
|
||||
Dimensions: cfg.Embeddings.Dimensions,
|
||||
Name: "litellm",
|
||||
BaseURL: cfg.LiteLLM.BaseURL,
|
||||
APIKey: cfg.LiteLLM.APIKey,
|
||||
EmbeddingModel: cfg.LiteLLM.EmbeddingModel,
|
||||
MetadataModel: cfg.LiteLLM.MetadataModel,
|
||||
FallbackMetadataModel: fallback,
|
||||
Temperature: cfg.Metadata.Temperature,
|
||||
Headers: cfg.LiteLLM.RequestHeaders,
|
||||
HTTPClient: httpClient,
|
||||
Log: log,
|
||||
Dimensions: cfg.Embeddings.Dimensions,
|
||||
}), nil
|
||||
}
|
||||
|
||||
@@ -10,15 +10,16 @@ import (
|
||||
|
||||
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
|
||||
return compat.New(compat.Config{
|
||||
Name: "ollama",
|
||||
BaseURL: cfg.Ollama.BaseURL,
|
||||
APIKey: cfg.Ollama.APIKey,
|
||||
EmbeddingModel: cfg.Embeddings.Model,
|
||||
MetadataModel: cfg.Metadata.Model,
|
||||
Temperature: cfg.Metadata.Temperature,
|
||||
Headers: cfg.Ollama.RequestHeaders,
|
||||
HTTPClient: httpClient,
|
||||
Log: log,
|
||||
Dimensions: cfg.Embeddings.Dimensions,
|
||||
Name: "ollama",
|
||||
BaseURL: cfg.Ollama.BaseURL,
|
||||
APIKey: cfg.Ollama.APIKey,
|
||||
EmbeddingModel: cfg.Embeddings.Model,
|
||||
MetadataModel: cfg.Metadata.Model,
|
||||
FallbackMetadataModel: cfg.Metadata.FallbackModel,
|
||||
Temperature: cfg.Metadata.Temperature,
|
||||
Headers: cfg.Ollama.RequestHeaders,
|
||||
HTTPClient: httpClient,
|
||||
Log: log,
|
||||
Dimensions: cfg.Embeddings.Dimensions,
|
||||
}), nil
|
||||
}
|
||||
|
||||
@@ -21,15 +21,16 @@ func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compa
|
||||
}
|
||||
|
||||
return compat.New(compat.Config{
|
||||
Name: "openrouter",
|
||||
BaseURL: cfg.OpenRouter.BaseURL,
|
||||
APIKey: cfg.OpenRouter.APIKey,
|
||||
EmbeddingModel: cfg.Embeddings.Model,
|
||||
MetadataModel: cfg.Metadata.Model,
|
||||
Temperature: cfg.Metadata.Temperature,
|
||||
Headers: headers,
|
||||
HTTPClient: httpClient,
|
||||
Log: log,
|
||||
Dimensions: cfg.Embeddings.Dimensions,
|
||||
Name: "openrouter",
|
||||
BaseURL: cfg.OpenRouter.BaseURL,
|
||||
APIKey: cfg.OpenRouter.APIKey,
|
||||
EmbeddingModel: cfg.Embeddings.Model,
|
||||
MetadataModel: cfg.Metadata.Model,
|
||||
FallbackMetadataModel: cfg.Metadata.FallbackModel,
|
||||
Temperature: cfg.Metadata.Temperature,
|
||||
Headers: headers,
|
||||
HTTPClient: httpClient,
|
||||
Log: log,
|
||||
Dimensions: cfg.Embeddings.Dimensions,
|
||||
}), nil
|
||||
}
|
||||
|
||||
@@ -84,17 +84,19 @@ type AIEmbeddingConfig struct {
|
||||
}
|
||||
|
||||
type AIMetadataConfig struct {
|
||||
Model string `yaml:"model"`
|
||||
Temperature float64 `yaml:"temperature"`
|
||||
Model string `yaml:"model"`
|
||||
FallbackModel string `yaml:"fallback_model"`
|
||||
Temperature float64 `yaml:"temperature"`
|
||||
}
|
||||
|
||||
type LiteLLMConfig struct {
|
||||
BaseURL string `yaml:"base_url"`
|
||||
APIKey string `yaml:"api_key"`
|
||||
UseResponsesAPI bool `yaml:"use_responses_api"`
|
||||
RequestHeaders map[string]string `yaml:"request_headers"`
|
||||
EmbeddingModel string `yaml:"embedding_model"`
|
||||
MetadataModel string `yaml:"metadata_model"`
|
||||
BaseURL string `yaml:"base_url"`
|
||||
APIKey string `yaml:"api_key"`
|
||||
UseResponsesAPI bool `yaml:"use_responses_api"`
|
||||
RequestHeaders map[string]string `yaml:"request_headers"`
|
||||
EmbeddingModel string `yaml:"embedding_model"`
|
||||
MetadataModel string `yaml:"metadata_model"`
|
||||
FallbackMetadataModel string `yaml:"fallback_metadata_model"`
|
||||
}
|
||||
|
||||
type OllamaConfig struct {
|
||||
|
||||
Reference in New Issue
Block a user