feat(config): add fallback model support for AI configurations
This commit is contained in:
@@ -42,6 +42,7 @@ ai:
|
|||||||
dimensions: 1536
|
dimensions: 1536
|
||||||
metadata:
|
metadata:
|
||||||
model: "gpt-4o-mini"
|
model: "gpt-4o-mini"
|
||||||
|
fallback_model: ""
|
||||||
temperature: 0.1
|
temperature: 0.1
|
||||||
litellm:
|
litellm:
|
||||||
base_url: "http://localhost:4000/v1"
|
base_url: "http://localhost:4000/v1"
|
||||||
@@ -50,6 +51,7 @@ ai:
|
|||||||
request_headers: {}
|
request_headers: {}
|
||||||
embedding_model: "openrouter/openai/text-embedding-3-small"
|
embedding_model: "openrouter/openai/text-embedding-3-small"
|
||||||
metadata_model: "gpt-4o-mini"
|
metadata_model: "gpt-4o-mini"
|
||||||
|
fallback_metadata_model: ""
|
||||||
ollama:
|
ollama:
|
||||||
base_url: "http://localhost:11434/v1"
|
base_url: "http://localhost:11434/v1"
|
||||||
api_key: "ollama"
|
api_key: "ollama"
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ type Client struct {
|
|||||||
apiKey string
|
apiKey string
|
||||||
embeddingModel string
|
embeddingModel string
|
||||||
metadataModel string
|
metadataModel string
|
||||||
|
fallbackMetadataModel string
|
||||||
temperature float64
|
temperature float64
|
||||||
headers map[string]string
|
headers map[string]string
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
@@ -51,6 +52,7 @@ type Config struct {
|
|||||||
APIKey string
|
APIKey string
|
||||||
EmbeddingModel string
|
EmbeddingModel string
|
||||||
MetadataModel string
|
MetadataModel string
|
||||||
|
FallbackMetadataModel string
|
||||||
Temperature float64
|
Temperature float64
|
||||||
Headers map[string]string
|
Headers map[string]string
|
||||||
HTTPClient *http.Client
|
HTTPClient *http.Client
|
||||||
@@ -105,6 +107,7 @@ func New(cfg Config) *Client {
|
|||||||
apiKey: cfg.APIKey,
|
apiKey: cfg.APIKey,
|
||||||
embeddingModel: cfg.EmbeddingModel,
|
embeddingModel: cfg.EmbeddingModel,
|
||||||
metadataModel: cfg.MetadataModel,
|
metadataModel: cfg.MetadataModel,
|
||||||
|
fallbackMetadataModel: cfg.FallbackMetadataModel,
|
||||||
temperature: cfg.Temperature,
|
temperature: cfg.Temperature,
|
||||||
headers: cfg.Headers,
|
headers: cfg.Headers,
|
||||||
httpClient: cfg.HTTPClient,
|
httpClient: cfg.HTTPClient,
|
||||||
@@ -146,8 +149,24 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype
|
|||||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: input must not be empty", c.name)
|
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: input must not be empty", c.name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel)
|
||||||
|
if err != nil && c.fallbackMetadataModel != "" && ctx.Err() == nil {
|
||||||
|
if c.log != nil {
|
||||||
|
c.log.Warn("metadata extraction failed, trying fallback model",
|
||||||
|
slog.String("provider", c.name),
|
||||||
|
slog.String("primary_model", c.metadataModel),
|
||||||
|
slog.String("fallback_model", c.fallbackMetadataModel),
|
||||||
|
slog.String("error", err.Error()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return c.extractMetadataWithModel(ctx, input, c.fallbackMetadataModel)
|
||||||
|
}
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) {
|
||||||
req := chatCompletionsRequest{
|
req := chatCompletionsRequest{
|
||||||
Model: c.metadataModel,
|
Model: model,
|
||||||
Temperature: c.temperature,
|
Temperature: c.temperature,
|
||||||
ResponseFormat: &responseType{
|
ResponseFormat: &responseType{
|
||||||
Type: "json_object",
|
Type: "json_object",
|
||||||
@@ -172,6 +191,9 @@ func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttype
|
|||||||
metadataText := strings.TrimSpace(resp.Choices[0].Message.Content)
|
metadataText := strings.TrimSpace(resp.Choices[0].Message.Content)
|
||||||
metadataText = stripCodeFence(metadataText)
|
metadataText = stripCodeFence(metadataText)
|
||||||
metadataText = extractJSONObject(metadataText)
|
metadataText = extractJSONObject(metadataText)
|
||||||
|
if metadataText == "" {
|
||||||
|
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: response contains no JSON object", c.name)
|
||||||
|
}
|
||||||
|
|
||||||
var metadata thoughttypes.ThoughtMetadata
|
var metadata thoughttypes.ThoughtMetadata
|
||||||
if err := json.Unmarshal([]byte(metadataText), &metadata); err != nil {
|
if err := json.Unmarshal([]byte(metadataText), &metadata); err != nil {
|
||||||
@@ -241,7 +263,7 @@ func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest
|
|||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lastErr = fmt.Errorf("%s request failed: %w", c.name, err)
|
lastErr = fmt.Errorf("%s request failed: %w", c.name, err)
|
||||||
if attempt < maxAttempts && isRetryableError(err) {
|
if attempt < maxAttempts && ctx.Err() == nil && isRetryableError(err) {
|
||||||
if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil {
|
if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil {
|
||||||
return retryErr
|
return retryErr
|
||||||
}
|
}
|
||||||
@@ -293,7 +315,7 @@ func extractJSONObject(s string) string {
|
|||||||
start := strings.Index(s, "{")
|
start := strings.Index(s, "{")
|
||||||
end := strings.LastIndex(s, "}")
|
end := strings.LastIndex(s, "}")
|
||||||
if start == -1 || end == -1 || end <= start {
|
if start == -1 || end == -1 || end <= start {
|
||||||
return s
|
return ""
|
||||||
}
|
}
|
||||||
return s[start : end+1]
|
return s[start : end+1]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,12 +9,17 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
|
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
|
||||||
|
fallback := cfg.LiteLLM.FallbackMetadataModel
|
||||||
|
if fallback == "" {
|
||||||
|
fallback = cfg.Metadata.FallbackModel
|
||||||
|
}
|
||||||
return compat.New(compat.Config{
|
return compat.New(compat.Config{
|
||||||
Name: "litellm",
|
Name: "litellm",
|
||||||
BaseURL: cfg.LiteLLM.BaseURL,
|
BaseURL: cfg.LiteLLM.BaseURL,
|
||||||
APIKey: cfg.LiteLLM.APIKey,
|
APIKey: cfg.LiteLLM.APIKey,
|
||||||
EmbeddingModel: cfg.LiteLLM.EmbeddingModel,
|
EmbeddingModel: cfg.LiteLLM.EmbeddingModel,
|
||||||
MetadataModel: cfg.LiteLLM.MetadataModel,
|
MetadataModel: cfg.LiteLLM.MetadataModel,
|
||||||
|
FallbackMetadataModel: fallback,
|
||||||
Temperature: cfg.Metadata.Temperature,
|
Temperature: cfg.Metadata.Temperature,
|
||||||
Headers: cfg.LiteLLM.RequestHeaders,
|
Headers: cfg.LiteLLM.RequestHeaders,
|
||||||
HTTPClient: httpClient,
|
HTTPClient: httpClient,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compa
|
|||||||
APIKey: cfg.Ollama.APIKey,
|
APIKey: cfg.Ollama.APIKey,
|
||||||
EmbeddingModel: cfg.Embeddings.Model,
|
EmbeddingModel: cfg.Embeddings.Model,
|
||||||
MetadataModel: cfg.Metadata.Model,
|
MetadataModel: cfg.Metadata.Model,
|
||||||
|
FallbackMetadataModel: cfg.Metadata.FallbackModel,
|
||||||
Temperature: cfg.Metadata.Temperature,
|
Temperature: cfg.Metadata.Temperature,
|
||||||
Headers: cfg.Ollama.RequestHeaders,
|
Headers: cfg.Ollama.RequestHeaders,
|
||||||
HTTPClient: httpClient,
|
HTTPClient: httpClient,
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compa
|
|||||||
APIKey: cfg.OpenRouter.APIKey,
|
APIKey: cfg.OpenRouter.APIKey,
|
||||||
EmbeddingModel: cfg.Embeddings.Model,
|
EmbeddingModel: cfg.Embeddings.Model,
|
||||||
MetadataModel: cfg.Metadata.Model,
|
MetadataModel: cfg.Metadata.Model,
|
||||||
|
FallbackMetadataModel: cfg.Metadata.FallbackModel,
|
||||||
Temperature: cfg.Metadata.Temperature,
|
Temperature: cfg.Metadata.Temperature,
|
||||||
Headers: headers,
|
Headers: headers,
|
||||||
HTTPClient: httpClient,
|
HTTPClient: httpClient,
|
||||||
|
|||||||
@@ -85,6 +85,7 @@ type AIEmbeddingConfig struct {
|
|||||||
|
|
||||||
type AIMetadataConfig struct {
|
type AIMetadataConfig struct {
|
||||||
Model string `yaml:"model"`
|
Model string `yaml:"model"`
|
||||||
|
FallbackModel string `yaml:"fallback_model"`
|
||||||
Temperature float64 `yaml:"temperature"`
|
Temperature float64 `yaml:"temperature"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,6 +96,7 @@ type LiteLLMConfig struct {
|
|||||||
RequestHeaders map[string]string `yaml:"request_headers"`
|
RequestHeaders map[string]string `yaml:"request_headers"`
|
||||||
EmbeddingModel string `yaml:"embedding_model"`
|
EmbeddingModel string `yaml:"embedding_model"`
|
||||||
MetadataModel string `yaml:"metadata_model"`
|
MetadataModel string `yaml:"metadata_model"`
|
||||||
|
FallbackMetadataModel string `yaml:"fallback_metadata_model"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OllamaConfig struct {
|
type OllamaConfig struct {
|
||||||
|
|||||||
Reference in New Issue
Block a user