Files
amcs/internal/ai/registry.go
Hein 14e218d784
Some checks failed
CI / build-and-test (push) Failing after -32m22s
test(config): add migration tests for litellm provider
* Implement tests for migrating configuration from v1 to v2 for the litellm provider.
* Validate the structure and values of the migrated configuration.
* Ensure migration rejects newer versions of the configuration.
fix(validate): enhance AI provider validation logic
* Consolidate provider validation into a dedicated method.
* Ensure at least one provider is specified and validate its type.
* Check for required fields based on provider type.
fix(mcpserver): update tool set to use new enrichment tool
* Replace RetryMetadataTool with RetryEnrichmentTool in the ToolSet.
fix(tools): refactor tools to use embedding and metadata runners
* Update tools to utilize EmbeddingRunner and MetadataRunner instead of Provider.
* Adjust method calls to align with the new runner interfaces.
2026-04-21 21:14:28 +02:00

97 lines
2.5 KiB
Go

package ai
import (
"fmt"
"log/slog"
"net/http"
"strings"
"git.warky.dev/wdevs/amcs/internal/ai/compat"
"git.warky.dev/wdevs/amcs/internal/config"
)
// Registry holds one compat.Client per named provider. Runners look up clients
// by provider name when walking a role chain.
type Registry struct {
clients map[string]*compat.Client
}
// NewRegistry builds a Registry from the configured providers. Each provider
// type maps onto a compat.Client with type-specific header plumbing (e.g.
// openrouter's HTTP-Referer / X-Title).
func NewRegistry(providers map[string]config.ProviderConfig, httpClient *http.Client, log *slog.Logger) (*Registry, error) {
if httpClient == nil {
return nil, fmt.Errorf("ai registry: http client is required")
}
if len(providers) == 0 {
return nil, fmt.Errorf("ai registry: no providers configured")
}
clients := make(map[string]*compat.Client, len(providers))
for name, p := range providers {
headers, err := providerHeaders(p)
if err != nil {
return nil, fmt.Errorf("ai registry: provider %q: %w", name, err)
}
clients[name] = compat.New(compat.Config{
Name: name,
BaseURL: p.BaseURL,
APIKey: p.APIKey,
Headers: headers,
HTTPClient: httpClient,
Log: log,
})
}
return &Registry{clients: clients}, nil
}
// Client returns the compat.Client registered under name.
func (r *Registry) Client(name string) (*compat.Client, error) {
c, ok := r.clients[name]
if !ok {
return nil, fmt.Errorf("ai registry: provider %q is not configured", name)
}
return c, nil
}
// Names returns the registered provider names.
func (r *Registry) Names() []string {
names := make([]string, 0, len(r.clients))
for name := range r.clients {
names = append(names, name)
}
return names
}
func providerHeaders(p config.ProviderConfig) (map[string]string, error) {
switch p.Type {
case "litellm", "ollama":
return cloneHeaders(p.RequestHeaders), nil
case "openrouter":
headers := cloneHeaders(p.RequestHeaders)
if headers == nil {
headers = map[string]string{}
}
if s := strings.TrimSpace(p.SiteURL); s != "" {
headers["HTTP-Referer"] = s
}
if s := strings.TrimSpace(p.AppName); s != "" {
headers["X-Title"] = s
}
return headers, nil
default:
return nil, fmt.Errorf("unsupported provider type %q", p.Type)
}
}
func cloneHeaders(in map[string]string) map[string]string {
if len(in) == 0 {
return nil
}
out := make(map[string]string, len(in))
for k, v := range in {
out[k] = v
}
return out
}