test(config): add migration tests for litellm provider
Some checks failed
CI / build-and-test (push) Failing after -32m22s
Some checks failed
CI / build-and-test (push) Failing after -32m22s
* Implement tests for migrating configuration from v1 to v2 for the litellm provider. * Validate the structure and values of the migrated configuration. * Ensure migration rejects newer versions of the configuration. fix(validate): enhance AI provider validation logic * Consolidate provider validation into a dedicated method. * Ensure at least one provider is specified and validate its type. * Check for required fields based on provider type. fix(mcpserver): update tool set to use new enrichment tool * Replace RetryMetadataTool with RetryEnrichmentTool in the ToolSet. fix(tools): refactor tools to use embedding and metadata runners * Update tools to utilize EmbeddingRunner and MetadataRunner instead of Provider. * Adjust method calls to align with the new runner interfaces.
This commit is contained in:
139
internal/ai/runner_test.go
Normal file
139
internal/ai/runner_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
)
|
||||
|
||||
func TestEmbeddingRunnerFallsBackAndSkipsUnhealthyPrimary(t *testing.T) {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
primaryCalls int
|
||||
)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/embeddings" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
switch req.Model {
|
||||
case "embed-primary":
|
||||
mu.Lock()
|
||||
primaryCalls++
|
||||
mu.Unlock()
|
||||
http.Error(w, "upstream down", http.StatusBadGateway)
|
||||
case "embed-fallback":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"data": []map[string]any{{"embedding": []float32{0.1, 0.2, 0.3}}},
|
||||
})
|
||||
default:
|
||||
http.Error(w, "unknown model", http.StatusBadRequest)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
reg, err := NewRegistry(map[string]config.ProviderConfig{
|
||||
"p1": {Type: "litellm", BaseURL: srv.URL, APIKey: "k1"},
|
||||
"p2": {Type: "litellm", BaseURL: srv.URL, APIKey: "k2"},
|
||||
}, srv.Client(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRegistry() error = %v", err)
|
||||
}
|
||||
|
||||
runner, err := NewEmbeddingRunner(reg, []config.RoleTarget{
|
||||
{Provider: "p1", Model: "embed-primary"},
|
||||
{Provider: "p2", Model: "embed-fallback"},
|
||||
}, 3, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewEmbeddingRunner() error = %v", err)
|
||||
}
|
||||
|
||||
res, err := runner.Embed(context.Background(), "hello")
|
||||
if err != nil {
|
||||
t.Fatalf("Embed() first call error = %v", err)
|
||||
}
|
||||
if res.Provider != "p2" || res.Model != "embed-fallback" {
|
||||
t.Fatalf("Embed() first call target = %s/%s, want p2/embed-fallback", res.Provider, res.Model)
|
||||
}
|
||||
|
||||
res, err = runner.Embed(context.Background(), "hello again")
|
||||
if err != nil {
|
||||
t.Fatalf("Embed() second call error = %v", err)
|
||||
}
|
||||
if res.Provider != "p2" || res.Model != "embed-fallback" {
|
||||
t.Fatalf("Embed() second call target = %s/%s, want p2/embed-fallback", res.Provider, res.Model)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
calls := primaryCalls
|
||||
mu.Unlock()
|
||||
if calls != 3 {
|
||||
t.Fatalf("primary calls = %d, want 3 (first request retries 3x; second call should skip unhealthy primary)", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetadataRunnerSummarizeFallsBack(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/chat/completions" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
switch req.Model {
|
||||
case "sum-primary":
|
||||
http.Error(w, "provider error", http.StatusBadGateway)
|
||||
case "sum-fallback":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{{
|
||||
"message": map[string]any{"role": "assistant", "content": "fallback summary"},
|
||||
}},
|
||||
})
|
||||
default:
|
||||
http.Error(w, "unknown model", http.StatusBadRequest)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
reg, err := NewRegistry(map[string]config.ProviderConfig{
|
||||
"p1": {Type: "litellm", BaseURL: srv.URL, APIKey: "k1"},
|
||||
"p2": {Type: "litellm", BaseURL: srv.URL, APIKey: "k2"},
|
||||
}, srv.Client(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRegistry() error = %v", err)
|
||||
}
|
||||
|
||||
runner, err := NewMetadataRunner(reg, []config.RoleTarget{
|
||||
{Provider: "p1", Model: "sum-primary"},
|
||||
{Provider: "p2", Model: "sum-fallback"},
|
||||
}, 0.1, false, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewMetadataRunner() error = %v", err)
|
||||
}
|
||||
|
||||
summary, err := runner.Summarize(context.Background(), "system", "user")
|
||||
if err != nil {
|
||||
t.Fatalf("Summarize() error = %v", err)
|
||||
}
|
||||
if summary != "fallback summary" {
|
||||
t.Fatalf("summary = %q, want %q", summary, "fallback summary")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user