Some checks failed
CI / build-and-test (push) Failing after -32m22s
* Implement tests for migrating configuration from v1 to v2 for the litellm provider. * Validate the structure and values of the migrated configuration. * Ensure migration rejects newer versions of the configuration. fix(validate): enhance AI provider validation logic * Consolidate provider validation into a dedicated method. * Ensure at least one provider is specified and validate its type. * Check for required fields based on provider type. fix(mcpserver): update tool set to use new enrichment tool * Replace RetryMetadataTool with RetryEnrichmentTool in the ToolSet. fix(tools): refactor tools to use embedding and metadata runners * Update tools to utilize EmbeddingRunner and MetadataRunner instead of Provider. * Adjust method calls to align with the new runner interfaces.
97 lines
2.5 KiB
Go
97 lines
2.5 KiB
Go
package ai
|
|
|
|
import (
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"git.warky.dev/wdevs/amcs/internal/ai/compat"
|
|
"git.warky.dev/wdevs/amcs/internal/config"
|
|
)
|
|
|
|
// Registry holds one compat.Client per named provider. Runners look up clients
|
|
// by provider name when walking a role chain.
|
|
type Registry struct {
|
|
clients map[string]*compat.Client
|
|
}
|
|
|
|
// NewRegistry builds a Registry from the configured providers. Each provider
|
|
// type maps onto a compat.Client with type-specific header plumbing (e.g.
|
|
// openrouter's HTTP-Referer / X-Title).
|
|
func NewRegistry(providers map[string]config.ProviderConfig, httpClient *http.Client, log *slog.Logger) (*Registry, error) {
|
|
if httpClient == nil {
|
|
return nil, fmt.Errorf("ai registry: http client is required")
|
|
}
|
|
if len(providers) == 0 {
|
|
return nil, fmt.Errorf("ai registry: no providers configured")
|
|
}
|
|
|
|
clients := make(map[string]*compat.Client, len(providers))
|
|
for name, p := range providers {
|
|
headers, err := providerHeaders(p)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ai registry: provider %q: %w", name, err)
|
|
}
|
|
clients[name] = compat.New(compat.Config{
|
|
Name: name,
|
|
BaseURL: p.BaseURL,
|
|
APIKey: p.APIKey,
|
|
Headers: headers,
|
|
HTTPClient: httpClient,
|
|
Log: log,
|
|
})
|
|
}
|
|
return &Registry{clients: clients}, nil
|
|
}
|
|
|
|
// Client returns the compat.Client registered under name.
|
|
func (r *Registry) Client(name string) (*compat.Client, error) {
|
|
c, ok := r.clients[name]
|
|
if !ok {
|
|
return nil, fmt.Errorf("ai registry: provider %q is not configured", name)
|
|
}
|
|
return c, nil
|
|
}
|
|
|
|
// Names returns the registered provider names.
|
|
func (r *Registry) Names() []string {
|
|
names := make([]string, 0, len(r.clients))
|
|
for name := range r.clients {
|
|
names = append(names, name)
|
|
}
|
|
return names
|
|
}
|
|
|
|
func providerHeaders(p config.ProviderConfig) (map[string]string, error) {
|
|
switch p.Type {
|
|
case "litellm", "ollama":
|
|
return cloneHeaders(p.RequestHeaders), nil
|
|
case "openrouter":
|
|
headers := cloneHeaders(p.RequestHeaders)
|
|
if headers == nil {
|
|
headers = map[string]string{}
|
|
}
|
|
if s := strings.TrimSpace(p.SiteURL); s != "" {
|
|
headers["HTTP-Referer"] = s
|
|
}
|
|
if s := strings.TrimSpace(p.AppName); s != "" {
|
|
headers["X-Title"] = s
|
|
}
|
|
return headers, nil
|
|
default:
|
|
return nil, fmt.Errorf("unsupported provider type %q", p.Type)
|
|
}
|
|
}
|
|
|
|
func cloneHeaders(in map[string]string) map[string]string {
|
|
if len(in) == 0 {
|
|
return nil
|
|
}
|
|
out := make(map[string]string, len(in))
|
|
for k, v := range in {
|
|
out[k] = v
|
|
}
|
|
return out
|
|
}
|