Some checks failed
CI / build-and-test (push) Failing after -32m22s
* Implement tests for migrating configuration from v1 to v2 for the litellm provider. * Validate the structure and values of the migrated configuration. * Ensure migration rejects newer versions of the configuration. fix(validate): enhance AI provider validation logic * Consolidate provider validation into a dedicated method. * Ensure at least one provider is specified and validate its type. * Check for required fields based on provider type. fix(mcpserver): update tool set to use new enrichment tool * Replace RetryMetadataTool with RetryEnrichmentTool in the ToolSet. fix(tools): refactor tools to use embedding and metadata runners * Update tools to utilize EmbeddingRunner and MetadataRunner instead of Provider. * Adjust method calls to align with the new runner interfaces.
140 lines
3.8 KiB
Go
140 lines
3.8 KiB
Go
package ai
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync"
|
|
"testing"
|
|
|
|
"git.warky.dev/wdevs/amcs/internal/config"
|
|
)
|
|
|
|
func TestEmbeddingRunnerFallsBackAndSkipsUnhealthyPrimary(t *testing.T) {
|
|
var (
|
|
mu sync.Mutex
|
|
primaryCalls int
|
|
)
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/embeddings" {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
var req struct {
|
|
Model string `json:"model"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
switch req.Model {
|
|
case "embed-primary":
|
|
mu.Lock()
|
|
primaryCalls++
|
|
mu.Unlock()
|
|
http.Error(w, "upstream down", http.StatusBadGateway)
|
|
case "embed-fallback":
|
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
"data": []map[string]any{{"embedding": []float32{0.1, 0.2, 0.3}}},
|
|
})
|
|
default:
|
|
http.Error(w, "unknown model", http.StatusBadRequest)
|
|
}
|
|
}))
|
|
defer srv.Close()
|
|
|
|
reg, err := NewRegistry(map[string]config.ProviderConfig{
|
|
"p1": {Type: "litellm", BaseURL: srv.URL, APIKey: "k1"},
|
|
"p2": {Type: "litellm", BaseURL: srv.URL, APIKey: "k2"},
|
|
}, srv.Client(), nil)
|
|
if err != nil {
|
|
t.Fatalf("NewRegistry() error = %v", err)
|
|
}
|
|
|
|
runner, err := NewEmbeddingRunner(reg, []config.RoleTarget{
|
|
{Provider: "p1", Model: "embed-primary"},
|
|
{Provider: "p2", Model: "embed-fallback"},
|
|
}, 3, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewEmbeddingRunner() error = %v", err)
|
|
}
|
|
|
|
res, err := runner.Embed(context.Background(), "hello")
|
|
if err != nil {
|
|
t.Fatalf("Embed() first call error = %v", err)
|
|
}
|
|
if res.Provider != "p2" || res.Model != "embed-fallback" {
|
|
t.Fatalf("Embed() first call target = %s/%s, want p2/embed-fallback", res.Provider, res.Model)
|
|
}
|
|
|
|
res, err = runner.Embed(context.Background(), "hello again")
|
|
if err != nil {
|
|
t.Fatalf("Embed() second call error = %v", err)
|
|
}
|
|
if res.Provider != "p2" || res.Model != "embed-fallback" {
|
|
t.Fatalf("Embed() second call target = %s/%s, want p2/embed-fallback", res.Provider, res.Model)
|
|
}
|
|
|
|
mu.Lock()
|
|
calls := primaryCalls
|
|
mu.Unlock()
|
|
if calls != 3 {
|
|
t.Fatalf("primary calls = %d, want 3 (first request retries 3x; second call should skip unhealthy primary)", calls)
|
|
}
|
|
}
|
|
|
|
func TestMetadataRunnerSummarizeFallsBack(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/chat/completions" {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
var req struct {
|
|
Model string `json:"model"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
switch req.Model {
|
|
case "sum-primary":
|
|
http.Error(w, "provider error", http.StatusBadGateway)
|
|
case "sum-fallback":
|
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
"choices": []map[string]any{{
|
|
"message": map[string]any{"role": "assistant", "content": "fallback summary"},
|
|
}},
|
|
})
|
|
default:
|
|
http.Error(w, "unknown model", http.StatusBadRequest)
|
|
}
|
|
}))
|
|
defer srv.Close()
|
|
|
|
reg, err := NewRegistry(map[string]config.ProviderConfig{
|
|
"p1": {Type: "litellm", BaseURL: srv.URL, APIKey: "k1"},
|
|
"p2": {Type: "litellm", BaseURL: srv.URL, APIKey: "k2"},
|
|
}, srv.Client(), nil)
|
|
if err != nil {
|
|
t.Fatalf("NewRegistry() error = %v", err)
|
|
}
|
|
|
|
runner, err := NewMetadataRunner(reg, []config.RoleTarget{
|
|
{Provider: "p1", Model: "sum-primary"},
|
|
{Provider: "p2", Model: "sum-fallback"},
|
|
}, 0.1, false, nil)
|
|
if err != nil {
|
|
t.Fatalf("NewMetadataRunner() error = %v", err)
|
|
}
|
|
|
|
summary, err := runner.Summarize(context.Background(), "system", "user")
|
|
if err != nil {
|
|
t.Fatalf("Summarize() error = %v", err)
|
|
}
|
|
if summary != "fallback summary" {
|
|
t.Fatalf("summary = %q, want %q", summary, "fallback summary")
|
|
}
|
|
}
|