Files
amcs/internal/config/migrate.go
Hein 14e218d784
Some checks failed
CI / build-and-test (push) Failing after -32m22s
test(config): add migration tests for litellm provider
* Implement tests for migrating configuration from v1 to v2 for the litellm provider.
* Validate the structure and values of the migrated configuration.
* Ensure migration rejects newer versions of the configuration.
fix(validate): enhance AI provider validation logic
* Consolidate provider validation into a dedicated method.
* Ensure at least one provider is specified and validate its type.
* Check for required fields based on provider type.
fix(mcpserver): update tool set to use new enrichment tool
* Replace RetryMetadataTool with RetryEnrichmentTool in the ToolSet.
fix(tools): refactor tools to use embedding and metadata runners
* Update tools to utilize EmbeddingRunner and MetadataRunner instead of Provider.
* Adjust method calls to align with the new runner interfaces.
2026-04-21 21:14:28 +02:00

342 lines
7.8 KiB
Go

package config
import (
"fmt"
"sort"
)
// CurrentConfigVersion is the schema version this binary expects. Files at a
// lower version are migrated automatically when loaded.
const CurrentConfigVersion = 2
// ConfigMigration upgrades a raw YAML map by one version.
type ConfigMigration struct {
From, To int
Describe string
Apply func(map[string]any) error
}
// migrations is the ordered ladder of upgrades. Add new entries at the end.
var migrations = []ConfigMigration{
{From: 1, To: 2, Describe: "named providers + role chains", Apply: migrateV1toV2},
}
// Migrate brings raw up to CurrentConfigVersion in place. Returns the list of
// migrations that were applied (may be empty if already current).
func Migrate(raw map[string]any) ([]ConfigMigration, error) {
if raw == nil {
return nil, fmt.Errorf("migrate: raw config is nil")
}
version := readVersion(raw)
if version > CurrentConfigVersion {
return nil, fmt.Errorf("migrate: config version %d is newer than supported version %d", version, CurrentConfigVersion)
}
applied := make([]ConfigMigration, 0)
for {
if version >= CurrentConfigVersion {
break
}
step, ok := findMigration(version)
if !ok {
return nil, fmt.Errorf("migrate: no migration registered from version %d", version)
}
if err := step.Apply(raw); err != nil {
return nil, fmt.Errorf("migrate v%d->v%d: %w", step.From, step.To, err)
}
raw["version"] = step.To
version = step.To
applied = append(applied, step)
}
return applied, nil
}
func findMigration(from int) (ConfigMigration, bool) {
for _, m := range migrations {
if m.From == from {
return m, true
}
}
return ConfigMigration{}, false
}
// readVersion returns the version from raw. Files without a version field are
// treated as version 1 (the original schema).
func readVersion(raw map[string]any) int {
v, ok := raw["version"]
if !ok {
return 1
}
switch n := v.(type) {
case int:
return n
case int64:
return int(n)
case float64:
return int(n)
}
return 1
}
// migrateV1toV2 lifts the single-provider config into the named-providers +
// role-chains layout. The pre-v2 config implicitly used one provider for both
// embeddings and metadata; we materialise that as a provider named "default".
func migrateV1toV2(raw map[string]any) error {
aiRaw := mapValue(raw, "ai")
if aiRaw == nil {
aiRaw = map[string]any{}
}
providerType := stringValue(aiRaw, "provider")
if providerType == "" {
providerType = "litellm"
}
providers, embeddingModel, metadataModel, fallbackModels := buildV1Provider(aiRaw, providerType)
embeddingsOld := mapValue(aiRaw, "embeddings")
dimensions := intValue(embeddingsOld, "dimensions")
if dimensions <= 0 {
dimensions = 1536
}
if embeddingModel == "" {
embeddingModel = stringValue(embeddingsOld, "model")
}
metadataOld := mapValue(aiRaw, "metadata")
if metadataModel == "" {
metadataModel = stringValue(metadataOld, "model")
}
temperature := floatValue(metadataOld, "temperature")
logConversations := boolValue(metadataOld, "log_conversations")
timeoutStr := stringValue(metadataOld, "timeout")
if list := stringListValue(metadataOld, "fallback_models"); len(list) > 0 {
fallbackModels = append(fallbackModels, list...)
}
if v := stringValue(metadataOld, "fallback_model"); v != "" {
fallbackModels = append(fallbackModels, v)
}
embeddings := map[string]any{
"dimensions": dimensions,
"primary": map[string]any{"provider": "default", "model": embeddingModel},
}
metadata := map[string]any{
"temperature": temperature,
"log_conversations": logConversations,
"primary": map[string]any{"provider": "default", "model": metadataModel},
}
if timeoutStr != "" {
metadata["timeout"] = timeoutStr
}
if fallbacks := chainTargets("default", fallbackModels); len(fallbacks) > 0 {
metadata["fallbacks"] = fallbacks
}
raw["ai"] = map[string]any{
"providers": providers,
"embeddings": embeddings,
"metadata": metadata,
}
return nil
}
func buildV1Provider(aiRaw map[string]any, providerType string) (map[string]any, string, string, []string) {
providers := map[string]any{}
defaultEntry := map[string]any{"type": providerType}
embedModel := ""
metaModel := ""
var fallbacks []string
switch providerType {
case "litellm":
block := mapValue(aiRaw, "litellm")
copyKeys(defaultEntry, block, "base_url", "api_key")
copyHeaders(defaultEntry, block, "request_headers")
embedModel = stringValue(block, "embedding_model")
metaModel = stringValue(block, "metadata_model")
if list := stringListValue(block, "fallback_metadata_models"); len(list) > 0 {
fallbacks = append(fallbacks, list...)
}
if v := stringValue(block, "fallback_metadata_model"); v != "" {
fallbacks = append(fallbacks, v)
}
case "ollama":
block := mapValue(aiRaw, "ollama")
copyKeys(defaultEntry, block, "base_url", "api_key")
copyHeaders(defaultEntry, block, "request_headers")
case "openrouter":
block := mapValue(aiRaw, "openrouter")
copyKeys(defaultEntry, block, "base_url", "api_key", "app_name", "site_url")
copyHeaders(defaultEntry, block, "extra_headers")
// rename: extra_headers → request_headers
if hdr, ok := defaultEntry["extra_headers"]; ok {
defaultEntry["request_headers"] = hdr
delete(defaultEntry, "extra_headers")
}
}
providers["default"] = defaultEntry
return providers, embedModel, metaModel, fallbacks
}
func chainTargets(provider string, models []string) []any {
out := make([]any, 0, len(models))
seen := map[string]struct{}{}
for _, m := range models {
if m == "" {
continue
}
key := provider + "|" + m
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, map[string]any{"provider": provider, "model": m})
}
return out
}
func mapValue(raw map[string]any, key string) map[string]any {
if raw == nil {
return nil
}
v, ok := raw[key]
if !ok {
return nil
}
switch m := v.(type) {
case map[string]any:
return m
case map[any]any:
return convertAnyMap(m)
}
return nil
}
func convertAnyMap(in map[any]any) map[string]any {
out := make(map[string]any, len(in))
keys := make([]string, 0, len(in))
for k, v := range in {
ks, ok := k.(string)
if !ok {
continue
}
keys = append(keys, ks)
out[ks] = v
}
sort.Strings(keys)
return out
}
func stringValue(raw map[string]any, key string) string {
if raw == nil {
return ""
}
v, ok := raw[key]
if !ok {
return ""
}
if s, ok := v.(string); ok {
return s
}
return ""
}
func intValue(raw map[string]any, key string) int {
if raw == nil {
return 0
}
switch n := raw[key].(type) {
case int:
return n
case int64:
return int(n)
case float64:
return int(n)
}
return 0
}
func floatValue(raw map[string]any, key string) float64 {
if raw == nil {
return 0
}
switch n := raw[key].(type) {
case float64:
return n
case int:
return float64(n)
case int64:
return float64(n)
}
return 0
}
func boolValue(raw map[string]any, key string) bool {
if raw == nil {
return false
}
if b, ok := raw[key].(bool); ok {
return b
}
return false
}
func stringListValue(raw map[string]any, key string) []string {
if raw == nil {
return nil
}
v, ok := raw[key]
if !ok {
return nil
}
list, ok := v.([]any)
if !ok {
return nil
}
out := make([]string, 0, len(list))
for _, item := range list {
if s, ok := item.(string); ok && s != "" {
out = append(out, s)
}
}
return out
}
func copyKeys(dst, src map[string]any, keys ...string) {
if src == nil {
return
}
for _, k := range keys {
if v, ok := src[k]; ok {
dst[k] = v
}
}
}
func copyHeaders(dst, src map[string]any, key string) {
if src == nil {
return
}
v, ok := src[key]
if !ok {
return
}
switch headers := v.(type) {
case map[string]any:
if len(headers) == 0 {
return
}
dst[key] = headers
case map[any]any:
if len(headers) == 0 {
return
}
dst[key] = convertAnyMap(headers)
}
}