test(config): add migration tests for litellm provider
Some checks failed
CI / build-and-test (push) Failing after -32m22s
Some checks failed
CI / build-and-test (push) Failing after -32m22s
* Implement tests for migrating configuration from v1 to v2 for the litellm provider. * Validate the structure and values of the migrated configuration. * Ensure migration rejects newer versions of the configuration. fix(validate): enhance AI provider validation logic * Consolidate provider validation into a dedicated method. * Ensure at least one provider is specified and validate its type. * Check for required fields based on provider type. fix(mcpserver): update tool set to use new enrichment tool * Replace RetryMetadataTool with RetryEnrichmentTool in the ToolSet. fix(tools): refactor tools to use embedding and metadata runners * Update tools to utilize EmbeddingRunner and MetadataRunner instead of Provider. * Adjust method calls to align with the new runner interfaces.
This commit is contained in:
367
internal/ai/runner.go
Normal file
367
internal/ai/runner.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai/compat"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
// Health TTLs per failure class. These are short enough that a healed target
|
||||
// gets retried without manual intervention, but long enough to avoid hammering
|
||||
// a broken provider every call.
|
||||
const (
|
||||
transientCooldown = 30 * time.Second
|
||||
permanentCooldown = 10 * time.Minute
|
||||
emptyResponseThreshold = 3
|
||||
emptyResponseCooldown = 2 * time.Minute
|
||||
dimensionMismatchWarning = "embedding dimension mismatch"
|
||||
)
|
||||
|
||||
// EmbedResult carries the vector plus the (provider, model) that produced it —
|
||||
// callers store the actual model so later searches against that row use the
|
||||
// matching query embedding.
|
||||
type EmbedResult struct {
|
||||
Vector []float32
|
||||
Provider string
|
||||
Model string
|
||||
}
|
||||
|
||||
// EmbeddingRunner executes the embeddings role chain with sequential fallback.
|
||||
type EmbeddingRunner struct {
|
||||
registry *Registry
|
||||
chain []config.RoleTarget
|
||||
dimensions int
|
||||
health *healthTracker
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
// MetadataRunner executes the metadata role chain with sequential fallback and
|
||||
// a heuristic fallthrough when every target is unhealthy or fails.
|
||||
type MetadataRunner struct {
|
||||
registry *Registry
|
||||
chain []config.RoleTarget
|
||||
opts metadataRunOpts
|
||||
health *healthTracker
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
type metadataRunOpts struct {
|
||||
temperature float64
|
||||
logConversations bool
|
||||
}
|
||||
|
||||
// NewEmbeddingRunner builds a runner for the embeddings role. chain must be
|
||||
// non-empty and every target must be registered.
|
||||
func NewEmbeddingRunner(registry *Registry, chain []config.RoleTarget, dimensions int, log *slog.Logger) (*EmbeddingRunner, error) {
|
||||
if registry == nil {
|
||||
return nil, fmt.Errorf("embedding runner: registry is required")
|
||||
}
|
||||
if len(chain) == 0 {
|
||||
return nil, fmt.Errorf("embedding runner: chain is empty")
|
||||
}
|
||||
if dimensions <= 0 {
|
||||
return nil, fmt.Errorf("embedding runner: dimensions must be > 0")
|
||||
}
|
||||
for i, t := range chain {
|
||||
if _, err := registry.Client(t.Provider); err != nil {
|
||||
return nil, fmt.Errorf("embedding runner: chain[%d]: %w", i, err)
|
||||
}
|
||||
}
|
||||
return &EmbeddingRunner{
|
||||
registry: registry,
|
||||
chain: chain,
|
||||
dimensions: dimensions,
|
||||
health: newHealthTracker(),
|
||||
log: log,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewMetadataRunner builds a runner for the metadata role.
|
||||
func NewMetadataRunner(registry *Registry, chain []config.RoleTarget, temperature float64, logConversations bool, log *slog.Logger) (*MetadataRunner, error) {
|
||||
if registry == nil {
|
||||
return nil, fmt.Errorf("metadata runner: registry is required")
|
||||
}
|
||||
if len(chain) == 0 {
|
||||
return nil, fmt.Errorf("metadata runner: chain is empty")
|
||||
}
|
||||
for i, t := range chain {
|
||||
if _, err := registry.Client(t.Provider); err != nil {
|
||||
return nil, fmt.Errorf("metadata runner: chain[%d]: %w", i, err)
|
||||
}
|
||||
}
|
||||
return &MetadataRunner{
|
||||
registry: registry,
|
||||
chain: chain,
|
||||
opts: metadataRunOpts{
|
||||
temperature: temperature,
|
||||
logConversations: logConversations,
|
||||
},
|
||||
health: newHealthTracker(),
|
||||
log: log,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PrimaryProvider returns the first provider in the chain.
|
||||
func (r *EmbeddingRunner) PrimaryProvider() string { return r.chain[0].Provider }
|
||||
|
||||
// PrimaryModel returns the first model in the chain — the one used as the
|
||||
// storage key for search matching.
|
||||
func (r *EmbeddingRunner) PrimaryModel() string { return r.chain[0].Model }
|
||||
|
||||
// Dimensions returns the required vector dimension.
|
||||
func (r *EmbeddingRunner) Dimensions() int { return r.dimensions }
|
||||
|
||||
// Embed walks the chain and returns the first successful embedding. The
|
||||
// returned EmbedResult names the actual (provider, model) that produced the
|
||||
// vector — callers use that when recording the row.
|
||||
func (r *EmbeddingRunner) Embed(ctx context.Context, input string) (EmbedResult, error) {
|
||||
var errs []error
|
||||
for _, target := range r.chain {
|
||||
if r.health.skip(target) {
|
||||
continue
|
||||
}
|
||||
client, err := r.registry.Client(target.Provider)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
vec, err := client.EmbedWith(ctx, target.Model, input)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return EmbedResult{}, ctx.Err()
|
||||
}
|
||||
r.classify(target, err)
|
||||
r.logFailure("embed", target, err)
|
||||
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err))
|
||||
continue
|
||||
}
|
||||
if len(vec) != r.dimensions {
|
||||
dimErr := fmt.Errorf("%s: expected %d, got %d", dimensionMismatchWarning, r.dimensions, len(vec))
|
||||
r.health.markTransient(target)
|
||||
r.logFailure("embed", target, dimErr)
|
||||
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, dimErr))
|
||||
continue
|
||||
}
|
||||
r.health.markHealthy(target)
|
||||
return EmbedResult{Vector: vec, Provider: target.Provider, Model: target.Model}, nil
|
||||
}
|
||||
return EmbedResult{}, fmt.Errorf("all embedding targets failed: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
// EmbedPrimary embeds using only the primary target — used for search queries
|
||||
// so the query vector matches rows stored under the primary model. Falls back
|
||||
// to returning the error without walking the chain.
|
||||
func (r *EmbeddingRunner) EmbedPrimary(ctx context.Context, input string) ([]float32, error) {
|
||||
target := r.chain[0]
|
||||
client, err := r.registry.Client(target.Provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vec, err := client.EmbedWith(ctx, target.Model, input)
|
||||
if err != nil {
|
||||
r.classify(target, err)
|
||||
return nil, err
|
||||
}
|
||||
if len(vec) != r.dimensions {
|
||||
return nil, fmt.Errorf("%s: expected %d, got %d", dimensionMismatchWarning, r.dimensions, len(vec))
|
||||
}
|
||||
r.health.markHealthy(target)
|
||||
return vec, nil
|
||||
}
|
||||
|
||||
// PrimaryProvider / PrimaryModel for metadata mirror the embedding runner.
|
||||
func (r *MetadataRunner) PrimaryProvider() string { return r.chain[0].Provider }
|
||||
func (r *MetadataRunner) PrimaryModel() string { return r.chain[0].Model }
|
||||
|
||||
// ExtractMetadata walks the chain sequentially. If every target fails or is
|
||||
// unhealthy, it returns a heuristic metadata so capture never hard-fails.
|
||||
func (r *MetadataRunner) ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) {
|
||||
var errs []error
|
||||
for _, target := range r.chain {
|
||||
if r.health.skip(target) {
|
||||
continue
|
||||
}
|
||||
client, err := r.registry.Client(target.Provider)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
md, err := client.ExtractMetadataWith(ctx, compat.MetadataOptions{
|
||||
Model: target.Model,
|
||||
Temperature: r.opts.temperature,
|
||||
LogConversations: r.opts.logConversations,
|
||||
}, input)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return thoughttypes.ThoughtMetadata{}, ctx.Err()
|
||||
}
|
||||
r.classify(target, err)
|
||||
r.logFailure("metadata", target, err)
|
||||
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err))
|
||||
continue
|
||||
}
|
||||
r.health.markHealthy(target)
|
||||
return md, nil
|
||||
}
|
||||
if r.log != nil {
|
||||
r.log.Warn("metadata chain exhausted, using heuristic fallback",
|
||||
slog.Int("targets", len(r.chain)),
|
||||
slog.String("error", errors.Join(errs...).Error()),
|
||||
)
|
||||
}
|
||||
return compat.HeuristicMetadataFromInput(input), nil
|
||||
}
|
||||
|
||||
// Summarize walks the chain; unlike metadata, there is no heuristic fallback —
|
||||
// returns the joined error when everything fails.
|
||||
func (r *MetadataRunner) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
|
||||
var errs []error
|
||||
for _, target := range r.chain {
|
||||
if r.health.skip(target) {
|
||||
continue
|
||||
}
|
||||
client, err := r.registry.Client(target.Provider)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
out, err := client.SummarizeWith(ctx, compat.SummarizeOptions{
|
||||
Model: target.Model,
|
||||
Temperature: r.opts.temperature,
|
||||
}, systemPrompt, userPrompt)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return "", ctx.Err()
|
||||
}
|
||||
r.classify(target, err)
|
||||
r.logFailure("summarize", target, err)
|
||||
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err))
|
||||
continue
|
||||
}
|
||||
r.health.markHealthy(target)
|
||||
return out, nil
|
||||
}
|
||||
return "", fmt.Errorf("all summarize targets failed: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (r *EmbeddingRunner) classify(target config.RoleTarget, err error) {
|
||||
switch {
|
||||
case compat.IsPermanentModelError(err):
|
||||
r.health.markPermanent(target)
|
||||
default:
|
||||
r.health.markTransient(target)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MetadataRunner) classify(target config.RoleTarget, err error) {
|
||||
switch {
|
||||
case compat.IsPermanentModelError(err):
|
||||
r.health.markPermanent(target)
|
||||
case errors.Is(err, compat.ErrEmptyResponse):
|
||||
r.health.markEmpty(target)
|
||||
default:
|
||||
r.health.markTransient(target)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *EmbeddingRunner) logFailure(role string, target config.RoleTarget, err error) {
|
||||
if r.log == nil {
|
||||
return
|
||||
}
|
||||
r.log.Warn("ai target failed",
|
||||
slog.String("role", role),
|
||||
slog.String("provider", target.Provider),
|
||||
slog.String("model", target.Model),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
|
||||
func (r *MetadataRunner) logFailure(role string, target config.RoleTarget, err error) {
|
||||
if r.log == nil {
|
||||
return
|
||||
}
|
||||
r.log.Warn("ai target failed",
|
||||
slog.String("role", role),
|
||||
slog.String("provider", target.Provider),
|
||||
slog.String("model", target.Model),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
|
||||
// healthTracker records per-(provider, model) failure state. skip returns true
|
||||
// when a target is still inside its cooldown window; the caller then tries the
|
||||
// next target in the chain.
|
||||
type healthTracker struct {
|
||||
mu sync.Mutex
|
||||
states map[config.RoleTarget]*healthState
|
||||
}
|
||||
|
||||
type healthState struct {
|
||||
unhealthyUntil time.Time
|
||||
emptyCount int
|
||||
}
|
||||
|
||||
func newHealthTracker() *healthTracker {
|
||||
return &healthTracker{states: map[config.RoleTarget]*healthState{}}
|
||||
}
|
||||
|
||||
func (h *healthTracker) skip(target config.RoleTarget) bool {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
s, ok := h.states[target]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(s.unhealthyUntil)
|
||||
}
|
||||
|
||||
func (h *healthTracker) markTransient(target config.RoleTarget) {
|
||||
h.setCooldown(target, transientCooldown)
|
||||
}
|
||||
|
||||
func (h *healthTracker) markPermanent(target config.RoleTarget) {
|
||||
h.setCooldown(target, permanentCooldown)
|
||||
}
|
||||
|
||||
func (h *healthTracker) markEmpty(target config.RoleTarget) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
s := h.states[target]
|
||||
if s == nil {
|
||||
s = &healthState{}
|
||||
h.states[target] = s
|
||||
}
|
||||
s.emptyCount++
|
||||
if s.emptyCount >= emptyResponseThreshold {
|
||||
s.unhealthyUntil = time.Now().Add(emptyResponseCooldown)
|
||||
s.emptyCount = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (h *healthTracker) markHealthy(target config.RoleTarget) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if s, ok := h.states[target]; ok {
|
||||
s.unhealthyUntil = time.Time{}
|
||||
s.emptyCount = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (h *healthTracker) setCooldown(target config.RoleTarget, d time.Duration) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
s := h.states[target]
|
||||
if s == nil {
|
||||
s = &healthState{}
|
||||
h.states[target] = s
|
||||
}
|
||||
s.unhealthyUntil = time.Now().Add(d)
|
||||
s.emptyCount = 0
|
||||
}
|
||||
Reference in New Issue
Block a user