Some checks failed
CI / build-and-test (push) Failing after -32m22s
* Implement tests for migrating configuration from v1 to v2 for the litellm provider. * Validate the structure and values of the migrated configuration. * Ensure migration rejects newer versions of the configuration. fix(validate): enhance AI provider validation logic * Consolidate provider validation into a dedicated method. * Ensure at least one provider is specified and validate its type. * Check for required fields based on provider type. fix(mcpserver): update tool set to use new enrichment tool * Replace RetryMetadataTool with RetryEnrichmentTool in the ToolSet. fix(tools): refactor tools to use embedding and metadata runners * Update tools to utilize EmbeddingRunner and MetadataRunner instead of Provider. * Adjust method calls to align with the new runner interfaces.
342 lines
7.8 KiB
Go
342 lines
7.8 KiB
Go
package config
|
|
|
|
import (
|
|
"fmt"
|
|
"sort"
|
|
)
|
|
|
|
// CurrentConfigVersion is the schema version this binary expects. Files at a
|
|
// lower version are migrated automatically when loaded.
|
|
const CurrentConfigVersion = 2
|
|
|
|
// ConfigMigration upgrades a raw YAML map by one version.
|
|
type ConfigMigration struct {
|
|
From, To int
|
|
Describe string
|
|
Apply func(map[string]any) error
|
|
}
|
|
|
|
// migrations is the ordered ladder of upgrades. Add new entries at the end.
|
|
var migrations = []ConfigMigration{
|
|
{From: 1, To: 2, Describe: "named providers + role chains", Apply: migrateV1toV2},
|
|
}
|
|
|
|
// Migrate brings raw up to CurrentConfigVersion in place. Returns the list of
|
|
// migrations that were applied (may be empty if already current).
|
|
func Migrate(raw map[string]any) ([]ConfigMigration, error) {
|
|
if raw == nil {
|
|
return nil, fmt.Errorf("migrate: raw config is nil")
|
|
}
|
|
|
|
version := readVersion(raw)
|
|
if version > CurrentConfigVersion {
|
|
return nil, fmt.Errorf("migrate: config version %d is newer than supported version %d", version, CurrentConfigVersion)
|
|
}
|
|
|
|
applied := make([]ConfigMigration, 0)
|
|
for {
|
|
if version >= CurrentConfigVersion {
|
|
break
|
|
}
|
|
step, ok := findMigration(version)
|
|
if !ok {
|
|
return nil, fmt.Errorf("migrate: no migration registered from version %d", version)
|
|
}
|
|
if err := step.Apply(raw); err != nil {
|
|
return nil, fmt.Errorf("migrate v%d->v%d: %w", step.From, step.To, err)
|
|
}
|
|
raw["version"] = step.To
|
|
version = step.To
|
|
applied = append(applied, step)
|
|
}
|
|
return applied, nil
|
|
}
|
|
|
|
func findMigration(from int) (ConfigMigration, bool) {
|
|
for _, m := range migrations {
|
|
if m.From == from {
|
|
return m, true
|
|
}
|
|
}
|
|
return ConfigMigration{}, false
|
|
}
|
|
|
|
// readVersion returns the version from raw. Files without a version field are
|
|
// treated as version 1 (the original schema).
|
|
func readVersion(raw map[string]any) int {
|
|
v, ok := raw["version"]
|
|
if !ok {
|
|
return 1
|
|
}
|
|
switch n := v.(type) {
|
|
case int:
|
|
return n
|
|
case int64:
|
|
return int(n)
|
|
case float64:
|
|
return int(n)
|
|
}
|
|
return 1
|
|
}
|
|
|
|
// migrateV1toV2 lifts the single-provider config into the named-providers +
|
|
// role-chains layout. The pre-v2 config implicitly used one provider for both
|
|
// embeddings and metadata; we materialise that as a provider named "default".
|
|
func migrateV1toV2(raw map[string]any) error {
|
|
aiRaw := mapValue(raw, "ai")
|
|
if aiRaw == nil {
|
|
aiRaw = map[string]any{}
|
|
}
|
|
|
|
providerType := stringValue(aiRaw, "provider")
|
|
if providerType == "" {
|
|
providerType = "litellm"
|
|
}
|
|
|
|
providers, embeddingModel, metadataModel, fallbackModels := buildV1Provider(aiRaw, providerType)
|
|
|
|
embeddingsOld := mapValue(aiRaw, "embeddings")
|
|
dimensions := intValue(embeddingsOld, "dimensions")
|
|
if dimensions <= 0 {
|
|
dimensions = 1536
|
|
}
|
|
if embeddingModel == "" {
|
|
embeddingModel = stringValue(embeddingsOld, "model")
|
|
}
|
|
|
|
metadataOld := mapValue(aiRaw, "metadata")
|
|
if metadataModel == "" {
|
|
metadataModel = stringValue(metadataOld, "model")
|
|
}
|
|
temperature := floatValue(metadataOld, "temperature")
|
|
logConversations := boolValue(metadataOld, "log_conversations")
|
|
timeoutStr := stringValue(metadataOld, "timeout")
|
|
|
|
if list := stringListValue(metadataOld, "fallback_models"); len(list) > 0 {
|
|
fallbackModels = append(fallbackModels, list...)
|
|
}
|
|
if v := stringValue(metadataOld, "fallback_model"); v != "" {
|
|
fallbackModels = append(fallbackModels, v)
|
|
}
|
|
|
|
embeddings := map[string]any{
|
|
"dimensions": dimensions,
|
|
"primary": map[string]any{"provider": "default", "model": embeddingModel},
|
|
}
|
|
|
|
metadata := map[string]any{
|
|
"temperature": temperature,
|
|
"log_conversations": logConversations,
|
|
"primary": map[string]any{"provider": "default", "model": metadataModel},
|
|
}
|
|
if timeoutStr != "" {
|
|
metadata["timeout"] = timeoutStr
|
|
}
|
|
if fallbacks := chainTargets("default", fallbackModels); len(fallbacks) > 0 {
|
|
metadata["fallbacks"] = fallbacks
|
|
}
|
|
|
|
raw["ai"] = map[string]any{
|
|
"providers": providers,
|
|
"embeddings": embeddings,
|
|
"metadata": metadata,
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func buildV1Provider(aiRaw map[string]any, providerType string) (map[string]any, string, string, []string) {
|
|
providers := map[string]any{}
|
|
defaultEntry := map[string]any{"type": providerType}
|
|
embedModel := ""
|
|
metaModel := ""
|
|
var fallbacks []string
|
|
|
|
switch providerType {
|
|
case "litellm":
|
|
block := mapValue(aiRaw, "litellm")
|
|
copyKeys(defaultEntry, block, "base_url", "api_key")
|
|
copyHeaders(defaultEntry, block, "request_headers")
|
|
embedModel = stringValue(block, "embedding_model")
|
|
metaModel = stringValue(block, "metadata_model")
|
|
if list := stringListValue(block, "fallback_metadata_models"); len(list) > 0 {
|
|
fallbacks = append(fallbacks, list...)
|
|
}
|
|
if v := stringValue(block, "fallback_metadata_model"); v != "" {
|
|
fallbacks = append(fallbacks, v)
|
|
}
|
|
case "ollama":
|
|
block := mapValue(aiRaw, "ollama")
|
|
copyKeys(defaultEntry, block, "base_url", "api_key")
|
|
copyHeaders(defaultEntry, block, "request_headers")
|
|
case "openrouter":
|
|
block := mapValue(aiRaw, "openrouter")
|
|
copyKeys(defaultEntry, block, "base_url", "api_key", "app_name", "site_url")
|
|
copyHeaders(defaultEntry, block, "extra_headers")
|
|
// rename: extra_headers → request_headers
|
|
if hdr, ok := defaultEntry["extra_headers"]; ok {
|
|
defaultEntry["request_headers"] = hdr
|
|
delete(defaultEntry, "extra_headers")
|
|
}
|
|
}
|
|
|
|
providers["default"] = defaultEntry
|
|
return providers, embedModel, metaModel, fallbacks
|
|
}
|
|
|
|
func chainTargets(provider string, models []string) []any {
|
|
out := make([]any, 0, len(models))
|
|
seen := map[string]struct{}{}
|
|
for _, m := range models {
|
|
if m == "" {
|
|
continue
|
|
}
|
|
key := provider + "|" + m
|
|
if _, ok := seen[key]; ok {
|
|
continue
|
|
}
|
|
seen[key] = struct{}{}
|
|
out = append(out, map[string]any{"provider": provider, "model": m})
|
|
}
|
|
return out
|
|
}
|
|
|
|
func mapValue(raw map[string]any, key string) map[string]any {
|
|
if raw == nil {
|
|
return nil
|
|
}
|
|
v, ok := raw[key]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
switch m := v.(type) {
|
|
case map[string]any:
|
|
return m
|
|
case map[any]any:
|
|
return convertAnyMap(m)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func convertAnyMap(in map[any]any) map[string]any {
|
|
out := make(map[string]any, len(in))
|
|
keys := make([]string, 0, len(in))
|
|
for k, v := range in {
|
|
ks, ok := k.(string)
|
|
if !ok {
|
|
continue
|
|
}
|
|
keys = append(keys, ks)
|
|
out[ks] = v
|
|
}
|
|
sort.Strings(keys)
|
|
return out
|
|
}
|
|
|
|
func stringValue(raw map[string]any, key string) string {
|
|
if raw == nil {
|
|
return ""
|
|
}
|
|
v, ok := raw[key]
|
|
if !ok {
|
|
return ""
|
|
}
|
|
if s, ok := v.(string); ok {
|
|
return s
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func intValue(raw map[string]any, key string) int {
|
|
if raw == nil {
|
|
return 0
|
|
}
|
|
switch n := raw[key].(type) {
|
|
case int:
|
|
return n
|
|
case int64:
|
|
return int(n)
|
|
case float64:
|
|
return int(n)
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func floatValue(raw map[string]any, key string) float64 {
|
|
if raw == nil {
|
|
return 0
|
|
}
|
|
switch n := raw[key].(type) {
|
|
case float64:
|
|
return n
|
|
case int:
|
|
return float64(n)
|
|
case int64:
|
|
return float64(n)
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func boolValue(raw map[string]any, key string) bool {
|
|
if raw == nil {
|
|
return false
|
|
}
|
|
if b, ok := raw[key].(bool); ok {
|
|
return b
|
|
}
|
|
return false
|
|
}
|
|
|
|
func stringListValue(raw map[string]any, key string) []string {
|
|
if raw == nil {
|
|
return nil
|
|
}
|
|
v, ok := raw[key]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
list, ok := v.([]any)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
out := make([]string, 0, len(list))
|
|
for _, item := range list {
|
|
if s, ok := item.(string); ok && s != "" {
|
|
out = append(out, s)
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func copyKeys(dst, src map[string]any, keys ...string) {
|
|
if src == nil {
|
|
return
|
|
}
|
|
for _, k := range keys {
|
|
if v, ok := src[k]; ok {
|
|
dst[k] = v
|
|
}
|
|
}
|
|
}
|
|
|
|
func copyHeaders(dst, src map[string]any, key string) {
|
|
if src == nil {
|
|
return
|
|
}
|
|
v, ok := src[key]
|
|
if !ok {
|
|
return
|
|
}
|
|
switch headers := v.(type) {
|
|
case map[string]any:
|
|
if len(headers) == 0 {
|
|
return
|
|
}
|
|
dst[key] = headers
|
|
case map[any]any:
|
|
if len(headers) == 0 {
|
|
return
|
|
}
|
|
dst[key] = convertAnyMap(headers)
|
|
}
|
|
}
|