Files
amcs/internal/config/loader.go
Hein 14e218d784
Some checks failed
CI / build-and-test (push) Failing after -32m22s
test(config): add migration tests for litellm provider
* Implement tests for migrating configuration from v1 to v2 for the litellm provider.
* Validate the structure and values of the migrated configuration.
* Ensure migration rejects newer versions of the configuration.
fix(validate): enhance AI provider validation logic
* Consolidate provider validation into a dedicated method.
* Ensure at least one provider is specified and validate its type.
* Check for required fields based on provider type.
fix(mcpserver): update tool set to use new enrichment tool
* Replace RetryMetadataTool with RetryEnrichmentTool in the ToolSet.
fix(tools): refactor tools to use embedding and metadata runners
* Update tools to utilize EmbeddingRunner and MetadataRunner instead of Provider.
* Adjust method calls to align with the new runner interfaces.
2026-04-21 21:14:28 +02:00

219 lines
5.7 KiB
Go

package config
import (
"fmt"
"log/slog"
"os"
"strconv"
"strings"
"time"
"git.warky.dev/wdevs/amcs/internal/buildinfo"
"gopkg.in/yaml.v3"
)
func Load(explicitPath string) (*Config, string, error) {
return LoadWithLogger(explicitPath, nil)
}
// LoadWithLogger is Load with a logger surface for migration notices. Passing
// nil is fine — migration events will simply not be logged.
func LoadWithLogger(explicitPath string, log *slog.Logger) (*Config, string, error) {
path := ResolvePath(explicitPath)
data, err := os.ReadFile(path)
if err != nil {
return nil, path, fmt.Errorf("read config %q: %w", path, err)
}
raw := map[string]any{}
if err := yaml.Unmarshal(data, &raw); err != nil {
return nil, path, fmt.Errorf("decode config %q: %w", path, err)
}
if raw == nil {
raw = map[string]any{}
}
applied, err := Migrate(raw)
if err != nil {
return nil, path, fmt.Errorf("migrate config %q: %w", path, err)
}
if len(applied) > 0 {
if err := rewriteConfigFile(path, data, raw); err != nil {
return nil, path, err
}
if log != nil {
for _, step := range applied {
log.Warn("config migrated",
slog.String("path", path),
slog.Int("from_version", step.From),
slog.Int("to_version", step.To),
slog.String("describe", step.Describe),
)
}
}
}
cfg, err := decodeTyped(raw)
if err != nil {
return nil, path, fmt.Errorf("decode migrated config %q: %w", path, err)
}
cfg.Version = CurrentConfigVersion
applyEnvOverrides(&cfg)
if err := cfg.Validate(); err != nil {
return nil, path, err
}
return &cfg, path, nil
}
func decodeTyped(raw map[string]any) (Config, error) {
out, err := yaml.Marshal(raw)
if err != nil {
return Config{}, fmt.Errorf("re-marshal migrated config: %w", err)
}
cfg := defaultConfig()
if err := yaml.Unmarshal(out, &cfg); err != nil {
return Config{}, err
}
return cfg, nil
}
func rewriteConfigFile(path string, original []byte, migrated map[string]any) error {
backupPath := fmt.Sprintf("%s.bak.%d", path, time.Now().Unix())
if err := os.WriteFile(backupPath, original, 0o600); err != nil {
return fmt.Errorf("write backup %q: %w", backupPath, err)
}
out, err := yaml.Marshal(migrated)
if err != nil {
return fmt.Errorf("marshal migrated config: %w", err)
}
if err := os.WriteFile(path, out, 0o600); err != nil {
return fmt.Errorf("write migrated config %q: %w", path, err)
}
return nil
}
func ResolvePath(explicitPath string) string {
if path := strings.TrimSpace(explicitPath); path != "" {
if path != ".yaml" && path != ".yml" {
return path
}
}
if envPath := strings.TrimSpace(os.Getenv("AMCS_CONFIG")); envPath != "" {
return envPath
}
return DefaultConfigPath
}
func defaultConfig() Config {
info := buildinfo.Current()
return Config{
Version: CurrentConfigVersion,
Server: ServerConfig{
Host: "0.0.0.0",
Port: 8080,
ReadTimeout: 10 * time.Minute,
WriteTimeout: 10 * time.Minute,
IdleTimeout: 60 * time.Second,
},
MCP: MCPConfig{
Path: "/mcp",
SSEPath: "/sse",
ServerName: "amcs",
Version: info.Version,
Transport: "streamable_http",
SessionTimeout: 10 * time.Minute,
},
Auth: AuthConfig{
HeaderName: "x-brain-key",
QueryParam: "key",
},
AI: AIConfig{
Providers: map[string]ProviderConfig{},
Embeddings: EmbeddingsRoleConfig{
Dimensions: 1536,
},
Metadata: MetadataRoleConfig{
Temperature: 0.1,
Timeout: 10 * time.Second,
},
},
Capture: CaptureConfig{
Source: DefaultSource,
MetadataDefaults: CaptureMetadataDefault{
Type: "observation",
TopicFallback: "uncategorized",
},
},
Search: SearchConfig{
DefaultLimit: 10,
DefaultThreshold: 0.5,
MaxLimit: 50,
},
Logging: LoggingConfig{
Level: "info",
Format: "json",
},
Backfill: BackfillConfig{
Enabled: false,
RunOnStartup: false,
Interval: 15 * time.Minute,
BatchSize: 20,
MaxPerRun: 100,
},
MetadataRetry: MetadataRetryConfig{
Enabled: false,
RunOnStartup: false,
Interval: 24 * time.Hour,
MaxPerRun: 100,
},
}
}
func applyEnvOverrides(cfg *Config) {
overrideString(&cfg.Database.URL, "AMCS_DATABASE_URL")
overrideString(&cfg.MCP.PublicURL, "AMCS_PUBLIC_URL")
overrideProviderField(cfg, "AMCS_LITELLM_BASE_URL", "litellm", func(p *ProviderConfig, v string) { p.BaseURL = v })
overrideProviderField(cfg, "AMCS_LITELLM_API_KEY", "litellm", func(p *ProviderConfig, v string) { p.APIKey = v })
overrideProviderField(cfg, "AMCS_OLLAMA_BASE_URL", "ollama", func(p *ProviderConfig, v string) { p.BaseURL = v })
overrideProviderField(cfg, "AMCS_OLLAMA_API_KEY", "ollama", func(p *ProviderConfig, v string) { p.APIKey = v })
overrideProviderField(cfg, "AMCS_OPENROUTER_API_KEY", "openrouter", func(p *ProviderConfig, v string) { p.APIKey = v })
if value, ok := os.LookupEnv("AMCS_SERVER_PORT"); ok {
if port, err := strconv.Atoi(strings.TrimSpace(value)); err == nil {
cfg.Server.Port = port
}
}
}
// overrideProviderField applies an env var to every configured provider of the
// given type. This preserves the v1 behaviour where e.g. AMCS_LITELLM_API_KEY
// rewrote the single litellm block — in v2 it rewrites every litellm provider.
func overrideProviderField(cfg *Config, envKey, providerType string, apply func(*ProviderConfig, string)) {
value, ok := os.LookupEnv(envKey)
if !ok {
return
}
value = strings.TrimSpace(value)
for name, p := range cfg.AI.Providers {
if p.Type != providerType {
continue
}
apply(&p, value)
cfg.AI.Providers[name] = p
}
}
func overrideString(target *string, envKey string) {
if value, ok := os.LookupEnv(envKey); ok {
*target = strings.TrimSpace(value)
}
}