Some checks failed
CI / build-and-test (push) Failing after -32m22s
* Implement tests for migrating configuration from v1 to v2 for the litellm provider. * Validate the structure and values of the migrated configuration. * Ensure migration rejects newer versions of the configuration. fix(validate): enhance AI provider validation logic * Consolidate provider validation into a dedicated method. * Ensure at least one provider is specified and validate its type. * Check for required fields based on provider type. fix(mcpserver): update tool set to use new enrichment tool * Replace RetryMetadataTool with RetryEnrichmentTool in the ToolSet. fix(tools): refactor tools to use embedding and metadata runners * Update tools to utilize EmbeddingRunner and MetadataRunner instead of Provider. * Adjust method calls to align with the new runner interfaces.
368 lines
11 KiB
Go
368 lines
11 KiB
Go
package ai
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"sync"
|
|
"time"
|
|
|
|
"git.warky.dev/wdevs/amcs/internal/ai/compat"
|
|
"git.warky.dev/wdevs/amcs/internal/config"
|
|
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
|
)
|
|
|
|
// Health TTLs per failure class. These are short enough that a healed target
|
|
// gets retried without manual intervention, but long enough to avoid hammering
|
|
// a broken provider every call.
|
|
const (
|
|
transientCooldown = 30 * time.Second
|
|
permanentCooldown = 10 * time.Minute
|
|
emptyResponseThreshold = 3
|
|
emptyResponseCooldown = 2 * time.Minute
|
|
dimensionMismatchWarning = "embedding dimension mismatch"
|
|
)
|
|
|
|
// EmbedResult carries the vector plus the (provider, model) that produced it —
|
|
// callers store the actual model so later searches against that row use the
|
|
// matching query embedding.
|
|
type EmbedResult struct {
|
|
Vector []float32
|
|
Provider string
|
|
Model string
|
|
}
|
|
|
|
// EmbeddingRunner executes the embeddings role chain with sequential fallback.
|
|
type EmbeddingRunner struct {
|
|
registry *Registry
|
|
chain []config.RoleTarget
|
|
dimensions int
|
|
health *healthTracker
|
|
log *slog.Logger
|
|
}
|
|
|
|
// MetadataRunner executes the metadata role chain with sequential fallback and
|
|
// a heuristic fallthrough when every target is unhealthy or fails.
|
|
type MetadataRunner struct {
|
|
registry *Registry
|
|
chain []config.RoleTarget
|
|
opts metadataRunOpts
|
|
health *healthTracker
|
|
log *slog.Logger
|
|
}
|
|
|
|
type metadataRunOpts struct {
|
|
temperature float64
|
|
logConversations bool
|
|
}
|
|
|
|
// NewEmbeddingRunner builds a runner for the embeddings role. chain must be
|
|
// non-empty and every target must be registered.
|
|
func NewEmbeddingRunner(registry *Registry, chain []config.RoleTarget, dimensions int, log *slog.Logger) (*EmbeddingRunner, error) {
|
|
if registry == nil {
|
|
return nil, fmt.Errorf("embedding runner: registry is required")
|
|
}
|
|
if len(chain) == 0 {
|
|
return nil, fmt.Errorf("embedding runner: chain is empty")
|
|
}
|
|
if dimensions <= 0 {
|
|
return nil, fmt.Errorf("embedding runner: dimensions must be > 0")
|
|
}
|
|
for i, t := range chain {
|
|
if _, err := registry.Client(t.Provider); err != nil {
|
|
return nil, fmt.Errorf("embedding runner: chain[%d]: %w", i, err)
|
|
}
|
|
}
|
|
return &EmbeddingRunner{
|
|
registry: registry,
|
|
chain: chain,
|
|
dimensions: dimensions,
|
|
health: newHealthTracker(),
|
|
log: log,
|
|
}, nil
|
|
}
|
|
|
|
// NewMetadataRunner builds a runner for the metadata role.
|
|
func NewMetadataRunner(registry *Registry, chain []config.RoleTarget, temperature float64, logConversations bool, log *slog.Logger) (*MetadataRunner, error) {
|
|
if registry == nil {
|
|
return nil, fmt.Errorf("metadata runner: registry is required")
|
|
}
|
|
if len(chain) == 0 {
|
|
return nil, fmt.Errorf("metadata runner: chain is empty")
|
|
}
|
|
for i, t := range chain {
|
|
if _, err := registry.Client(t.Provider); err != nil {
|
|
return nil, fmt.Errorf("metadata runner: chain[%d]: %w", i, err)
|
|
}
|
|
}
|
|
return &MetadataRunner{
|
|
registry: registry,
|
|
chain: chain,
|
|
opts: metadataRunOpts{
|
|
temperature: temperature,
|
|
logConversations: logConversations,
|
|
},
|
|
health: newHealthTracker(),
|
|
log: log,
|
|
}, nil
|
|
}
|
|
|
|
// PrimaryProvider returns the first provider in the chain.
|
|
func (r *EmbeddingRunner) PrimaryProvider() string { return r.chain[0].Provider }
|
|
|
|
// PrimaryModel returns the first model in the chain — the one used as the
|
|
// storage key for search matching.
|
|
func (r *EmbeddingRunner) PrimaryModel() string { return r.chain[0].Model }
|
|
|
|
// Dimensions returns the required vector dimension.
|
|
func (r *EmbeddingRunner) Dimensions() int { return r.dimensions }
|
|
|
|
// Embed walks the chain and returns the first successful embedding. The
|
|
// returned EmbedResult names the actual (provider, model) that produced the
|
|
// vector — callers use that when recording the row.
|
|
func (r *EmbeddingRunner) Embed(ctx context.Context, input string) (EmbedResult, error) {
|
|
var errs []error
|
|
for _, target := range r.chain {
|
|
if r.health.skip(target) {
|
|
continue
|
|
}
|
|
client, err := r.registry.Client(target.Provider)
|
|
if err != nil {
|
|
errs = append(errs, err)
|
|
continue
|
|
}
|
|
vec, err := client.EmbedWith(ctx, target.Model, input)
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
return EmbedResult{}, ctx.Err()
|
|
}
|
|
r.classify(target, err)
|
|
r.logFailure("embed", target, err)
|
|
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err))
|
|
continue
|
|
}
|
|
if len(vec) != r.dimensions {
|
|
dimErr := fmt.Errorf("%s: expected %d, got %d", dimensionMismatchWarning, r.dimensions, len(vec))
|
|
r.health.markTransient(target)
|
|
r.logFailure("embed", target, dimErr)
|
|
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, dimErr))
|
|
continue
|
|
}
|
|
r.health.markHealthy(target)
|
|
return EmbedResult{Vector: vec, Provider: target.Provider, Model: target.Model}, nil
|
|
}
|
|
return EmbedResult{}, fmt.Errorf("all embedding targets failed: %w", errors.Join(errs...))
|
|
}
|
|
|
|
// EmbedPrimary embeds using only the primary target — used for search queries
|
|
// so the query vector matches rows stored under the primary model. Falls back
|
|
// to returning the error without walking the chain.
|
|
func (r *EmbeddingRunner) EmbedPrimary(ctx context.Context, input string) ([]float32, error) {
|
|
target := r.chain[0]
|
|
client, err := r.registry.Client(target.Provider)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
vec, err := client.EmbedWith(ctx, target.Model, input)
|
|
if err != nil {
|
|
r.classify(target, err)
|
|
return nil, err
|
|
}
|
|
if len(vec) != r.dimensions {
|
|
return nil, fmt.Errorf("%s: expected %d, got %d", dimensionMismatchWarning, r.dimensions, len(vec))
|
|
}
|
|
r.health.markHealthy(target)
|
|
return vec, nil
|
|
}
|
|
|
|
// PrimaryProvider / PrimaryModel for metadata mirror the embedding runner.
|
|
func (r *MetadataRunner) PrimaryProvider() string { return r.chain[0].Provider }
|
|
func (r *MetadataRunner) PrimaryModel() string { return r.chain[0].Model }
|
|
|
|
// ExtractMetadata walks the chain sequentially. If every target fails or is
|
|
// unhealthy, it returns a heuristic metadata so capture never hard-fails.
|
|
func (r *MetadataRunner) ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) {
|
|
var errs []error
|
|
for _, target := range r.chain {
|
|
if r.health.skip(target) {
|
|
continue
|
|
}
|
|
client, err := r.registry.Client(target.Provider)
|
|
if err != nil {
|
|
errs = append(errs, err)
|
|
continue
|
|
}
|
|
md, err := client.ExtractMetadataWith(ctx, compat.MetadataOptions{
|
|
Model: target.Model,
|
|
Temperature: r.opts.temperature,
|
|
LogConversations: r.opts.logConversations,
|
|
}, input)
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
return thoughttypes.ThoughtMetadata{}, ctx.Err()
|
|
}
|
|
r.classify(target, err)
|
|
r.logFailure("metadata", target, err)
|
|
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err))
|
|
continue
|
|
}
|
|
r.health.markHealthy(target)
|
|
return md, nil
|
|
}
|
|
if r.log != nil {
|
|
r.log.Warn("metadata chain exhausted, using heuristic fallback",
|
|
slog.Int("targets", len(r.chain)),
|
|
slog.String("error", errors.Join(errs...).Error()),
|
|
)
|
|
}
|
|
return compat.HeuristicMetadataFromInput(input), nil
|
|
}
|
|
|
|
// Summarize walks the chain; unlike metadata, there is no heuristic fallback —
|
|
// returns the joined error when everything fails.
|
|
func (r *MetadataRunner) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
|
|
var errs []error
|
|
for _, target := range r.chain {
|
|
if r.health.skip(target) {
|
|
continue
|
|
}
|
|
client, err := r.registry.Client(target.Provider)
|
|
if err != nil {
|
|
errs = append(errs, err)
|
|
continue
|
|
}
|
|
out, err := client.SummarizeWith(ctx, compat.SummarizeOptions{
|
|
Model: target.Model,
|
|
Temperature: r.opts.temperature,
|
|
}, systemPrompt, userPrompt)
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
return "", ctx.Err()
|
|
}
|
|
r.classify(target, err)
|
|
r.logFailure("summarize", target, err)
|
|
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err))
|
|
continue
|
|
}
|
|
r.health.markHealthy(target)
|
|
return out, nil
|
|
}
|
|
return "", fmt.Errorf("all summarize targets failed: %w", errors.Join(errs...))
|
|
}
|
|
|
|
func (r *EmbeddingRunner) classify(target config.RoleTarget, err error) {
|
|
switch {
|
|
case compat.IsPermanentModelError(err):
|
|
r.health.markPermanent(target)
|
|
default:
|
|
r.health.markTransient(target)
|
|
}
|
|
}
|
|
|
|
func (r *MetadataRunner) classify(target config.RoleTarget, err error) {
|
|
switch {
|
|
case compat.IsPermanentModelError(err):
|
|
r.health.markPermanent(target)
|
|
case errors.Is(err, compat.ErrEmptyResponse):
|
|
r.health.markEmpty(target)
|
|
default:
|
|
r.health.markTransient(target)
|
|
}
|
|
}
|
|
|
|
func (r *EmbeddingRunner) logFailure(role string, target config.RoleTarget, err error) {
|
|
if r.log == nil {
|
|
return
|
|
}
|
|
r.log.Warn("ai target failed",
|
|
slog.String("role", role),
|
|
slog.String("provider", target.Provider),
|
|
slog.String("model", target.Model),
|
|
slog.String("error", err.Error()),
|
|
)
|
|
}
|
|
|
|
func (r *MetadataRunner) logFailure(role string, target config.RoleTarget, err error) {
|
|
if r.log == nil {
|
|
return
|
|
}
|
|
r.log.Warn("ai target failed",
|
|
slog.String("role", role),
|
|
slog.String("provider", target.Provider),
|
|
slog.String("model", target.Model),
|
|
slog.String("error", err.Error()),
|
|
)
|
|
}
|
|
|
|
// healthTracker records per-(provider, model) failure state. skip returns true
|
|
// when a target is still inside its cooldown window; the caller then tries the
|
|
// next target in the chain.
|
|
type healthTracker struct {
|
|
mu sync.Mutex
|
|
states map[config.RoleTarget]*healthState
|
|
}
|
|
|
|
type healthState struct {
|
|
unhealthyUntil time.Time
|
|
emptyCount int
|
|
}
|
|
|
|
func newHealthTracker() *healthTracker {
|
|
return &healthTracker{states: map[config.RoleTarget]*healthState{}}
|
|
}
|
|
|
|
func (h *healthTracker) skip(target config.RoleTarget) bool {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
s, ok := h.states[target]
|
|
if !ok {
|
|
return false
|
|
}
|
|
return time.Now().Before(s.unhealthyUntil)
|
|
}
|
|
|
|
func (h *healthTracker) markTransient(target config.RoleTarget) {
|
|
h.setCooldown(target, transientCooldown)
|
|
}
|
|
|
|
func (h *healthTracker) markPermanent(target config.RoleTarget) {
|
|
h.setCooldown(target, permanentCooldown)
|
|
}
|
|
|
|
func (h *healthTracker) markEmpty(target config.RoleTarget) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
s := h.states[target]
|
|
if s == nil {
|
|
s = &healthState{}
|
|
h.states[target] = s
|
|
}
|
|
s.emptyCount++
|
|
if s.emptyCount >= emptyResponseThreshold {
|
|
s.unhealthyUntil = time.Now().Add(emptyResponseCooldown)
|
|
s.emptyCount = 0
|
|
}
|
|
}
|
|
|
|
func (h *healthTracker) markHealthy(target config.RoleTarget) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
if s, ok := h.states[target]; ok {
|
|
s.unhealthyUntil = time.Time{}
|
|
s.emptyCount = 0
|
|
}
|
|
}
|
|
|
|
func (h *healthTracker) setCooldown(target config.RoleTarget, d time.Duration) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
s := h.states[target]
|
|
if s == nil {
|
|
s = &healthState{}
|
|
h.states[target] = s
|
|
}
|
|
s.unhealthyUntil = time.Now().Add(d)
|
|
s.emptyCount = 0
|
|
}
|