test(config): add migration tests for litellm provider
Some checks failed
CI / build-and-test (push) Failing after -32m22s
Some checks failed
CI / build-and-test (push) Failing after -32m22s
* Implement tests for migrating configuration from v1 to v2 for the litellm provider. * Validate the structure and values of the migrated configuration. * Ensure migration rejects newer versions of the configuration. fix(validate): enhance AI provider validation logic * Consolidate provider validation into a dedicated method. * Ensure at least one provider is specified and validate its type. * Check for required fields based on provider type. fix(mcpserver): update tool set to use new enrichment tool * Replace RetryMetadataTool with RetryEnrichmentTool in the ToolSet. fix(tools): refactor tools to use embedding and metadata runners * Update tools to utilize EmbeddingRunner and MetadataRunner instead of Provider. * Adjust method calls to align with the new runner interfaces.
This commit is contained in:
@@ -8,6 +8,7 @@ const (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Version int `yaml:"version"`
|
||||
Server ServerConfig `yaml:"server"`
|
||||
MCP MCPConfig `yaml:"mcp"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
@@ -37,11 +38,8 @@ type MCPConfig struct {
|
||||
Version string `yaml:"version"`
|
||||
Transport string `yaml:"transport"`
|
||||
SessionTimeout time.Duration `yaml:"session_timeout"`
|
||||
// PublicURL is the externally reachable base URL of this server (e.g. https://amcs.example.com).
|
||||
// When set, it is used to build absolute icon URLs in the MCP server identity.
|
||||
PublicURL string `yaml:"public_url"`
|
||||
// Instructions is set at startup from the embedded memory.md and sent to MCP clients on initialise.
|
||||
Instructions string `yaml:"-"`
|
||||
PublicURL string `yaml:"public_url"`
|
||||
Instructions string `yaml:"-"`
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
@@ -77,52 +75,82 @@ type DatabaseConfig struct {
|
||||
MaxConnIdleTime time.Duration `yaml:"max_conn_idle_time"`
|
||||
}
|
||||
|
||||
// AIConfig (v2): named providers + per-role chains.
|
||||
type AIConfig struct {
|
||||
Provider string `yaml:"provider"`
|
||||
Embeddings AIEmbeddingConfig `yaml:"embeddings"`
|
||||
Metadata AIMetadataConfig `yaml:"metadata"`
|
||||
LiteLLM LiteLLMConfig `yaml:"litellm"`
|
||||
Ollama OllamaConfig `yaml:"ollama"`
|
||||
OpenRouter OpenRouterAIConfig `yaml:"openrouter"`
|
||||
Providers map[string]ProviderConfig `yaml:"providers"`
|
||||
Embeddings EmbeddingsRoleConfig `yaml:"embeddings"`
|
||||
Metadata MetadataRoleConfig `yaml:"metadata"`
|
||||
Background *BackgroundRolesConfig `yaml:"background,omitempty"`
|
||||
}
|
||||
|
||||
type AIEmbeddingConfig struct {
|
||||
Model string `yaml:"model"`
|
||||
Dimensions int `yaml:"dimensions"`
|
||||
type ProviderConfig struct {
|
||||
Type string `yaml:"type"`
|
||||
BaseURL string `yaml:"base_url"`
|
||||
APIKey string `yaml:"api_key"`
|
||||
RequestHeaders map[string]string `yaml:"request_headers,omitempty"`
|
||||
AppName string `yaml:"app_name,omitempty"`
|
||||
SiteURL string `yaml:"site_url,omitempty"`
|
||||
}
|
||||
|
||||
type AIMetadataConfig struct {
|
||||
Model string `yaml:"model"`
|
||||
FallbackModels []string `yaml:"fallback_models"`
|
||||
FallbackModel string `yaml:"fallback_model"` // legacy single fallback
|
||||
type RoleTarget struct {
|
||||
Provider string `yaml:"provider"`
|
||||
Model string `yaml:"model"`
|
||||
}
|
||||
|
||||
type RoleChain struct {
|
||||
Primary RoleTarget `yaml:"primary"`
|
||||
Fallbacks []RoleTarget `yaml:"fallbacks,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingsRoleConfig struct {
|
||||
Dimensions int `yaml:"dimensions"`
|
||||
Primary RoleTarget `yaml:"primary"`
|
||||
Fallbacks []RoleTarget `yaml:"fallbacks,omitempty"`
|
||||
}
|
||||
|
||||
type MetadataRoleConfig struct {
|
||||
Temperature float64 `yaml:"temperature"`
|
||||
LogConversations bool `yaml:"log_conversations"`
|
||||
Timeout time.Duration `yaml:"timeout"`
|
||||
Primary RoleTarget `yaml:"primary"`
|
||||
Fallbacks []RoleTarget `yaml:"fallbacks,omitempty"`
|
||||
}
|
||||
|
||||
type LiteLLMConfig struct {
|
||||
BaseURL string `yaml:"base_url"`
|
||||
APIKey string `yaml:"api_key"`
|
||||
UseResponsesAPI bool `yaml:"use_responses_api"`
|
||||
RequestHeaders map[string]string `yaml:"request_headers"`
|
||||
EmbeddingModel string `yaml:"embedding_model"`
|
||||
MetadataModel string `yaml:"metadata_model"`
|
||||
FallbackMetadataModels []string `yaml:"fallback_metadata_models"`
|
||||
FallbackMetadataModel string `yaml:"fallback_metadata_model"` // legacy single fallback
|
||||
// BackgroundRolesConfig overrides the foreground chains for background workers
|
||||
// (backfill_embeddings, metadata_retry, reparse_metadata). Either field may be
|
||||
// nil to inherit the foreground role unchanged.
|
||||
type BackgroundRolesConfig struct {
|
||||
Embeddings *RoleChain `yaml:"embeddings,omitempty"`
|
||||
Metadata *RoleChain `yaml:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type OllamaConfig struct {
|
||||
BaseURL string `yaml:"base_url"`
|
||||
APIKey string `yaml:"api_key"`
|
||||
RequestHeaders map[string]string `yaml:"request_headers"`
|
||||
// Chain returns primary followed by fallbacks (deduped, blanks dropped).
|
||||
func (e EmbeddingsRoleConfig) Chain() []RoleTarget {
|
||||
return dedupeTargets(append([]RoleTarget{e.Primary}, e.Fallbacks...))
|
||||
}
|
||||
|
||||
type OpenRouterAIConfig struct {
|
||||
BaseURL string `yaml:"base_url"`
|
||||
APIKey string `yaml:"api_key"`
|
||||
AppName string `yaml:"app_name"`
|
||||
SiteURL string `yaml:"site_url"`
|
||||
ExtraHeaders map[string]string `yaml:"extra_headers"`
|
||||
func (m MetadataRoleConfig) Chain() []RoleTarget {
|
||||
return dedupeTargets(append([]RoleTarget{m.Primary}, m.Fallbacks...))
|
||||
}
|
||||
|
||||
func (c RoleChain) AsTargets() []RoleTarget {
|
||||
return dedupeTargets(append([]RoleTarget{c.Primary}, c.Fallbacks...))
|
||||
}
|
||||
|
||||
func dedupeTargets(in []RoleTarget) []RoleTarget {
|
||||
out := make([]RoleTarget, 0, len(in))
|
||||
seen := make(map[RoleTarget]struct{}, len(in))
|
||||
for _, t := range in {
|
||||
if t.Provider == "" || t.Model == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[t]; ok {
|
||||
continue
|
||||
}
|
||||
seen[t] = struct{}{}
|
||||
out = append(out, t)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
type CaptureConfig struct {
|
||||
@@ -167,45 +195,3 @@ type MetadataRetryConfig struct {
|
||||
MaxPerRun int `yaml:"max_per_run"`
|
||||
IncludeArchived bool `yaml:"include_archived"`
|
||||
}
|
||||
|
||||
func (c AIMetadataConfig) EffectiveFallbackModels() []string {
|
||||
models := make([]string, 0, len(c.FallbackModels)+1)
|
||||
for _, model := range c.FallbackModels {
|
||||
if model != "" {
|
||||
models = append(models, model)
|
||||
}
|
||||
}
|
||||
if c.FallbackModel != "" {
|
||||
models = append(models, c.FallbackModel)
|
||||
}
|
||||
return dedupeNonEmpty(models)
|
||||
}
|
||||
|
||||
func (c LiteLLMConfig) EffectiveFallbackMetadataModels() []string {
|
||||
models := make([]string, 0, len(c.FallbackMetadataModels)+1)
|
||||
for _, model := range c.FallbackMetadataModels {
|
||||
if model != "" {
|
||||
models = append(models, model)
|
||||
}
|
||||
}
|
||||
if c.FallbackMetadataModel != "" {
|
||||
models = append(models, c.FallbackMetadataModel)
|
||||
}
|
||||
return dedupeNonEmpty(models)
|
||||
}
|
||||
|
||||
func dedupeNonEmpty(values []string) []string {
|
||||
seen := make(map[string]struct{}, len(values))
|
||||
out := make([]string, 0, len(values))
|
||||
for _, value := range values {
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[value]; ok {
|
||||
continue
|
||||
}
|
||||
seen[value] = struct{}{}
|
||||
out = append(out, value)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -12,6 +13,12 @@ import (
|
||||
)
|
||||
|
||||
func Load(explicitPath string) (*Config, string, error) {
|
||||
return LoadWithLogger(explicitPath, nil)
|
||||
}
|
||||
|
||||
// LoadWithLogger is Load with a logger surface for migration notices. Passing
|
||||
// nil is fine — migration events will simply not be logged.
|
||||
func LoadWithLogger(explicitPath string, log *slog.Logger) (*Config, string, error) {
|
||||
path := ResolvePath(explicitPath)
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
@@ -19,10 +26,40 @@ func Load(explicitPath string) (*Config, string, error) {
|
||||
return nil, path, fmt.Errorf("read config %q: %w", path, err)
|
||||
}
|
||||
|
||||
cfg := defaultConfig()
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
raw := map[string]any{}
|
||||
if err := yaml.Unmarshal(data, &raw); err != nil {
|
||||
return nil, path, fmt.Errorf("decode config %q: %w", path, err)
|
||||
}
|
||||
if raw == nil {
|
||||
raw = map[string]any{}
|
||||
}
|
||||
|
||||
applied, err := Migrate(raw)
|
||||
if err != nil {
|
||||
return nil, path, fmt.Errorf("migrate config %q: %w", path, err)
|
||||
}
|
||||
|
||||
if len(applied) > 0 {
|
||||
if err := rewriteConfigFile(path, data, raw); err != nil {
|
||||
return nil, path, err
|
||||
}
|
||||
if log != nil {
|
||||
for _, step := range applied {
|
||||
log.Warn("config migrated",
|
||||
slog.String("path", path),
|
||||
slog.Int("from_version", step.From),
|
||||
slog.Int("to_version", step.To),
|
||||
slog.String("describe", step.Describe),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cfg, err := decodeTyped(raw)
|
||||
if err != nil {
|
||||
return nil, path, fmt.Errorf("decode migrated config %q: %w", path, err)
|
||||
}
|
||||
cfg.Version = CurrentConfigVersion
|
||||
|
||||
applyEnvOverrides(&cfg)
|
||||
if err := cfg.Validate(); err != nil {
|
||||
@@ -32,6 +69,34 @@ func Load(explicitPath string) (*Config, string, error) {
|
||||
return &cfg, path, nil
|
||||
}
|
||||
|
||||
func decodeTyped(raw map[string]any) (Config, error) {
|
||||
out, err := yaml.Marshal(raw)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("re-marshal migrated config: %w", err)
|
||||
}
|
||||
cfg := defaultConfig()
|
||||
if err := yaml.Unmarshal(out, &cfg); err != nil {
|
||||
return Config{}, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func rewriteConfigFile(path string, original []byte, migrated map[string]any) error {
|
||||
backupPath := fmt.Sprintf("%s.bak.%d", path, time.Now().Unix())
|
||||
if err := os.WriteFile(backupPath, original, 0o600); err != nil {
|
||||
return fmt.Errorf("write backup %q: %w", backupPath, err)
|
||||
}
|
||||
|
||||
out, err := yaml.Marshal(migrated)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal migrated config: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(path, out, 0o600); err != nil {
|
||||
return fmt.Errorf("write migrated config %q: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ResolvePath(explicitPath string) string {
|
||||
if path := strings.TrimSpace(explicitPath); path != "" {
|
||||
if path != ".yaml" && path != ".yml" {
|
||||
@@ -49,6 +114,7 @@ func ResolvePath(explicitPath string) string {
|
||||
func defaultConfig() Config {
|
||||
info := buildinfo.Current()
|
||||
return Config{
|
||||
Version: CurrentConfigVersion,
|
||||
Server: ServerConfig{
|
||||
Host: "0.0.0.0",
|
||||
Port: 8080,
|
||||
@@ -69,20 +135,14 @@ func defaultConfig() Config {
|
||||
QueryParam: "key",
|
||||
},
|
||||
AI: AIConfig{
|
||||
Provider: "litellm",
|
||||
Embeddings: AIEmbeddingConfig{
|
||||
Model: "openai/text-embedding-3-small",
|
||||
Providers: map[string]ProviderConfig{},
|
||||
Embeddings: EmbeddingsRoleConfig{
|
||||
Dimensions: 1536,
|
||||
},
|
||||
Metadata: AIMetadataConfig{
|
||||
Model: "gpt-4o-mini",
|
||||
Metadata: MetadataRoleConfig{
|
||||
Temperature: 0.1,
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
Ollama: OllamaConfig{
|
||||
BaseURL: "http://localhost:11434/v1",
|
||||
APIKey: "ollama",
|
||||
},
|
||||
},
|
||||
Capture: CaptureConfig{
|
||||
Source: DefaultSource,
|
||||
@@ -119,11 +179,12 @@ func defaultConfig() Config {
|
||||
func applyEnvOverrides(cfg *Config) {
|
||||
overrideString(&cfg.Database.URL, "AMCS_DATABASE_URL")
|
||||
overrideString(&cfg.MCP.PublicURL, "AMCS_PUBLIC_URL")
|
||||
overrideString(&cfg.AI.LiteLLM.BaseURL, "AMCS_LITELLM_BASE_URL")
|
||||
overrideString(&cfg.AI.LiteLLM.APIKey, "AMCS_LITELLM_API_KEY")
|
||||
overrideString(&cfg.AI.Ollama.BaseURL, "AMCS_OLLAMA_BASE_URL")
|
||||
overrideString(&cfg.AI.Ollama.APIKey, "AMCS_OLLAMA_API_KEY")
|
||||
overrideString(&cfg.AI.OpenRouter.APIKey, "AMCS_OPENROUTER_API_KEY")
|
||||
|
||||
overrideProviderField(cfg, "AMCS_LITELLM_BASE_URL", "litellm", func(p *ProviderConfig, v string) { p.BaseURL = v })
|
||||
overrideProviderField(cfg, "AMCS_LITELLM_API_KEY", "litellm", func(p *ProviderConfig, v string) { p.APIKey = v })
|
||||
overrideProviderField(cfg, "AMCS_OLLAMA_BASE_URL", "ollama", func(p *ProviderConfig, v string) { p.BaseURL = v })
|
||||
overrideProviderField(cfg, "AMCS_OLLAMA_API_KEY", "ollama", func(p *ProviderConfig, v string) { p.APIKey = v })
|
||||
overrideProviderField(cfg, "AMCS_OPENROUTER_API_KEY", "openrouter", func(p *ProviderConfig, v string) { p.APIKey = v })
|
||||
|
||||
if value, ok := os.LookupEnv("AMCS_SERVER_PORT"); ok {
|
||||
if port, err := strconv.Atoi(strings.TrimSpace(value)); err == nil {
|
||||
@@ -132,6 +193,24 @@ func applyEnvOverrides(cfg *Config) {
|
||||
}
|
||||
}
|
||||
|
||||
// overrideProviderField applies an env var to every configured provider of the
|
||||
// given type. This preserves the v1 behaviour where e.g. AMCS_LITELLM_API_KEY
|
||||
// rewrote the single litellm block — in v2 it rewrites every litellm provider.
|
||||
func overrideProviderField(cfg *Config, envKey, providerType string, apply func(*ProviderConfig, string)) {
|
||||
value, ok := os.LookupEnv(envKey)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
value = strings.TrimSpace(value)
|
||||
for name, p := range cfg.AI.Providers {
|
||||
if p.Type != providerType {
|
||||
continue
|
||||
}
|
||||
apply(&p, value)
|
||||
cfg.AI.Providers[name] = p
|
||||
}
|
||||
}
|
||||
|
||||
func overrideString(target *string, envKey string) {
|
||||
if value, ok := os.LookupEnv(envKey); ok {
|
||||
*target = strings.TrimSpace(value)
|
||||
|
||||
@@ -31,9 +31,8 @@ func TestResolvePathIgnoresBareYAMLExtension(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadAppliesEnvOverrides(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "test.yaml")
|
||||
if err := os.WriteFile(configPath, []byte(`
|
||||
const v2ConfigYAML = `
|
||||
version: 2
|
||||
server:
|
||||
port: 8080
|
||||
mcp:
|
||||
@@ -46,18 +45,30 @@ auth:
|
||||
database:
|
||||
url: "postgres://from-file"
|
||||
ai:
|
||||
provider: "litellm"
|
||||
providers:
|
||||
default:
|
||||
type: "litellm"
|
||||
base_url: "http://localhost:4000/v1"
|
||||
api_key: "file-key"
|
||||
embeddings:
|
||||
dimensions: 1536
|
||||
litellm:
|
||||
base_url: "http://localhost:4000/v1"
|
||||
api_key: "file-key"
|
||||
primary:
|
||||
provider: "default"
|
||||
model: "text-embed"
|
||||
metadata:
|
||||
primary:
|
||||
provider: "default"
|
||||
model: "gpt-4"
|
||||
search:
|
||||
default_limit: 10
|
||||
max_limit: 50
|
||||
logging:
|
||||
level: "info"
|
||||
`), 0o600); err != nil {
|
||||
`
|
||||
|
||||
func TestLoadAppliesEnvOverrides(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "test.yaml")
|
||||
if err := os.WriteFile(configPath, []byte(v2ConfigYAML), 0o600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
@@ -76,8 +87,8 @@ logging:
|
||||
if cfg.Database.URL != "postgres://from-env" {
|
||||
t.Fatalf("database url = %q, want env override", cfg.Database.URL)
|
||||
}
|
||||
if cfg.AI.LiteLLM.APIKey != "env-key" {
|
||||
t.Fatalf("litellm api key = %q, want env override", cfg.AI.LiteLLM.APIKey)
|
||||
if cfg.AI.Providers["default"].APIKey != "env-key" {
|
||||
t.Fatalf("litellm api key = %q, want env override", cfg.AI.Providers["default"].APIKey)
|
||||
}
|
||||
if cfg.Server.Port != 9090 {
|
||||
t.Fatalf("server port = %d, want 9090", cfg.Server.Port)
|
||||
@@ -90,10 +101,12 @@ logging:
|
||||
func TestLoadAppliesOllamaEnvOverrides(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "test.yaml")
|
||||
if err := os.WriteFile(configPath, []byte(`
|
||||
version: 2
|
||||
server:
|
||||
port: 8080
|
||||
mcp:
|
||||
path: "/mcp"
|
||||
session_timeout: "10m"
|
||||
auth:
|
||||
keys:
|
||||
- id: "test"
|
||||
@@ -101,15 +114,20 @@ auth:
|
||||
database:
|
||||
url: "postgres://from-file"
|
||||
ai:
|
||||
provider: "ollama"
|
||||
providers:
|
||||
local:
|
||||
type: "ollama"
|
||||
base_url: "http://localhost:11434/v1"
|
||||
api_key: "ollama"
|
||||
embeddings:
|
||||
model: "nomic-embed-text"
|
||||
dimensions: 768
|
||||
primary:
|
||||
provider: "local"
|
||||
model: "nomic-embed-text"
|
||||
metadata:
|
||||
model: "llama3.2"
|
||||
ollama:
|
||||
base_url: "http://localhost:11434/v1"
|
||||
api_key: "ollama"
|
||||
primary:
|
||||
provider: "local"
|
||||
model: "llama3.2"
|
||||
search:
|
||||
default_limit: 10
|
||||
max_limit: 50
|
||||
@@ -127,10 +145,77 @@ logging:
|
||||
t.Fatalf("Load() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.AI.Ollama.BaseURL != "https://ollama.example.com/v1" {
|
||||
t.Fatalf("ollama base url = %q, want env override", cfg.AI.Ollama.BaseURL)
|
||||
p := cfg.AI.Providers["local"]
|
||||
if p.BaseURL != "https://ollama.example.com/v1" {
|
||||
t.Fatalf("ollama base url = %q, want env override", p.BaseURL)
|
||||
}
|
||||
if cfg.AI.Ollama.APIKey != "remote-key" {
|
||||
t.Fatalf("ollama api key = %q, want env override", cfg.AI.Ollama.APIKey)
|
||||
if p.APIKey != "remote-key" {
|
||||
t.Fatalf("ollama api key = %q, want env override", p.APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadMigratesV1Config(t *testing.T) {
|
||||
configPath := filepath.Join(t.TempDir(), "v1.yaml")
|
||||
v1 := `
|
||||
server:
|
||||
port: 8080
|
||||
mcp:
|
||||
path: "/mcp"
|
||||
session_timeout: "10m"
|
||||
auth:
|
||||
keys:
|
||||
- id: "test"
|
||||
value: "secret"
|
||||
database:
|
||||
url: "postgres://from-file"
|
||||
ai:
|
||||
provider: "litellm"
|
||||
embeddings:
|
||||
model: "text-embed"
|
||||
dimensions: 1536
|
||||
metadata:
|
||||
model: "gpt-4"
|
||||
temperature: 0.2
|
||||
fallback_models: ["gpt-3.5"]
|
||||
litellm:
|
||||
base_url: "http://localhost:4000/v1"
|
||||
api_key: "file-key"
|
||||
search:
|
||||
default_limit: 10
|
||||
max_limit: 50
|
||||
logging:
|
||||
level: "info"
|
||||
`
|
||||
if err := os.WriteFile(configPath, []byte(v1), 0o600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
cfg, _, err := Load(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Load() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.Version != CurrentConfigVersion {
|
||||
t.Fatalf("version = %d, want %d", cfg.Version, CurrentConfigVersion)
|
||||
}
|
||||
if p, ok := cfg.AI.Providers["default"]; !ok || p.Type != "litellm" || p.APIKey != "file-key" {
|
||||
t.Fatalf("providers[default] = %+v, want litellm/file-key", p)
|
||||
}
|
||||
if cfg.AI.Embeddings.Primary.Model != "text-embed" || cfg.AI.Embeddings.Primary.Provider != "default" {
|
||||
t.Fatalf("embeddings.primary = %+v, want default/text-embed", cfg.AI.Embeddings.Primary)
|
||||
}
|
||||
if cfg.AI.Metadata.Primary.Model != "gpt-4" || cfg.AI.Metadata.Primary.Provider != "default" {
|
||||
t.Fatalf("metadata.primary = %+v, want default/gpt-4", cfg.AI.Metadata.Primary)
|
||||
}
|
||||
if len(cfg.AI.Metadata.Fallbacks) != 1 || cfg.AI.Metadata.Fallbacks[0].Model != "gpt-3.5" {
|
||||
t.Fatalf("metadata.fallbacks = %+v, want [default/gpt-3.5]", cfg.AI.Metadata.Fallbacks)
|
||||
}
|
||||
|
||||
entries, err := filepath.Glob(configPath + ".bak.*")
|
||||
if err != nil {
|
||||
t.Fatalf("glob backups: %v", err)
|
||||
}
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("backup files = %d, want 1", len(entries))
|
||||
}
|
||||
}
|
||||
|
||||
341
internal/config/migrate.go
Normal file
341
internal/config/migrate.go
Normal file
@@ -0,0 +1,341 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// CurrentConfigVersion is the schema version this binary expects. Files at a
|
||||
// lower version are migrated automatically when loaded.
|
||||
const CurrentConfigVersion = 2
|
||||
|
||||
// ConfigMigration upgrades a raw YAML map by one version.
|
||||
type ConfigMigration struct {
|
||||
From, To int
|
||||
Describe string
|
||||
Apply func(map[string]any) error
|
||||
}
|
||||
|
||||
// migrations is the ordered ladder of upgrades. Add new entries at the end.
|
||||
var migrations = []ConfigMigration{
|
||||
{From: 1, To: 2, Describe: "named providers + role chains", Apply: migrateV1toV2},
|
||||
}
|
||||
|
||||
// Migrate brings raw up to CurrentConfigVersion in place. Returns the list of
|
||||
// migrations that were applied (may be empty if already current).
|
||||
func Migrate(raw map[string]any) ([]ConfigMigration, error) {
|
||||
if raw == nil {
|
||||
return nil, fmt.Errorf("migrate: raw config is nil")
|
||||
}
|
||||
|
||||
version := readVersion(raw)
|
||||
if version > CurrentConfigVersion {
|
||||
return nil, fmt.Errorf("migrate: config version %d is newer than supported version %d", version, CurrentConfigVersion)
|
||||
}
|
||||
|
||||
applied := make([]ConfigMigration, 0)
|
||||
for {
|
||||
if version >= CurrentConfigVersion {
|
||||
break
|
||||
}
|
||||
step, ok := findMigration(version)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("migrate: no migration registered from version %d", version)
|
||||
}
|
||||
if err := step.Apply(raw); err != nil {
|
||||
return nil, fmt.Errorf("migrate v%d->v%d: %w", step.From, step.To, err)
|
||||
}
|
||||
raw["version"] = step.To
|
||||
version = step.To
|
||||
applied = append(applied, step)
|
||||
}
|
||||
return applied, nil
|
||||
}
|
||||
|
||||
func findMigration(from int) (ConfigMigration, bool) {
|
||||
for _, m := range migrations {
|
||||
if m.From == from {
|
||||
return m, true
|
||||
}
|
||||
}
|
||||
return ConfigMigration{}, false
|
||||
}
|
||||
|
||||
// readVersion returns the version from raw. Files without a version field are
|
||||
// treated as version 1 (the original schema).
|
||||
func readVersion(raw map[string]any) int {
|
||||
v, ok := raw["version"]
|
||||
if !ok {
|
||||
return 1
|
||||
}
|
||||
switch n := v.(type) {
|
||||
case int:
|
||||
return n
|
||||
case int64:
|
||||
return int(n)
|
||||
case float64:
|
||||
return int(n)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// migrateV1toV2 lifts the single-provider config into the named-providers +
|
||||
// role-chains layout. The pre-v2 config implicitly used one provider for both
|
||||
// embeddings and metadata; we materialise that as a provider named "default".
|
||||
func migrateV1toV2(raw map[string]any) error {
|
||||
aiRaw := mapValue(raw, "ai")
|
||||
if aiRaw == nil {
|
||||
aiRaw = map[string]any{}
|
||||
}
|
||||
|
||||
providerType := stringValue(aiRaw, "provider")
|
||||
if providerType == "" {
|
||||
providerType = "litellm"
|
||||
}
|
||||
|
||||
providers, embeddingModel, metadataModel, fallbackModels := buildV1Provider(aiRaw, providerType)
|
||||
|
||||
embeddingsOld := mapValue(aiRaw, "embeddings")
|
||||
dimensions := intValue(embeddingsOld, "dimensions")
|
||||
if dimensions <= 0 {
|
||||
dimensions = 1536
|
||||
}
|
||||
if embeddingModel == "" {
|
||||
embeddingModel = stringValue(embeddingsOld, "model")
|
||||
}
|
||||
|
||||
metadataOld := mapValue(aiRaw, "metadata")
|
||||
if metadataModel == "" {
|
||||
metadataModel = stringValue(metadataOld, "model")
|
||||
}
|
||||
temperature := floatValue(metadataOld, "temperature")
|
||||
logConversations := boolValue(metadataOld, "log_conversations")
|
||||
timeoutStr := stringValue(metadataOld, "timeout")
|
||||
|
||||
if list := stringListValue(metadataOld, "fallback_models"); len(list) > 0 {
|
||||
fallbackModels = append(fallbackModels, list...)
|
||||
}
|
||||
if v := stringValue(metadataOld, "fallback_model"); v != "" {
|
||||
fallbackModels = append(fallbackModels, v)
|
||||
}
|
||||
|
||||
embeddings := map[string]any{
|
||||
"dimensions": dimensions,
|
||||
"primary": map[string]any{"provider": "default", "model": embeddingModel},
|
||||
}
|
||||
|
||||
metadata := map[string]any{
|
||||
"temperature": temperature,
|
||||
"log_conversations": logConversations,
|
||||
"primary": map[string]any{"provider": "default", "model": metadataModel},
|
||||
}
|
||||
if timeoutStr != "" {
|
||||
metadata["timeout"] = timeoutStr
|
||||
}
|
||||
if fallbacks := chainTargets("default", fallbackModels); len(fallbacks) > 0 {
|
||||
metadata["fallbacks"] = fallbacks
|
||||
}
|
||||
|
||||
raw["ai"] = map[string]any{
|
||||
"providers": providers,
|
||||
"embeddings": embeddings,
|
||||
"metadata": metadata,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildV1Provider(aiRaw map[string]any, providerType string) (map[string]any, string, string, []string) {
|
||||
providers := map[string]any{}
|
||||
defaultEntry := map[string]any{"type": providerType}
|
||||
embedModel := ""
|
||||
metaModel := ""
|
||||
var fallbacks []string
|
||||
|
||||
switch providerType {
|
||||
case "litellm":
|
||||
block := mapValue(aiRaw, "litellm")
|
||||
copyKeys(defaultEntry, block, "base_url", "api_key")
|
||||
copyHeaders(defaultEntry, block, "request_headers")
|
||||
embedModel = stringValue(block, "embedding_model")
|
||||
metaModel = stringValue(block, "metadata_model")
|
||||
if list := stringListValue(block, "fallback_metadata_models"); len(list) > 0 {
|
||||
fallbacks = append(fallbacks, list...)
|
||||
}
|
||||
if v := stringValue(block, "fallback_metadata_model"); v != "" {
|
||||
fallbacks = append(fallbacks, v)
|
||||
}
|
||||
case "ollama":
|
||||
block := mapValue(aiRaw, "ollama")
|
||||
copyKeys(defaultEntry, block, "base_url", "api_key")
|
||||
copyHeaders(defaultEntry, block, "request_headers")
|
||||
case "openrouter":
|
||||
block := mapValue(aiRaw, "openrouter")
|
||||
copyKeys(defaultEntry, block, "base_url", "api_key", "app_name", "site_url")
|
||||
copyHeaders(defaultEntry, block, "extra_headers")
|
||||
// rename: extra_headers → request_headers
|
||||
if hdr, ok := defaultEntry["extra_headers"]; ok {
|
||||
defaultEntry["request_headers"] = hdr
|
||||
delete(defaultEntry, "extra_headers")
|
||||
}
|
||||
}
|
||||
|
||||
providers["default"] = defaultEntry
|
||||
return providers, embedModel, metaModel, fallbacks
|
||||
}
|
||||
|
||||
func chainTargets(provider string, models []string) []any {
|
||||
out := make([]any, 0, len(models))
|
||||
seen := map[string]struct{}{}
|
||||
for _, m := range models {
|
||||
if m == "" {
|
||||
continue
|
||||
}
|
||||
key := provider + "|" + m
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, map[string]any{"provider": provider, "model": m})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func mapValue(raw map[string]any, key string) map[string]any {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
v, ok := raw[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
switch m := v.(type) {
|
||||
case map[string]any:
|
||||
return m
|
||||
case map[any]any:
|
||||
return convertAnyMap(m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertAnyMap(in map[any]any) map[string]any {
|
||||
out := make(map[string]any, len(in))
|
||||
keys := make([]string, 0, len(in))
|
||||
for k, v := range in {
|
||||
ks, ok := k.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, ks)
|
||||
out[ks] = v
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return out
|
||||
}
|
||||
|
||||
func stringValue(raw map[string]any, key string) string {
|
||||
if raw == nil {
|
||||
return ""
|
||||
}
|
||||
v, ok := raw[key]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func intValue(raw map[string]any, key string) int {
|
||||
if raw == nil {
|
||||
return 0
|
||||
}
|
||||
switch n := raw[key].(type) {
|
||||
case int:
|
||||
return n
|
||||
case int64:
|
||||
return int(n)
|
||||
case float64:
|
||||
return int(n)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func floatValue(raw map[string]any, key string) float64 {
|
||||
if raw == nil {
|
||||
return 0
|
||||
}
|
||||
switch n := raw[key].(type) {
|
||||
case float64:
|
||||
return n
|
||||
case int:
|
||||
return float64(n)
|
||||
case int64:
|
||||
return float64(n)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func boolValue(raw map[string]any, key string) bool {
|
||||
if raw == nil {
|
||||
return false
|
||||
}
|
||||
if b, ok := raw[key].(bool); ok {
|
||||
return b
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func stringListValue(raw map[string]any, key string) []string {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
v, ok := raw[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
list, ok := v.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(list))
|
||||
for _, item := range list {
|
||||
if s, ok := item.(string); ok && s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func copyKeys(dst, src map[string]any, keys ...string) {
|
||||
if src == nil {
|
||||
return
|
||||
}
|
||||
for _, k := range keys {
|
||||
if v, ok := src[k]; ok {
|
||||
dst[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func copyHeaders(dst, src map[string]any, key string) {
|
||||
if src == nil {
|
||||
return
|
||||
}
|
||||
v, ok := src[key]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
switch headers := v.(type) {
|
||||
case map[string]any:
|
||||
if len(headers) == 0 {
|
||||
return
|
||||
}
|
||||
dst[key] = headers
|
||||
case map[any]any:
|
||||
if len(headers) == 0 {
|
||||
return
|
||||
}
|
||||
dst[key] = convertAnyMap(headers)
|
||||
}
|
||||
}
|
||||
77
internal/config/migrate_test.go
Normal file
77
internal/config/migrate_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestMigrateV1ToV2Litellm(t *testing.T) {
|
||||
raw := map[string]any{
|
||||
"ai": map[string]any{
|
||||
"provider": "litellm",
|
||||
"embeddings": map[string]any{
|
||||
"model": "text-embedding-3-small",
|
||||
"dimensions": 1536,
|
||||
},
|
||||
"metadata": map[string]any{
|
||||
"model": "gpt-4o-mini",
|
||||
"temperature": 0.2,
|
||||
"fallback_models": []any{"gpt-4.1-mini"},
|
||||
},
|
||||
"litellm": map[string]any{
|
||||
"base_url": "http://localhost:4000/v1",
|
||||
"api_key": "secret",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
applied, err := Migrate(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("Migrate() error = %v", err)
|
||||
}
|
||||
if len(applied) != 1 || applied[0].From != 1 || applied[0].To != 2 {
|
||||
t.Fatalf("applied = %+v, want [v1->v2]", applied)
|
||||
}
|
||||
if got := readVersion(raw); got != CurrentConfigVersion {
|
||||
t.Fatalf("version = %d, want %d", got, CurrentConfigVersion)
|
||||
}
|
||||
|
||||
ai := mapValue(raw, "ai")
|
||||
providers := mapValue(ai, "providers")
|
||||
def := mapValue(providers, "default")
|
||||
if got := stringValue(def, "type"); got != "litellm" {
|
||||
t.Fatalf("providers.default.type = %q, want litellm", got)
|
||||
}
|
||||
if got := stringValue(def, "base_url"); got != "http://localhost:4000/v1" {
|
||||
t.Fatalf("providers.default.base_url = %q", got)
|
||||
}
|
||||
|
||||
emb := mapValue(ai, "embeddings")
|
||||
embPrimary := mapValue(emb, "primary")
|
||||
if stringValue(embPrimary, "provider") != "default" || stringValue(embPrimary, "model") != "text-embedding-3-small" {
|
||||
t.Fatalf("embeddings.primary = %+v, want default/text-embedding-3-small", embPrimary)
|
||||
}
|
||||
|
||||
meta := mapValue(ai, "metadata")
|
||||
metaPrimary := mapValue(meta, "primary")
|
||||
if stringValue(metaPrimary, "provider") != "default" || stringValue(metaPrimary, "model") != "gpt-4o-mini" {
|
||||
t.Fatalf("metadata.primary = %+v, want default/gpt-4o-mini", metaPrimary)
|
||||
}
|
||||
fallbacks, ok := meta["fallbacks"].([]any)
|
||||
if !ok || len(fallbacks) != 1 {
|
||||
t.Fatalf("metadata.fallbacks = %#v, want len=1", meta["fallbacks"])
|
||||
}
|
||||
firstFallback, ok := fallbacks[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("metadata.fallbacks[0] type = %T, want map[string]any", fallbacks[0])
|
||||
}
|
||||
if stringValue(firstFallback, "provider") != "default" || stringValue(firstFallback, "model") != "gpt-4.1-mini" {
|
||||
t.Fatalf("metadata fallback = %+v, want default/gpt-4.1-mini", firstFallback)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateRejectsNewerVersion(t *testing.T) {
|
||||
raw := map[string]any{"version": CurrentConfigVersion + 1}
|
||||
|
||||
_, err := Migrate(raw)
|
||||
if err == nil {
|
||||
t.Fatal("Migrate() error = nil, want error for newer config version")
|
||||
}
|
||||
}
|
||||
@@ -45,38 +45,8 @@ func (c Config) Validate() error {
|
||||
return fmt.Errorf("invalid config: mcp.session_timeout must be greater than zero")
|
||||
}
|
||||
|
||||
switch c.AI.Provider {
|
||||
case "litellm", "ollama", "openrouter":
|
||||
default:
|
||||
return fmt.Errorf("invalid config: unsupported ai.provider %q", c.AI.Provider)
|
||||
}
|
||||
|
||||
if c.AI.Embeddings.Dimensions <= 0 {
|
||||
return fmt.Errorf("invalid config: ai.embeddings.dimensions must be greater than zero")
|
||||
}
|
||||
|
||||
switch c.AI.Provider {
|
||||
case "litellm":
|
||||
if strings.TrimSpace(c.AI.LiteLLM.BaseURL) == "" {
|
||||
return fmt.Errorf("invalid config: ai.litellm.base_url is required when ai.provider=litellm")
|
||||
}
|
||||
if strings.TrimSpace(c.AI.LiteLLM.APIKey) == "" {
|
||||
return fmt.Errorf("invalid config: ai.litellm.api_key is required when ai.provider=litellm")
|
||||
}
|
||||
case "ollama":
|
||||
if strings.TrimSpace(c.AI.Ollama.BaseURL) == "" {
|
||||
return fmt.Errorf("invalid config: ai.ollama.base_url is required when ai.provider=ollama")
|
||||
}
|
||||
if strings.TrimSpace(c.AI.Ollama.APIKey) == "" {
|
||||
return fmt.Errorf("invalid config: ai.ollama.api_key is required when ai.provider=ollama")
|
||||
}
|
||||
case "openrouter":
|
||||
if strings.TrimSpace(c.AI.OpenRouter.BaseURL) == "" {
|
||||
return fmt.Errorf("invalid config: ai.openrouter.base_url is required when ai.provider=openrouter")
|
||||
}
|
||||
if strings.TrimSpace(c.AI.OpenRouter.APIKey) == "" {
|
||||
return fmt.Errorf("invalid config: ai.openrouter.api_key is required when ai.provider=openrouter")
|
||||
}
|
||||
if err := c.AI.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.Server.Port <= 0 {
|
||||
@@ -108,3 +78,61 @@ func (c Config) Validate() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a AIConfig) validate() error {
|
||||
if len(a.Providers) == 0 {
|
||||
return fmt.Errorf("invalid config: ai.providers must contain at least one entry")
|
||||
}
|
||||
for name, p := range a.Providers {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return fmt.Errorf("invalid config: ai.providers contains an entry with an empty name")
|
||||
}
|
||||
switch p.Type {
|
||||
case "litellm", "ollama", "openrouter":
|
||||
default:
|
||||
return fmt.Errorf("invalid config: ai.providers.%s.type %q is not supported", name, p.Type)
|
||||
}
|
||||
if strings.TrimSpace(p.BaseURL) == "" {
|
||||
return fmt.Errorf("invalid config: ai.providers.%s.base_url is required", name)
|
||||
}
|
||||
if strings.TrimSpace(p.APIKey) == "" {
|
||||
return fmt.Errorf("invalid config: ai.providers.%s.api_key is required", name)
|
||||
}
|
||||
}
|
||||
|
||||
if a.Embeddings.Dimensions <= 0 {
|
||||
return fmt.Errorf("invalid config: ai.embeddings.dimensions must be greater than zero")
|
||||
}
|
||||
|
||||
if err := a.validateChain("ai.embeddings", a.Embeddings.Chain()); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := a.validateChain("ai.metadata", a.Metadata.Chain()); err != nil {
|
||||
return err
|
||||
}
|
||||
if a.Background != nil {
|
||||
if a.Background.Embeddings != nil {
|
||||
if err := a.validateChain("ai.background.embeddings", a.Background.Embeddings.AsTargets()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if a.Background.Metadata != nil {
|
||||
if err := a.validateChain("ai.background.metadata", a.Background.Metadata.AsTargets()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a AIConfig) validateChain(prefix string, chain []RoleTarget) error {
|
||||
if len(chain) == 0 {
|
||||
return fmt.Errorf("invalid config: %s.primary must reference a configured provider and model", prefix)
|
||||
}
|
||||
for i, target := range chain {
|
||||
if _, ok := a.Providers[target.Provider]; !ok {
|
||||
return fmt.Errorf("invalid config: %s[%d] references unknown provider %q", prefix, i, target.Provider)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -7,28 +7,23 @@ import (
|
||||
|
||||
func validConfig() Config {
|
||||
return Config{
|
||||
Server: ServerConfig{Port: 8080},
|
||||
MCP: MCPConfig{Path: "/mcp", SessionTimeout: 10 * time.Minute},
|
||||
Version: CurrentConfigVersion,
|
||||
Server: ServerConfig{Port: 8080},
|
||||
MCP: MCPConfig{Path: "/mcp", SessionTimeout: 10 * time.Minute},
|
||||
Auth: AuthConfig{
|
||||
Keys: []APIKey{{ID: "test", Value: "secret"}},
|
||||
},
|
||||
Database: DatabaseConfig{URL: "postgres://example"},
|
||||
AI: AIConfig{
|
||||
Provider: "litellm",
|
||||
Embeddings: AIEmbeddingConfig{
|
||||
Providers: map[string]ProviderConfig{
|
||||
"default": {Type: "litellm", BaseURL: "http://localhost:4000/v1", APIKey: "key"},
|
||||
},
|
||||
Embeddings: EmbeddingsRoleConfig{
|
||||
Dimensions: 1536,
|
||||
Primary: RoleTarget{Provider: "default", Model: "text-embed"},
|
||||
},
|
||||
LiteLLM: LiteLLMConfig{
|
||||
BaseURL: "http://localhost:4000/v1",
|
||||
APIKey: "key",
|
||||
},
|
||||
Ollama: OllamaConfig{
|
||||
BaseURL: "http://localhost:11434/v1",
|
||||
APIKey: "ollama",
|
||||
},
|
||||
OpenRouter: OpenRouterAIConfig{
|
||||
BaseURL: "https://openrouter.ai/api/v1",
|
||||
APIKey: "key",
|
||||
Metadata: MetadataRoleConfig{
|
||||
Primary: RoleTarget{Provider: "default", Model: "gpt-4"},
|
||||
},
|
||||
},
|
||||
Search: SearchConfig{DefaultLimit: 10, MaxLimit: 50},
|
||||
@@ -36,29 +31,44 @@ func validConfig() Config {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAcceptsSupportedProviders(t *testing.T) {
|
||||
cfg := validConfig()
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Fatalf("Validate litellm error = %v", err)
|
||||
}
|
||||
|
||||
cfg.AI.Provider = "ollama"
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Fatalf("Validate ollama error = %v", err)
|
||||
}
|
||||
|
||||
cfg.AI.Provider = "openrouter"
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Fatalf("Validate openrouter error = %v", err)
|
||||
func TestValidateAcceptsSupportedProviderTypes(t *testing.T) {
|
||||
for _, providerType := range []string{"litellm", "ollama", "openrouter"} {
|
||||
cfg := validConfig()
|
||||
p := cfg.AI.Providers["default"]
|
||||
p.Type = providerType
|
||||
cfg.AI.Providers["default"] = p
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Fatalf("Validate %s error = %v", providerType, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRejectsInvalidProvider(t *testing.T) {
|
||||
func TestValidateRejectsInvalidProviderType(t *testing.T) {
|
||||
cfg := validConfig()
|
||||
cfg.AI.Provider = "unknown"
|
||||
p := cfg.AI.Providers["default"]
|
||||
p.Type = "unknown"
|
||||
cfg.AI.Providers["default"] = p
|
||||
|
||||
if err := cfg.Validate(); err == nil {
|
||||
t.Fatal("Validate() error = nil, want error for unsupported provider")
|
||||
t.Fatal("Validate() error = nil, want error for unsupported provider type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRejectsChainWithUnknownProvider(t *testing.T) {
|
||||
cfg := validConfig()
|
||||
cfg.AI.Metadata.Primary = RoleTarget{Provider: "does-not-exist", Model: "x"}
|
||||
|
||||
if err := cfg.Validate(); err == nil {
|
||||
t.Fatal("Validate() error = nil, want error for chain referencing unknown provider")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRejectsEmptyProviders(t *testing.T) {
|
||||
cfg := validConfig()
|
||||
cfg.AI.Providers = map[string]ProviderConfig{}
|
||||
|
||||
if err := cfg.Validate(); err == nil {
|
||||
t.Fatal("Validate() error = nil, want error for empty providers")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user