package ai import ( "context" "encoding/json" "net/http" "net/http/httptest" "sync" "testing" "git.warky.dev/wdevs/amcs/internal/config" ) func TestEmbeddingRunnerFallsBackAndSkipsUnhealthyPrimary(t *testing.T) { var ( mu sync.Mutex primaryCalls int ) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/embeddings" { http.NotFound(w, r) return } var req struct { Model string `json:"model"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } switch req.Model { case "embed-primary": mu.Lock() primaryCalls++ mu.Unlock() http.Error(w, "upstream down", http.StatusBadGateway) case "embed-fallback": _ = json.NewEncoder(w).Encode(map[string]any{ "data": []map[string]any{{"embedding": []float32{0.1, 0.2, 0.3}}}, }) default: http.Error(w, "unknown model", http.StatusBadRequest) } })) defer srv.Close() reg, err := NewRegistry(map[string]config.ProviderConfig{ "p1": {Type: "litellm", BaseURL: srv.URL, APIKey: "k1"}, "p2": {Type: "litellm", BaseURL: srv.URL, APIKey: "k2"}, }, srv.Client(), nil) if err != nil { t.Fatalf("NewRegistry() error = %v", err) } runner, err := NewEmbeddingRunner(reg, []config.RoleTarget{ {Provider: "p1", Model: "embed-primary"}, {Provider: "p2", Model: "embed-fallback"}, }, 3, nil) if err != nil { t.Fatalf("NewEmbeddingRunner() error = %v", err) } res, err := runner.Embed(context.Background(), "hello") if err != nil { t.Fatalf("Embed() first call error = %v", err) } if res.Provider != "p2" || res.Model != "embed-fallback" { t.Fatalf("Embed() first call target = %s/%s, want p2/embed-fallback", res.Provider, res.Model) } res, err = runner.Embed(context.Background(), "hello again") if err != nil { t.Fatalf("Embed() second call error = %v", err) } if res.Provider != "p2" || res.Model != "embed-fallback" { t.Fatalf("Embed() second call target = %s/%s, want p2/embed-fallback", res.Provider, res.Model) } mu.Lock() calls := primaryCalls mu.Unlock() if calls != 3 { t.Fatalf("primary calls = %d, want 3 (first request retries 3x; second call should skip unhealthy primary)", calls) } } func TestMetadataRunnerSummarizeFallsBack(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/chat/completions" { http.NotFound(w, r) return } var req struct { Model string `json:"model"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } switch req.Model { case "sum-primary": http.Error(w, "provider error", http.StatusBadGateway) case "sum-fallback": _ = json.NewEncoder(w).Encode(map[string]any{ "choices": []map[string]any{{ "message": map[string]any{"role": "assistant", "content": "fallback summary"}, }}, }) default: http.Error(w, "unknown model", http.StatusBadRequest) } })) defer srv.Close() reg, err := NewRegistry(map[string]config.ProviderConfig{ "p1": {Type: "litellm", BaseURL: srv.URL, APIKey: "k1"}, "p2": {Type: "litellm", BaseURL: srv.URL, APIKey: "k2"}, }, srv.Client(), nil) if err != nil { t.Fatalf("NewRegistry() error = %v", err) } runner, err := NewMetadataRunner(reg, []config.RoleTarget{ {Provider: "p1", Model: "sum-primary"}, {Provider: "p2", Model: "sum-fallback"}, }, 0.1, false, nil) if err != nil { t.Fatalf("NewMetadataRunner() error = %v", err) } summary, err := runner.Summarize(context.Background(), "system", "user") if err != nil { t.Fatalf("Summarize() error = %v", err) } if summary != "fallback summary" { t.Fatalf("summary = %q, want %q", summary, "fallback summary") } }