test(config): add migration tests for litellm provider
Some checks failed
CI / build-and-test (push) Failing after -32m22s
Some checks failed
CI / build-and-test (push) Failing after -32m22s
* Implement tests for migrating configuration from v1 to v2 for the litellm provider. * Validate the structure and values of the migrated configuration. * Ensure migration rejects newer versions of the configuration. fix(validate): enhance AI provider validation logic * Consolidate provider validation into a dedicated method. * Ensure at least one provider is specified and validate its type. * Check for required fields based on provider type. fix(mcpserver): update tool set to use new enrichment tool * Replace RetryMetadataTool with RetryEnrichmentTool in the ToolSet. fix(tools): refactor tools to use embedding and metadata runners * Update tools to utilize EmbeddingRunner and MetadataRunner instead of Provider. * Adjust method calls to align with the new runner interfaces.
This commit is contained in:
@@ -2,6 +2,7 @@ package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -12,6 +13,12 @@ import (
|
||||
)
|
||||
|
||||
func Load(explicitPath string) (*Config, string, error) {
|
||||
return LoadWithLogger(explicitPath, nil)
|
||||
}
|
||||
|
||||
// LoadWithLogger is Load with a logger surface for migration notices. Passing
|
||||
// nil is fine — migration events will simply not be logged.
|
||||
func LoadWithLogger(explicitPath string, log *slog.Logger) (*Config, string, error) {
|
||||
path := ResolvePath(explicitPath)
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
@@ -19,10 +26,40 @@ func Load(explicitPath string) (*Config, string, error) {
|
||||
return nil, path, fmt.Errorf("read config %q: %w", path, err)
|
||||
}
|
||||
|
||||
cfg := defaultConfig()
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
raw := map[string]any{}
|
||||
if err := yaml.Unmarshal(data, &raw); err != nil {
|
||||
return nil, path, fmt.Errorf("decode config %q: %w", path, err)
|
||||
}
|
||||
if raw == nil {
|
||||
raw = map[string]any{}
|
||||
}
|
||||
|
||||
applied, err := Migrate(raw)
|
||||
if err != nil {
|
||||
return nil, path, fmt.Errorf("migrate config %q: %w", path, err)
|
||||
}
|
||||
|
||||
if len(applied) > 0 {
|
||||
if err := rewriteConfigFile(path, data, raw); err != nil {
|
||||
return nil, path, err
|
||||
}
|
||||
if log != nil {
|
||||
for _, step := range applied {
|
||||
log.Warn("config migrated",
|
||||
slog.String("path", path),
|
||||
slog.Int("from_version", step.From),
|
||||
slog.Int("to_version", step.To),
|
||||
slog.String("describe", step.Describe),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cfg, err := decodeTyped(raw)
|
||||
if err != nil {
|
||||
return nil, path, fmt.Errorf("decode migrated config %q: %w", path, err)
|
||||
}
|
||||
cfg.Version = CurrentConfigVersion
|
||||
|
||||
applyEnvOverrides(&cfg)
|
||||
if err := cfg.Validate(); err != nil {
|
||||
@@ -32,6 +69,34 @@ func Load(explicitPath string) (*Config, string, error) {
|
||||
return &cfg, path, nil
|
||||
}
|
||||
|
||||
func decodeTyped(raw map[string]any) (Config, error) {
|
||||
out, err := yaml.Marshal(raw)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("re-marshal migrated config: %w", err)
|
||||
}
|
||||
cfg := defaultConfig()
|
||||
if err := yaml.Unmarshal(out, &cfg); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func rewriteConfigFile(path string, original []byte, migrated map[string]any) error {
|
||||
backupPath := fmt.Sprintf("%s.bak.%d", path, time.Now().Unix())
|
||||
if err := os.WriteFile(backupPath, original, 0o600); err != nil {
|
||||
return fmt.Errorf("write backup %q: %w", backupPath, err)
|
||||
}
|
||||
|
||||
out, err := yaml.Marshal(migrated)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal migrated config: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(path, out, 0o600); err != nil {
|
||||
return fmt.Errorf("write migrated config %q: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ResolvePath(explicitPath string) string {
|
||||
if path := strings.TrimSpace(explicitPath); path != "" {
|
||||
if path != ".yaml" && path != ".yml" {
|
||||
@@ -49,6 +114,7 @@ func ResolvePath(explicitPath string) string {
|
||||
func defaultConfig() Config {
|
||||
info := buildinfo.Current()
|
||||
return Config{
|
||||
Version: CurrentConfigVersion,
|
||||
Server: ServerConfig{
|
||||
Host: "0.0.0.0",
|
||||
Port: 8080,
|
||||
@@ -69,20 +135,14 @@ func defaultConfig() Config {
|
||||
QueryParam: "key",
|
||||
},
|
||||
AI: AIConfig{
|
||||
Provider: "litellm",
|
||||
Embeddings: AIEmbeddingConfig{
|
||||
Model: "openai/text-embedding-3-small",
|
||||
Providers: map[string]ProviderConfig{},
|
||||
Embeddings: EmbeddingsRoleConfig{
|
||||
Dimensions: 1536,
|
||||
},
|
||||
Metadata: AIMetadataConfig{
|
||||
Model: "gpt-4o-mini",
|
||||
Metadata: MetadataRoleConfig{
|
||||
Temperature: 0.1,
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
Ollama: OllamaConfig{
|
||||
BaseURL: "http://localhost:11434/v1",
|
||||
APIKey: "ollama",
|
||||
},
|
||||
},
|
||||
Capture: CaptureConfig{
|
||||
Source: DefaultSource,
|
||||
@@ -119,11 +179,12 @@ func defaultConfig() Config {
|
||||
func applyEnvOverrides(cfg *Config) {
|
||||
overrideString(&cfg.Database.URL, "AMCS_DATABASE_URL")
|
||||
overrideString(&cfg.MCP.PublicURL, "AMCS_PUBLIC_URL")
|
||||
overrideString(&cfg.AI.LiteLLM.BaseURL, "AMCS_LITELLM_BASE_URL")
|
||||
overrideString(&cfg.AI.LiteLLM.APIKey, "AMCS_LITELLM_API_KEY")
|
||||
overrideString(&cfg.AI.Ollama.BaseURL, "AMCS_OLLAMA_BASE_URL")
|
||||
overrideString(&cfg.AI.Ollama.APIKey, "AMCS_OLLAMA_API_KEY")
|
||||
overrideString(&cfg.AI.OpenRouter.APIKey, "AMCS_OPENROUTER_API_KEY")
|
||||
|
||||
overrideProviderField(cfg, "AMCS_LITELLM_BASE_URL", "litellm", func(p *ProviderConfig, v string) { p.BaseURL = v })
|
||||
overrideProviderField(cfg, "AMCS_LITELLM_API_KEY", "litellm", func(p *ProviderConfig, v string) { p.APIKey = v })
|
||||
overrideProviderField(cfg, "AMCS_OLLAMA_BASE_URL", "ollama", func(p *ProviderConfig, v string) { p.BaseURL = v })
|
||||
overrideProviderField(cfg, "AMCS_OLLAMA_API_KEY", "ollama", func(p *ProviderConfig, v string) { p.APIKey = v })
|
||||
overrideProviderField(cfg, "AMCS_OPENROUTER_API_KEY", "openrouter", func(p *ProviderConfig, v string) { p.APIKey = v })
|
||||
|
||||
if value, ok := os.LookupEnv("AMCS_SERVER_PORT"); ok {
|
||||
if port, err := strconv.Atoi(strings.TrimSpace(value)); err == nil {
|
||||
@@ -132,6 +193,24 @@ func applyEnvOverrides(cfg *Config) {
|
||||
}
|
||||
}
|
||||
|
||||
// overrideProviderField applies an env var to every configured provider of the
|
||||
// given type. This preserves the v1 behaviour where e.g. AMCS_LITELLM_API_KEY
|
||||
// rewrote the single litellm block — in v2 it rewrites every litellm provider.
|
||||
func overrideProviderField(cfg *Config, envKey, providerType string, apply func(*ProviderConfig, string)) {
|
||||
value, ok := os.LookupEnv(envKey)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
value = strings.TrimSpace(value)
|
||||
for name, p := range cfg.AI.Providers {
|
||||
if p.Type != providerType {
|
||||
continue
|
||||
}
|
||||
apply(&p, value)
|
||||
cfg.AI.Providers[name] = p
|
||||
}
|
||||
}
|
||||
|
||||
func overrideString(target *string, envKey string) {
|
||||
if value, ok := os.LookupEnv(envKey); ok {
|
||||
*target = strings.TrimSpace(value)
|
||||
|
||||
Reference in New Issue
Block a user