test(config): add migration tests for litellm provider
Some checks failed
CI / build-and-test (push) Failing after -32m22s
Some checks failed
CI / build-and-test (push) Failing after -32m22s
* Implement tests for migrating configuration from v1 to v2 for the litellm provider. * Validate the structure and values of the migrated configuration. * Ensure migration rejects newer versions of the configuration. fix(validate): enhance AI provider validation logic * Consolidate provider validation into a dedicated method. * Ensure at least one provider is specified and validate its type. * Check for required fields based on provider type. fix(mcpserver): update tool set to use new enrichment tool * Replace RetryMetadataTool with RetryEnrichmentTool in the ToolSet. fix(tools): refactor tools to use embedding and metadata runners * Update tools to utilize EmbeddingRunner and MetadataRunner instead of Provider. * Adjust method calls to align with the new runner interfaces.
This commit is contained in:
341
internal/config/migrate.go
Normal file
341
internal/config/migrate.go
Normal file
@@ -0,0 +1,341 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// CurrentConfigVersion is the schema version this binary expects. Files at a
|
||||
// lower version are migrated automatically when loaded.
|
||||
const CurrentConfigVersion = 2
|
||||
|
||||
// ConfigMigration upgrades a raw YAML map by one version.
|
||||
type ConfigMigration struct {
|
||||
From, To int
|
||||
Describe string
|
||||
Apply func(map[string]any) error
|
||||
}
|
||||
|
||||
// migrations is the ordered ladder of upgrades. Add new entries at the end.
|
||||
var migrations = []ConfigMigration{
|
||||
{From: 1, To: 2, Describe: "named providers + role chains", Apply: migrateV1toV2},
|
||||
}
|
||||
|
||||
// Migrate brings raw up to CurrentConfigVersion in place. Returns the list of
|
||||
// migrations that were applied (may be empty if already current).
|
||||
func Migrate(raw map[string]any) ([]ConfigMigration, error) {
|
||||
if raw == nil {
|
||||
return nil, fmt.Errorf("migrate: raw config is nil")
|
||||
}
|
||||
|
||||
version := readVersion(raw)
|
||||
if version > CurrentConfigVersion {
|
||||
return nil, fmt.Errorf("migrate: config version %d is newer than supported version %d", version, CurrentConfigVersion)
|
||||
}
|
||||
|
||||
applied := make([]ConfigMigration, 0)
|
||||
for {
|
||||
if version >= CurrentConfigVersion {
|
||||
break
|
||||
}
|
||||
step, ok := findMigration(version)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("migrate: no migration registered from version %d", version)
|
||||
}
|
||||
if err := step.Apply(raw); err != nil {
|
||||
return nil, fmt.Errorf("migrate v%d->v%d: %w", step.From, step.To, err)
|
||||
}
|
||||
raw["version"] = step.To
|
||||
version = step.To
|
||||
applied = append(applied, step)
|
||||
}
|
||||
return applied, nil
|
||||
}
|
||||
|
||||
func findMigration(from int) (ConfigMigration, bool) {
|
||||
for _, m := range migrations {
|
||||
if m.From == from {
|
||||
return m, true
|
||||
}
|
||||
}
|
||||
return ConfigMigration{}, false
|
||||
}
|
||||
|
||||
// readVersion returns the version from raw. Files without a version field are
|
||||
// treated as version 1 (the original schema).
|
||||
func readVersion(raw map[string]any) int {
|
||||
v, ok := raw["version"]
|
||||
if !ok {
|
||||
return 1
|
||||
}
|
||||
switch n := v.(type) {
|
||||
case int:
|
||||
return n
|
||||
case int64:
|
||||
return int(n)
|
||||
case float64:
|
||||
return int(n)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
// migrateV1toV2 lifts the single-provider config into the named-providers +
|
||||
// role-chains layout. The pre-v2 config implicitly used one provider for both
|
||||
// embeddings and metadata; we materialise that as a provider named "default".
|
||||
func migrateV1toV2(raw map[string]any) error {
|
||||
aiRaw := mapValue(raw, "ai")
|
||||
if aiRaw == nil {
|
||||
aiRaw = map[string]any{}
|
||||
}
|
||||
|
||||
providerType := stringValue(aiRaw, "provider")
|
||||
if providerType == "" {
|
||||
providerType = "litellm"
|
||||
}
|
||||
|
||||
providers, embeddingModel, metadataModel, fallbackModels := buildV1Provider(aiRaw, providerType)
|
||||
|
||||
embeddingsOld := mapValue(aiRaw, "embeddings")
|
||||
dimensions := intValue(embeddingsOld, "dimensions")
|
||||
if dimensions <= 0 {
|
||||
dimensions = 1536
|
||||
}
|
||||
if embeddingModel == "" {
|
||||
embeddingModel = stringValue(embeddingsOld, "model")
|
||||
}
|
||||
|
||||
metadataOld := mapValue(aiRaw, "metadata")
|
||||
if metadataModel == "" {
|
||||
metadataModel = stringValue(metadataOld, "model")
|
||||
}
|
||||
temperature := floatValue(metadataOld, "temperature")
|
||||
logConversations := boolValue(metadataOld, "log_conversations")
|
||||
timeoutStr := stringValue(metadataOld, "timeout")
|
||||
|
||||
if list := stringListValue(metadataOld, "fallback_models"); len(list) > 0 {
|
||||
fallbackModels = append(fallbackModels, list...)
|
||||
}
|
||||
if v := stringValue(metadataOld, "fallback_model"); v != "" {
|
||||
fallbackModels = append(fallbackModels, v)
|
||||
}
|
||||
|
||||
embeddings := map[string]any{
|
||||
"dimensions": dimensions,
|
||||
"primary": map[string]any{"provider": "default", "model": embeddingModel},
|
||||
}
|
||||
|
||||
metadata := map[string]any{
|
||||
"temperature": temperature,
|
||||
"log_conversations": logConversations,
|
||||
"primary": map[string]any{"provider": "default", "model": metadataModel},
|
||||
}
|
||||
if timeoutStr != "" {
|
||||
metadata["timeout"] = timeoutStr
|
||||
}
|
||||
if fallbacks := chainTargets("default", fallbackModels); len(fallbacks) > 0 {
|
||||
metadata["fallbacks"] = fallbacks
|
||||
}
|
||||
|
||||
raw["ai"] = map[string]any{
|
||||
"providers": providers,
|
||||
"embeddings": embeddings,
|
||||
"metadata": metadata,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildV1Provider(aiRaw map[string]any, providerType string) (map[string]any, string, string, []string) {
|
||||
providers := map[string]any{}
|
||||
defaultEntry := map[string]any{"type": providerType}
|
||||
embedModel := ""
|
||||
metaModel := ""
|
||||
var fallbacks []string
|
||||
|
||||
switch providerType {
|
||||
case "litellm":
|
||||
block := mapValue(aiRaw, "litellm")
|
||||
copyKeys(defaultEntry, block, "base_url", "api_key")
|
||||
copyHeaders(defaultEntry, block, "request_headers")
|
||||
embedModel = stringValue(block, "embedding_model")
|
||||
metaModel = stringValue(block, "metadata_model")
|
||||
if list := stringListValue(block, "fallback_metadata_models"); len(list) > 0 {
|
||||
fallbacks = append(fallbacks, list...)
|
||||
}
|
||||
if v := stringValue(block, "fallback_metadata_model"); v != "" {
|
||||
fallbacks = append(fallbacks, v)
|
||||
}
|
||||
case "ollama":
|
||||
block := mapValue(aiRaw, "ollama")
|
||||
copyKeys(defaultEntry, block, "base_url", "api_key")
|
||||
copyHeaders(defaultEntry, block, "request_headers")
|
||||
case "openrouter":
|
||||
block := mapValue(aiRaw, "openrouter")
|
||||
copyKeys(defaultEntry, block, "base_url", "api_key", "app_name", "site_url")
|
||||
copyHeaders(defaultEntry, block, "extra_headers")
|
||||
// rename: extra_headers → request_headers
|
||||
if hdr, ok := defaultEntry["extra_headers"]; ok {
|
||||
defaultEntry["request_headers"] = hdr
|
||||
delete(defaultEntry, "extra_headers")
|
||||
}
|
||||
}
|
||||
|
||||
providers["default"] = defaultEntry
|
||||
return providers, embedModel, metaModel, fallbacks
|
||||
}
|
||||
|
||||
func chainTargets(provider string, models []string) []any {
|
||||
out := make([]any, 0, len(models))
|
||||
seen := map[string]struct{}{}
|
||||
for _, m := range models {
|
||||
if m == "" {
|
||||
continue
|
||||
}
|
||||
key := provider + "|" + m
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, map[string]any{"provider": provider, "model": m})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func mapValue(raw map[string]any, key string) map[string]any {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
v, ok := raw[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
switch m := v.(type) {
|
||||
case map[string]any:
|
||||
return m
|
||||
case map[any]any:
|
||||
return convertAnyMap(m)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertAnyMap(in map[any]any) map[string]any {
|
||||
out := make(map[string]any, len(in))
|
||||
keys := make([]string, 0, len(in))
|
||||
for k, v := range in {
|
||||
ks, ok := k.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, ks)
|
||||
out[ks] = v
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return out
|
||||
}
|
||||
|
||||
func stringValue(raw map[string]any, key string) string {
|
||||
if raw == nil {
|
||||
return ""
|
||||
}
|
||||
v, ok := raw[key]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func intValue(raw map[string]any, key string) int {
|
||||
if raw == nil {
|
||||
return 0
|
||||
}
|
||||
switch n := raw[key].(type) {
|
||||
case int:
|
||||
return n
|
||||
case int64:
|
||||
return int(n)
|
||||
case float64:
|
||||
return int(n)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func floatValue(raw map[string]any, key string) float64 {
|
||||
if raw == nil {
|
||||
return 0
|
||||
}
|
||||
switch n := raw[key].(type) {
|
||||
case float64:
|
||||
return n
|
||||
case int:
|
||||
return float64(n)
|
||||
case int64:
|
||||
return float64(n)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func boolValue(raw map[string]any, key string) bool {
|
||||
if raw == nil {
|
||||
return false
|
||||
}
|
||||
if b, ok := raw[key].(bool); ok {
|
||||
return b
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func stringListValue(raw map[string]any, key string) []string {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
v, ok := raw[key]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
list, ok := v.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(list))
|
||||
for _, item := range list {
|
||||
if s, ok := item.(string); ok && s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func copyKeys(dst, src map[string]any, keys ...string) {
|
||||
if src == nil {
|
||||
return
|
||||
}
|
||||
for _, k := range keys {
|
||||
if v, ok := src[k]; ok {
|
||||
dst[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func copyHeaders(dst, src map[string]any, key string) {
|
||||
if src == nil {
|
||||
return
|
||||
}
|
||||
v, ok := src[key]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
switch headers := v.(type) {
|
||||
case map[string]any:
|
||||
if len(headers) == 0 {
|
||||
return
|
||||
}
|
||||
dst[key] = headers
|
||||
case map[any]any:
|
||||
if len(headers) == 0 {
|
||||
return
|
||||
}
|
||||
dst[key] = convertAnyMap(headers)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user