package ai import ( "context" "errors" "fmt" "log/slog" "sync" "time" "git.warky.dev/wdevs/amcs/internal/ai/compat" "git.warky.dev/wdevs/amcs/internal/config" thoughttypes "git.warky.dev/wdevs/amcs/internal/types" ) // Health TTLs per failure class. These are short enough that a healed target // gets retried without manual intervention, but long enough to avoid hammering // a broken provider every call. const ( transientCooldown = 30 * time.Second permanentCooldown = 10 * time.Minute emptyResponseThreshold = 3 emptyResponseCooldown = 2 * time.Minute dimensionMismatchWarning = "embedding dimension mismatch" ) // EmbedResult carries the vector plus the (provider, model) that produced it — // callers store the actual model so later searches against that row use the // matching query embedding. type EmbedResult struct { Vector []float32 Provider string Model string } // EmbeddingRunner executes the embeddings role chain with sequential fallback. type EmbeddingRunner struct { registry *Registry chain []config.RoleTarget dimensions int health *healthTracker log *slog.Logger } // MetadataRunner executes the metadata role chain with sequential fallback and // a heuristic fallthrough when every target is unhealthy or fails. type MetadataRunner struct { registry *Registry chain []config.RoleTarget opts metadataRunOpts health *healthTracker log *slog.Logger } type metadataRunOpts struct { temperature float64 logConversations bool } // NewEmbeddingRunner builds a runner for the embeddings role. chain must be // non-empty and every target must be registered. func NewEmbeddingRunner(registry *Registry, chain []config.RoleTarget, dimensions int, log *slog.Logger) (*EmbeddingRunner, error) { if registry == nil { return nil, fmt.Errorf("embedding runner: registry is required") } if len(chain) == 0 { return nil, fmt.Errorf("embedding runner: chain is empty") } if dimensions <= 0 { return nil, fmt.Errorf("embedding runner: dimensions must be > 0") } for i, t := range chain { if _, err := registry.Client(t.Provider); err != nil { return nil, fmt.Errorf("embedding runner: chain[%d]: %w", i, err) } } return &EmbeddingRunner{ registry: registry, chain: chain, dimensions: dimensions, health: newHealthTracker(), log: log, }, nil } // NewMetadataRunner builds a runner for the metadata role. func NewMetadataRunner(registry *Registry, chain []config.RoleTarget, temperature float64, logConversations bool, log *slog.Logger) (*MetadataRunner, error) { if registry == nil { return nil, fmt.Errorf("metadata runner: registry is required") } if len(chain) == 0 { return nil, fmt.Errorf("metadata runner: chain is empty") } for i, t := range chain { if _, err := registry.Client(t.Provider); err != nil { return nil, fmt.Errorf("metadata runner: chain[%d]: %w", i, err) } } return &MetadataRunner{ registry: registry, chain: chain, opts: metadataRunOpts{ temperature: temperature, logConversations: logConversations, }, health: newHealthTracker(), log: log, }, nil } // PrimaryProvider returns the first provider in the chain. func (r *EmbeddingRunner) PrimaryProvider() string { return r.chain[0].Provider } // PrimaryModel returns the first model in the chain — the one used as the // storage key for search matching. func (r *EmbeddingRunner) PrimaryModel() string { return r.chain[0].Model } // Dimensions returns the required vector dimension. func (r *EmbeddingRunner) Dimensions() int { return r.dimensions } // Embed walks the chain and returns the first successful embedding. The // returned EmbedResult names the actual (provider, model) that produced the // vector — callers use that when recording the row. func (r *EmbeddingRunner) Embed(ctx context.Context, input string) (EmbedResult, error) { var errs []error for _, target := range r.chain { if r.health.skip(target) { continue } client, err := r.registry.Client(target.Provider) if err != nil { errs = append(errs, err) continue } vec, err := client.EmbedWith(ctx, target.Model, input) if err != nil { if ctx.Err() != nil { return EmbedResult{}, ctx.Err() } r.classify(target, err) r.logFailure("embed", target, err) errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err)) continue } if len(vec) != r.dimensions { dimErr := fmt.Errorf("%s: expected %d, got %d", dimensionMismatchWarning, r.dimensions, len(vec)) r.health.markTransient(target) r.logFailure("embed", target, dimErr) errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, dimErr)) continue } r.health.markHealthy(target) return EmbedResult{Vector: vec, Provider: target.Provider, Model: target.Model}, nil } return EmbedResult{}, fmt.Errorf("all embedding targets failed: %w", errors.Join(errs...)) } // EmbedPrimary embeds using only the primary target — used for search queries // so the query vector matches rows stored under the primary model. Falls back // to returning the error without walking the chain. func (r *EmbeddingRunner) EmbedPrimary(ctx context.Context, input string) ([]float32, error) { target := r.chain[0] client, err := r.registry.Client(target.Provider) if err != nil { return nil, err } vec, err := client.EmbedWith(ctx, target.Model, input) if err != nil { r.classify(target, err) return nil, err } if len(vec) != r.dimensions { return nil, fmt.Errorf("%s: expected %d, got %d", dimensionMismatchWarning, r.dimensions, len(vec)) } r.health.markHealthy(target) return vec, nil } // PrimaryProvider / PrimaryModel for metadata mirror the embedding runner. func (r *MetadataRunner) PrimaryProvider() string { return r.chain[0].Provider } func (r *MetadataRunner) PrimaryModel() string { return r.chain[0].Model } // ExtractMetadata walks the chain sequentially. If every target fails or is // unhealthy, it returns a heuristic metadata so capture never hard-fails. func (r *MetadataRunner) ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) { var errs []error for _, target := range r.chain { if r.health.skip(target) { continue } client, err := r.registry.Client(target.Provider) if err != nil { errs = append(errs, err) continue } md, err := client.ExtractMetadataWith(ctx, compat.MetadataOptions{ Model: target.Model, Temperature: r.opts.temperature, LogConversations: r.opts.logConversations, }, input) if err != nil { if ctx.Err() != nil { return thoughttypes.ThoughtMetadata{}, ctx.Err() } r.classify(target, err) r.logFailure("metadata", target, err) errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err)) continue } r.health.markHealthy(target) return md, nil } if r.log != nil { r.log.Warn("metadata chain exhausted, using heuristic fallback", slog.Int("targets", len(r.chain)), slog.String("error", errors.Join(errs...).Error()), ) } return compat.HeuristicMetadataFromInput(input), nil } // Summarize walks the chain; unlike metadata, there is no heuristic fallback — // returns the joined error when everything fails. func (r *MetadataRunner) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) { var errs []error for _, target := range r.chain { if r.health.skip(target) { continue } client, err := r.registry.Client(target.Provider) if err != nil { errs = append(errs, err) continue } out, err := client.SummarizeWith(ctx, compat.SummarizeOptions{ Model: target.Model, Temperature: r.opts.temperature, }, systemPrompt, userPrompt) if err != nil { if ctx.Err() != nil { return "", ctx.Err() } r.classify(target, err) r.logFailure("summarize", target, err) errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err)) continue } r.health.markHealthy(target) return out, nil } return "", fmt.Errorf("all summarize targets failed: %w", errors.Join(errs...)) } func (r *EmbeddingRunner) classify(target config.RoleTarget, err error) { switch { case compat.IsPermanentModelError(err): r.health.markPermanent(target) default: r.health.markTransient(target) } } func (r *MetadataRunner) classify(target config.RoleTarget, err error) { switch { case compat.IsPermanentModelError(err): r.health.markPermanent(target) case errors.Is(err, compat.ErrEmptyResponse): r.health.markEmpty(target) default: r.health.markTransient(target) } } func (r *EmbeddingRunner) logFailure(role string, target config.RoleTarget, err error) { if r.log == nil { return } r.log.Warn("ai target failed", slog.String("role", role), slog.String("provider", target.Provider), slog.String("model", target.Model), slog.String("error", err.Error()), ) } func (r *MetadataRunner) logFailure(role string, target config.RoleTarget, err error) { if r.log == nil { return } r.log.Warn("ai target failed", slog.String("role", role), slog.String("provider", target.Provider), slog.String("model", target.Model), slog.String("error", err.Error()), ) } // healthTracker records per-(provider, model) failure state. skip returns true // when a target is still inside its cooldown window; the caller then tries the // next target in the chain. type healthTracker struct { mu sync.Mutex states map[config.RoleTarget]*healthState } type healthState struct { unhealthyUntil time.Time emptyCount int } func newHealthTracker() *healthTracker { return &healthTracker{states: map[config.RoleTarget]*healthState{}} } func (h *healthTracker) skip(target config.RoleTarget) bool { h.mu.Lock() defer h.mu.Unlock() s, ok := h.states[target] if !ok { return false } return time.Now().Before(s.unhealthyUntil) } func (h *healthTracker) markTransient(target config.RoleTarget) { h.setCooldown(target, transientCooldown) } func (h *healthTracker) markPermanent(target config.RoleTarget) { h.setCooldown(target, permanentCooldown) } func (h *healthTracker) markEmpty(target config.RoleTarget) { h.mu.Lock() defer h.mu.Unlock() s := h.states[target] if s == nil { s = &healthState{} h.states[target] = s } s.emptyCount++ if s.emptyCount >= emptyResponseThreshold { s.unhealthyUntil = time.Now().Add(emptyResponseCooldown) s.emptyCount = 0 } } func (h *healthTracker) markHealthy(target config.RoleTarget) { h.mu.Lock() defer h.mu.Unlock() if s, ok := h.states[target]; ok { s.unhealthyUntil = time.Time{} s.emptyCount = 0 } } func (h *healthTracker) setCooldown(target config.RoleTarget, d time.Duration) { h.mu.Lock() defer h.mu.Unlock() s := h.states[target] if s == nil { s = &healthState{} h.states[target] = s } s.unhealthyUntil = time.Now().Add(d) s.emptyCount = 0 }