test(config): add migration tests for litellm provider
Some checks failed
CI / build-and-test (push) Failing after -32m22s
Some checks failed
CI / build-and-test (push) Failing after -32m22s
* Implement tests for migrating configuration from v1 to v2 for the litellm provider. * Validate the structure and values of the migrated configuration. * Ensure migration rejects newer versions of the configuration. fix(validate): enhance AI provider validation logic * Consolidate provider validation into a dedicated method. * Ensure at least one provider is specified and validate its type. * Check for required fields based on provider type. fix(mcpserver): update tool set to use new enrichment tool * Replace RetryMetadataTool with RetryEnrichmentTool in the ToolSet. fix(tools): refactor tools to use embedding and metadata runners * Update tools to utilize EmbeddingRunner and MetadataRunner instead of Provider. * Adjust method calls to align with the new runner interfaces.
This commit is contained in:
@@ -14,7 +14,6 @@ import (
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
@@ -36,36 +35,39 @@ Rules:
|
||||
- If unsure, prefer "observation".
|
||||
- Do not include any text outside the JSON object.`
|
||||
|
||||
// Client is a low-level OpenAI-compatible HTTP client. It knows nothing about
|
||||
// role chains, fallbacks, or health — those concerns belong to ai.Runner. Each
|
||||
// method takes the model name per-call so a single Client instance can service
|
||||
// many different models on the same base URL.
|
||||
type Client struct {
|
||||
name string
|
||||
baseURL string
|
||||
apiKey string
|
||||
embeddingModel string
|
||||
metadataModel string
|
||||
fallbackMetadataModels []string
|
||||
temperature float64
|
||||
headers map[string]string
|
||||
httpClient *http.Client
|
||||
log *slog.Logger
|
||||
dimensions int
|
||||
logConversations bool
|
||||
modelHealthMu sync.Mutex
|
||||
modelHealth map[string]modelHealthState
|
||||
name string
|
||||
baseURL string
|
||||
apiKey string
|
||||
headers map[string]string
|
||||
httpClient *http.Client
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string
|
||||
BaseURL string
|
||||
APIKey string
|
||||
EmbeddingModel string
|
||||
MetadataModel string
|
||||
FallbackMetadataModels []string
|
||||
Temperature float64
|
||||
Headers map[string]string
|
||||
HTTPClient *http.Client
|
||||
Log *slog.Logger
|
||||
Dimensions int
|
||||
LogConversations bool
|
||||
Name string
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Headers map[string]string
|
||||
HTTPClient *http.Client
|
||||
Log *slog.Logger
|
||||
}
|
||||
|
||||
// MetadataOptions control a single ExtractMetadataWith call.
|
||||
type MetadataOptions struct {
|
||||
Model string
|
||||
Temperature float64
|
||||
LogConversations bool
|
||||
}
|
||||
|
||||
// SummarizeOptions control a single SummarizeWith call.
|
||||
type SummarizeOptions struct {
|
||||
Model string
|
||||
Temperature float64
|
||||
}
|
||||
|
||||
type embeddingsRequest struct {
|
||||
@@ -127,65 +129,38 @@ type providerError struct {
|
||||
|
||||
const maxMetadataAttempts = 3
|
||||
|
||||
const (
|
||||
emptyResponseCircuitThreshold = 3
|
||||
emptyResponseCircuitTTL = 5 * time.Minute
|
||||
permanentModelFailureTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
// ErrEmptyResponse and ErrNoJSONObject are sentinel errors callers can inspect
|
||||
// to classify metadata failures (e.g. bump empty-response health counters).
|
||||
var (
|
||||
errMetadataEmptyResponse = errors.New("metadata empty response")
|
||||
errMetadataNoJSONObject = errors.New("metadata response contains no JSON object")
|
||||
ErrEmptyResponse = errors.New("metadata empty response")
|
||||
ErrNoJSONObject = errors.New("metadata response contains no JSON object")
|
||||
)
|
||||
|
||||
type modelHealthState struct {
|
||||
consecutiveEmpty int
|
||||
unhealthyUntil time.Time
|
||||
}
|
||||
|
||||
func New(cfg Config) *Client {
|
||||
fallbacks := make([]string, 0, len(cfg.FallbackMetadataModels))
|
||||
seen := make(map[string]struct{}, len(cfg.FallbackMetadataModels))
|
||||
for _, model := range cfg.FallbackMetadataModels {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[model]; ok {
|
||||
continue
|
||||
}
|
||||
seen[model] = struct{}{}
|
||||
fallbacks = append(fallbacks, model)
|
||||
}
|
||||
|
||||
return &Client{
|
||||
name: cfg.Name,
|
||||
baseURL: cfg.BaseURL,
|
||||
apiKey: cfg.APIKey,
|
||||
embeddingModel: cfg.EmbeddingModel,
|
||||
metadataModel: cfg.MetadataModel,
|
||||
fallbackMetadataModels: fallbacks,
|
||||
temperature: cfg.Temperature,
|
||||
headers: cfg.Headers,
|
||||
httpClient: cfg.HTTPClient,
|
||||
log: cfg.Log,
|
||||
dimensions: cfg.Dimensions,
|
||||
logConversations: cfg.LogConversations,
|
||||
modelHealth: make(map[string]modelHealthState),
|
||||
name: cfg.Name,
|
||||
baseURL: cfg.BaseURL,
|
||||
apiKey: cfg.APIKey,
|
||||
headers: cfg.Headers,
|
||||
httpClient: cfg.HTTPClient,
|
||||
log: cfg.Log,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Embed(ctx context.Context, input string) ([]float32, error) {
|
||||
func (c *Client) Name() string { return c.name }
|
||||
|
||||
// EmbedWith generates an embedding for the given input using model.
|
||||
func (c *Client) EmbedWith(ctx context.Context, model, input string) ([]float32, error) {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return nil, fmt.Errorf("%s embed: input must not be empty", c.name)
|
||||
}
|
||||
if strings.TrimSpace(model) == "" {
|
||||
return nil, fmt.Errorf("%s embed: model is required", c.name)
|
||||
}
|
||||
|
||||
var resp embeddingsResponse
|
||||
err := c.doJSON(ctx, "/embeddings", embeddingsRequest{
|
||||
Input: input,
|
||||
Model: c.embeddingModel,
|
||||
}, &resp)
|
||||
err := c.doJSON(ctx, "/embeddings", embeddingsRequest{Input: input, Model: model}, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -195,141 +170,34 @@ func (c *Client) Embed(ctx context.Context, input string) ([]float32, error) {
|
||||
if len(resp.Data) == 0 {
|
||||
return nil, fmt.Errorf("%s embed: no embedding returned", c.name)
|
||||
}
|
||||
if c.dimensions > 0 && len(resp.Data[0].Embedding) != c.dimensions {
|
||||
return nil, fmt.Errorf("%s embed: expected %d dimensions, got %d", c.name, c.dimensions, len(resp.Data[0].Embedding))
|
||||
}
|
||||
|
||||
return resp.Data[0].Embedding, nil
|
||||
}
|
||||
|
||||
func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) {
|
||||
// ExtractMetadataWith extracts structured metadata for input using opts.Model.
|
||||
// Returns compat.ErrEmptyResponse / ErrNoJSONObject wrapped when the model
|
||||
// produces unusable output so callers can classify the failure.
|
||||
func (c *Client) ExtractMetadataWith(ctx context.Context, opts MetadataOptions, input string) (thoughttypes.ThoughtMetadata, error) {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: input must not be empty", c.name)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
if c.log != nil {
|
||||
c.log.Info("metadata client started",
|
||||
slog.String("provider", c.name),
|
||||
slog.String("model", c.metadataModel),
|
||||
)
|
||||
}
|
||||
|
||||
logCompletion := func(model string, err error) {
|
||||
if c.log == nil {
|
||||
return
|
||||
}
|
||||
|
||||
attrs := []any{
|
||||
slog.String("provider", c.name),
|
||||
slog.String("model", model),
|
||||
slog.String("duration", formatLogDuration(time.Since(start))),
|
||||
}
|
||||
if err != nil {
|
||||
attrs = append(attrs, slog.String("error", err.Error()))
|
||||
c.log.Error("metadata client completed", attrs...)
|
||||
return
|
||||
}
|
||||
|
||||
c.log.Info("metadata client completed", attrs...)
|
||||
}
|
||||
|
||||
result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel)
|
||||
if errors.Is(err, errMetadataEmptyResponse) {
|
||||
c.noteEmptyResponse(c.metadataModel)
|
||||
}
|
||||
if isPermanentModelError(err) {
|
||||
c.notePermanentModelFailure(c.metadataModel, err)
|
||||
}
|
||||
if err == nil {
|
||||
c.noteModelSuccess(c.metadataModel)
|
||||
logCompletion(c.metadataModel, nil)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
for _, fallbackModel := range c.fallbackMetadataModels {
|
||||
if ctx.Err() != nil {
|
||||
break
|
||||
}
|
||||
if fallbackModel == "" || fallbackModel == c.metadataModel {
|
||||
continue
|
||||
}
|
||||
if c.shouldBypassModel(fallbackModel) {
|
||||
continue
|
||||
}
|
||||
if c.log != nil {
|
||||
c.log.Warn("metadata extraction failed, trying fallback model",
|
||||
slog.String("provider", c.name),
|
||||
slog.String("primary_model", c.metadataModel),
|
||||
slog.String("fallback_model", fallbackModel),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
fallbackResult, fallbackErr := c.extractMetadataWithModel(ctx, input, fallbackModel)
|
||||
if errors.Is(fallbackErr, errMetadataEmptyResponse) {
|
||||
c.noteEmptyResponse(fallbackModel)
|
||||
}
|
||||
if isPermanentModelError(fallbackErr) {
|
||||
c.notePermanentModelFailure(fallbackModel, fallbackErr)
|
||||
}
|
||||
if fallbackErr == nil {
|
||||
c.noteModelSuccess(fallbackModel)
|
||||
logCompletion(fallbackModel, nil)
|
||||
return fallbackResult, nil
|
||||
}
|
||||
err = fallbackErr
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
err = fmt.Errorf("%s metadata: %w", c.name, ctx.Err())
|
||||
logCompletion(c.metadataModel, err)
|
||||
return thoughttypes.ThoughtMetadata{}, err
|
||||
}
|
||||
|
||||
heuristic := heuristicMetadataFromInput(input)
|
||||
if c.log != nil {
|
||||
c.log.Warn("metadata extraction failed for all models, using heuristic fallback",
|
||||
slog.String("provider", c.name),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
logCompletion(c.metadataModel, nil)
|
||||
return heuristic, nil
|
||||
}
|
||||
|
||||
func formatLogDuration(d time.Duration) string {
|
||||
if d < 0 {
|
||||
d = -d
|
||||
}
|
||||
|
||||
totalMilliseconds := d.Milliseconds()
|
||||
minutes := totalMilliseconds / 60000
|
||||
seconds := (totalMilliseconds / 1000) % 60
|
||||
milliseconds := totalMilliseconds % 1000
|
||||
return fmt.Sprintf("%02d:%02d:%03d", minutes, seconds, milliseconds)
|
||||
}
|
||||
|
||||
func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) {
|
||||
if c.shouldBypassModel(model) {
|
||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: model %q temporarily bypassed after repeated empty responses", c.name, model)
|
||||
if strings.TrimSpace(opts.Model) == "" {
|
||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: model is required", c.name)
|
||||
}
|
||||
|
||||
stream := true
|
||||
req := chatCompletionsRequest{
|
||||
Model: model,
|
||||
Temperature: c.temperature,
|
||||
ResponseFormat: &responseType{
|
||||
Type: "json_object",
|
||||
},
|
||||
Stream: &stream,
|
||||
Model: opts.Model,
|
||||
Temperature: opts.Temperature,
|
||||
ResponseFormat: &responseType{Type: "json_object"},
|
||||
Stream: &stream,
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: metadataSystemPrompt},
|
||||
{Role: "user", Content: input},
|
||||
},
|
||||
}
|
||||
|
||||
metadata, err := c.extractMetadataWithRequest(ctx, req, input, model)
|
||||
metadata, err := c.extractMetadataWithRequest(ctx, req, input, opts)
|
||||
if err == nil || !shouldRetryWithoutJSONMode(err) {
|
||||
return metadata, err
|
||||
}
|
||||
@@ -337,23 +205,22 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri
|
||||
if c.log != nil {
|
||||
c.log.Warn("metadata json mode failed, retrying without response_format",
|
||||
slog.String("provider", c.name),
|
||||
slog.String("model", model),
|
||||
slog.String("model", opts.Model),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
|
||||
req.ResponseFormat = nil
|
||||
return c.extractMetadataWithRequest(ctx, req, input, model)
|
||||
return c.extractMetadataWithRequest(ctx, req, input, opts)
|
||||
}
|
||||
|
||||
func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatCompletionsRequest, input, model string) (thoughttypes.ThoughtMetadata, error) {
|
||||
|
||||
func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatCompletionsRequest, input string, opts MetadataOptions) (thoughttypes.ThoughtMetadata, error) {
|
||||
var lastErr error
|
||||
for attempt := 1; attempt <= maxMetadataAttempts; attempt++ {
|
||||
if c.logConversations && c.log != nil {
|
||||
if opts.LogConversations && c.log != nil {
|
||||
c.log.Info("metadata conversation request",
|
||||
slog.String("provider", c.name),
|
||||
slog.String("model", model),
|
||||
slog.String("model", opts.Model),
|
||||
slog.Int("attempt", attempt),
|
||||
slog.String("system", metadataSystemPrompt),
|
||||
slog.String("input", input),
|
||||
@@ -373,10 +240,10 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet
|
||||
|
||||
rawResponse := extractChoiceText(resp.Choices[0].Message, resp.Choices[0].Text)
|
||||
|
||||
if c.logConversations && c.log != nil {
|
||||
if opts.LogConversations && c.log != nil {
|
||||
c.log.Info("metadata conversation response",
|
||||
slog.String("provider", c.name),
|
||||
slog.String("model", model),
|
||||
slog.String("model", opts.Model),
|
||||
slog.Int("attempt", attempt),
|
||||
slog.String("response", rawResponse),
|
||||
)
|
||||
@@ -387,13 +254,13 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet
|
||||
metadataText = stripCodeFence(metadataText)
|
||||
metadataText = extractJSONObject(metadataText)
|
||||
if metadataText == "" {
|
||||
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataNoJSONObject)
|
||||
lastErr = fmt.Errorf("%s metadata: %w", c.name, ErrNoJSONObject)
|
||||
if strings.TrimSpace(rawResponse) == "" && attempt < maxMetadataAttempts && ctx.Err() == nil {
|
||||
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataEmptyResponse)
|
||||
lastErr = fmt.Errorf("%s metadata: %w", c.name, ErrEmptyResponse)
|
||||
if c.log != nil {
|
||||
c.log.Warn("metadata response empty, waiting and retrying",
|
||||
slog.String("provider", c.name),
|
||||
slog.String("model", model),
|
||||
slog.String("model", opts.Model),
|
||||
slog.Int("attempt", attempt+1),
|
||||
)
|
||||
}
|
||||
@@ -403,7 +270,7 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(rawResponse) == "" {
|
||||
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataEmptyResponse)
|
||||
lastErr = fmt.Errorf("%s metadata: %w", c.name, ErrEmptyResponse)
|
||||
}
|
||||
return thoughttypes.ThoughtMetadata{}, lastErr
|
||||
}
|
||||
@@ -420,13 +287,17 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet
|
||||
if lastErr != nil {
|
||||
return thoughttypes.ThoughtMetadata{}, lastErr
|
||||
}
|
||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: %w", c.name, errMetadataNoJSONObject)
|
||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: %w", c.name, ErrNoJSONObject)
|
||||
}
|
||||
|
||||
func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
|
||||
// SummarizeWith runs a chat-completion summarisation using opts.Model.
|
||||
func (c *Client) SummarizeWith(ctx context.Context, opts SummarizeOptions, systemPrompt, userPrompt string) (string, error) {
|
||||
if strings.TrimSpace(opts.Model) == "" {
|
||||
return "", fmt.Errorf("%s summarize: model is required", c.name)
|
||||
}
|
||||
req := chatCompletionsRequest{
|
||||
Model: c.metadataModel,
|
||||
Temperature: 0.2,
|
||||
Model: opts.Model,
|
||||
Temperature: opts.Temperature,
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: userPrompt},
|
||||
@@ -447,12 +318,49 @@ func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string)
|
||||
return extractChoiceText(resp.Choices[0].Message, resp.Choices[0].Text), nil
|
||||
}
|
||||
|
||||
func (c *Client) Name() string {
|
||||
return c.name
|
||||
// IsPermanentModelError reports whether err indicates the model itself is
|
||||
// invalid or missing (vs. a transient outage). Runners use this to mark a
|
||||
// target unhealthy for longer.
|
||||
func IsPermanentModelError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
lower := strings.ToLower(err.Error())
|
||||
for _, marker := range []string{
|
||||
"invalid model name",
|
||||
"model_not_found",
|
||||
"model not found",
|
||||
"unknown model",
|
||||
"no such model",
|
||||
"does not exist",
|
||||
} {
|
||||
if strings.Contains(lower, marker) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Client) EmbeddingModel() string {
|
||||
return c.embeddingModel
|
||||
// HeuristicMetadataFromInput produces best-effort metadata from the note text
|
||||
// when every model in the chain has failed. Exported so ai.Runner can use it.
|
||||
func HeuristicMetadataFromInput(input string) thoughttypes.ThoughtMetadata {
|
||||
text := strings.TrimSpace(input)
|
||||
lower := strings.ToLower(text)
|
||||
|
||||
metadata := thoughttypes.ThoughtMetadata{
|
||||
People: heuristicPeople(text),
|
||||
ActionItems: heuristicActionItems(text),
|
||||
DatesMentioned: heuristicDates(text),
|
||||
Topics: heuristicTopics(lower),
|
||||
Type: heuristicType(lower),
|
||||
}
|
||||
if len(metadata.Topics) == 0 {
|
||||
metadata.Topics = []string{"uncategorized"}
|
||||
}
|
||||
if metadata.Type == "" {
|
||||
metadata.Type = "observation"
|
||||
}
|
||||
return metadata
|
||||
}
|
||||
|
||||
func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest any) error {
|
||||
@@ -724,8 +632,6 @@ func isRetryableChatResponseError(err error) bool {
|
||||
return strings.Contains(lower, "read response") || strings.Contains(lower, "read stream response")
|
||||
}
|
||||
|
||||
// extractJSONObject finds the first complete {...} block in s.
|
||||
// It handles models that prepend prose to a JSON response despite json_object mode.
|
||||
func extractJSONObject(s string) string {
|
||||
for start := 0; start < len(s); start++ {
|
||||
if s[start] != '{' {
|
||||
@@ -768,10 +674,6 @@ func extractJSONObject(s string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// stripThinkingBlocks removes <think>...</think> and <thinking>...</thinking>
|
||||
// blocks produced by reasoning models (DeepSeek R1, QwQ, etc.) so that the
|
||||
// remaining text can be parsed as JSON without interference from thinking content
|
||||
// that may itself contain braces.
|
||||
func stripThinkingBlocks(s string) string {
|
||||
for _, tag := range []string{"think", "thinking"} {
|
||||
open := "<" + tag + ">"
|
||||
@@ -857,7 +759,6 @@ func extractTextFromAny(value any) string {
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
case map[string]any:
|
||||
// Common provider shapes for chat content parts.
|
||||
for _, key := range []string{"text", "output_text", "content", "value"} {
|
||||
if nested, ok := typed[key]; ok {
|
||||
if text := strings.TrimSpace(extractTextFromAny(nested)); text != "" {
|
||||
@@ -875,28 +776,6 @@ var (
|
||||
wordPattern = regexp.MustCompile(`[a-zA-Z][a-zA-Z0-9_/-]{2,}`)
|
||||
)
|
||||
|
||||
func heuristicMetadataFromInput(input string) thoughttypes.ThoughtMetadata {
|
||||
text := strings.TrimSpace(input)
|
||||
lower := strings.ToLower(text)
|
||||
|
||||
metadata := thoughttypes.ThoughtMetadata{
|
||||
People: heuristicPeople(text),
|
||||
ActionItems: heuristicActionItems(text),
|
||||
DatesMentioned: heuristicDates(text),
|
||||
Topics: heuristicTopics(lower),
|
||||
Type: heuristicType(lower),
|
||||
Source: "",
|
||||
}
|
||||
|
||||
if len(metadata.Topics) == 0 {
|
||||
metadata.Topics = []string{"uncategorized"}
|
||||
}
|
||||
if metadata.Type == "" {
|
||||
metadata.Type = "observation"
|
||||
}
|
||||
return metadata
|
||||
}
|
||||
|
||||
func heuristicType(lower string) string {
|
||||
switch {
|
||||
case strings.Contains(lower, "preferred name"), strings.Contains(lower, "personal profile"), strings.Contains(lower, "wife:"), strings.Contains(lower, "daughter:"), strings.Contains(lower, "born:"):
|
||||
@@ -1055,7 +934,7 @@ func shouldRetryWithoutJSONMode(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, errMetadataEmptyResponse) || errors.Is(err, errMetadataNoJSONObject) {
|
||||
if errors.Is(err, ErrEmptyResponse) || errors.Is(err, ErrNoJSONObject) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1063,27 +942,6 @@ func shouldRetryWithoutJSONMode(err error) bool {
|
||||
return strings.Contains(lower, "parse json")
|
||||
}
|
||||
|
||||
func isPermanentModelError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
lower := strings.ToLower(err.Error())
|
||||
for _, marker := range []string{
|
||||
"invalid model name",
|
||||
"model_not_found",
|
||||
"model not found",
|
||||
"unknown model",
|
||||
"no such model",
|
||||
"does not exist",
|
||||
} {
|
||||
if strings.Contains(lower, marker) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func sleepRetry(ctx context.Context, attempt int, log *slog.Logger, provider string) error {
|
||||
delay := time.Duration(attempt*attempt) * 200 * time.Millisecond
|
||||
if log != nil {
|
||||
@@ -1110,59 +968,3 @@ func sleepMetadataRetry(ctx context.Context, attempt int) error {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) shouldBypassModel(model string) bool {
|
||||
c.modelHealthMu.Lock()
|
||||
defer c.modelHealthMu.Unlock()
|
||||
|
||||
state, ok := c.modelHealth[model]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return !state.unhealthyUntil.IsZero() && time.Now().Before(state.unhealthyUntil)
|
||||
}
|
||||
|
||||
func (c *Client) noteEmptyResponse(model string) {
|
||||
c.modelHealthMu.Lock()
|
||||
defer c.modelHealthMu.Unlock()
|
||||
|
||||
state := c.modelHealth[model]
|
||||
state.consecutiveEmpty++
|
||||
if state.consecutiveEmpty >= emptyResponseCircuitThreshold {
|
||||
state.unhealthyUntil = time.Now().Add(emptyResponseCircuitTTL)
|
||||
if c.log != nil {
|
||||
c.log.Warn("metadata model marked temporarily unhealthy after repeated empty responses",
|
||||
slog.String("provider", c.name),
|
||||
slog.String("model", model),
|
||||
slog.Time("until", state.unhealthyUntil),
|
||||
)
|
||||
}
|
||||
}
|
||||
c.modelHealth[model] = state
|
||||
}
|
||||
|
||||
func (c *Client) noteModelSuccess(model string) {
|
||||
c.modelHealthMu.Lock()
|
||||
defer c.modelHealthMu.Unlock()
|
||||
|
||||
delete(c.modelHealth, model)
|
||||
}
|
||||
|
||||
func (c *Client) notePermanentModelFailure(model string, err error) {
|
||||
c.modelHealthMu.Lock()
|
||||
defer c.modelHealthMu.Unlock()
|
||||
|
||||
state := c.modelHealth[model]
|
||||
state.consecutiveEmpty = emptyResponseCircuitThreshold
|
||||
state.unhealthyUntil = time.Now().Add(permanentModelFailureTTL)
|
||||
c.modelHealth[model] = state
|
||||
|
||||
if c.log != nil {
|
||||
c.log.Warn("metadata model marked unhealthy after permanent failure",
|
||||
slog.String("provider", c.name),
|
||||
slog.String("model", model),
|
||||
slog.String("error", err.Error()),
|
||||
slog.Time("until", state.unhealthyUntil),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,17 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newTestClient(t *testing.T, url string) *Client {
|
||||
t.Helper()
|
||||
return New(Config{
|
||||
Name: "litellm",
|
||||
BaseURL: url,
|
||||
APIKey: "test-key",
|
||||
HTTPClient: http.DefaultClient,
|
||||
Log: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractMetadataFromStreamingResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -26,6 +37,9 @@ func TestExtractMetadataFromStreamingResponse(t *testing.T) {
|
||||
if req.Stream == nil || !*req.Stream {
|
||||
t.Fatalf("stream flag = %v, want true", req.Stream)
|
||||
}
|
||||
if req.Model != "qwen3.5:latest" {
|
||||
t.Fatalf("model = %q, want qwen3.5:latest", req.Model)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"{\\\"people\\\":[],\"}}]}\n\n")
|
||||
@@ -35,20 +49,13 @@ func TestExtractMetadataFromStreamingResponse(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := New(Config{
|
||||
Name: "litellm",
|
||||
BaseURL: server.URL,
|
||||
APIKey: "test-key",
|
||||
MetadataModel: "qwen3.5:latest",
|
||||
Temperature: 0.1,
|
||||
HTTPClient: server.Client(),
|
||||
Log: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
EmbeddingModel: "unused",
|
||||
})
|
||||
|
||||
metadata, err := client.ExtractMetadata(context.Background(), "Project idea: Build an Android companion app.")
|
||||
client := newTestClient(t, server.URL)
|
||||
metadata, err := client.ExtractMetadataWith(context.Background(), MetadataOptions{
|
||||
Model: "qwen3.5:latest",
|
||||
Temperature: 0.1,
|
||||
}, "Project idea: Build an Android companion app.")
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractMetadata() error = %v", err)
|
||||
t.Fatalf("ExtractMetadataWith() error = %v", err)
|
||||
}
|
||||
|
||||
if metadata.Type != "idea" {
|
||||
@@ -94,20 +101,13 @@ func TestExtractMetadataRetriesWithoutJSONMode(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := New(Config{
|
||||
Name: "litellm",
|
||||
BaseURL: server.URL,
|
||||
APIKey: "test-key",
|
||||
MetadataModel: "qwen3.5:latest",
|
||||
Temperature: 0.1,
|
||||
HTTPClient: server.Client(),
|
||||
Log: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
EmbeddingModel: "unused",
|
||||
})
|
||||
|
||||
metadata, err := client.ExtractMetadata(context.Background(), "Project idea: Build an Android companion app.")
|
||||
client := newTestClient(t, server.URL)
|
||||
metadata, err := client.ExtractMetadataWith(context.Background(), MetadataOptions{
|
||||
Model: "qwen3.5:latest",
|
||||
Temperature: 0.1,
|
||||
}, "Project idea: Build an Android companion app.")
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractMetadata() error = %v", err)
|
||||
t.Fatalf("ExtractMetadataWith() error = %v", err)
|
||||
}
|
||||
|
||||
if metadata.Type != "idea" {
|
||||
@@ -127,71 +127,33 @@ func TestExtractMetadataRetriesWithoutJSONMode(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractMetadataBypassesInvalidFallbackModelAfterFirstFailure(t *testing.T) {
|
||||
func TestIsPermanentModelError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var mu sync.Mutex
|
||||
primaryCalls := 0
|
||||
invalidFallbackCalls := 0
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
_ = r.Body.Close()
|
||||
}()
|
||||
|
||||
var req chatCompletionsRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
|
||||
switch req.Model {
|
||||
case "empty-primary":
|
||||
_, _ = io.WriteString(w, `{"choices":[{"message":{"role":"assistant","content":""}}]}`)
|
||||
case "qwen3.5:latest":
|
||||
mu.Lock()
|
||||
primaryCalls++
|
||||
mu.Unlock()
|
||||
_, _ = io.WriteString(w, `{"choices":[{"message":{"role":"assistant","content":"{\"people\":[],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"metadata\"],\"type\":\"observation\",\"source\":\"primary\"}"}}]}`)
|
||||
case "qwen3":
|
||||
mu.Lock()
|
||||
invalidFallbackCalls++
|
||||
mu.Unlock()
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = io.WriteString(w, "{\"error\":{\"message\":\"{'error': '/chat/completions: Invalid model name passed in model=qwen3. Call `/v1/models` to view available models for your key.'}\"}}")
|
||||
default:
|
||||
t.Fatalf("unexpected model %q", req.Model)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := New(Config{
|
||||
Name: "litellm",
|
||||
BaseURL: server.URL,
|
||||
APIKey: "test-key",
|
||||
MetadataModel: "empty-primary",
|
||||
FallbackMetadataModels: []string{"qwen3", "qwen3.5:latest"},
|
||||
Temperature: 0.1,
|
||||
HTTPClient: server.Client(),
|
||||
Log: slog.New(slog.NewTextHandler(io.Discard, nil)),
|
||||
EmbeddingModel: "unused",
|
||||
})
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
metadata, err := client.ExtractMetadata(context.Background(), "A short note about metadata.")
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractMetadata() error = %v", err)
|
||||
}
|
||||
if metadata.Source != "primary" {
|
||||
t.Fatalf("metadata source = %q, want primary", metadata.Source)
|
||||
}
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{"nil", nil, false},
|
||||
{"invalid model", errMsg("Invalid model name passed in model=qwen3"), true},
|
||||
{"model not found", errMsg("model_not_found"), true},
|
||||
{"no such model", errMsg("no such model"), true},
|
||||
{"transient", errMsg("connection refused"), false},
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if invalidFallbackCalls != 1 {
|
||||
t.Fatalf("invalid fallback calls = %d, want 1", invalidFallbackCalls)
|
||||
}
|
||||
if primaryCalls != 2 {
|
||||
t.Fatalf("valid fallback calls = %d, want 2", primaryCalls)
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := IsPermanentModelError(tc.err); got != tc.want {
|
||||
t.Fatalf("IsPermanentModelError(%v) = %v, want %v", tc.err, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type stringError string
|
||||
|
||||
func (s stringError) Error() string { return string(s) }
|
||||
|
||||
func errMsg(s string) error { return stringError(s) }
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai/litellm"
|
||||
"git.warky.dev/wdevs/amcs/internal/ai/ollama"
|
||||
"git.warky.dev/wdevs/amcs/internal/ai/openrouter"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
)
|
||||
|
||||
func NewProvider(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (Provider, error) {
|
||||
switch cfg.Provider {
|
||||
case "litellm":
|
||||
return litellm.New(cfg, httpClient, log)
|
||||
case "ollama":
|
||||
return ollama.New(cfg, httpClient, log)
|
||||
case "openrouter":
|
||||
return openrouter.New(cfg, httpClient, log)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported ai.provider: %s", cfg.Provider)
|
||||
}
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
)
|
||||
|
||||
func TestNewProviderSupportsOllama(t *testing.T) {
|
||||
provider, err := NewProvider(config.AIConfig{
|
||||
Provider: "ollama",
|
||||
Embeddings: config.AIEmbeddingConfig{
|
||||
Model: "nomic-embed-text",
|
||||
Dimensions: 768,
|
||||
},
|
||||
Metadata: config.AIMetadataConfig{
|
||||
Model: "llama3.2",
|
||||
},
|
||||
Ollama: config.OllamaConfig{
|
||||
BaseURL: "http://localhost:11434/v1",
|
||||
APIKey: "ollama",
|
||||
},
|
||||
}, &http.Client{}, slog.New(slog.NewTextHandler(io.Discard, nil)))
|
||||
if err != nil {
|
||||
t.Fatalf("NewProvider() error = %v", err)
|
||||
}
|
||||
if provider.Name() != "ollama" {
|
||||
t.Fatalf("provider name = %q, want ollama", provider.Name())
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
package litellm
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai/compat"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
)
|
||||
|
||||
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
|
||||
fallbacks := cfg.LiteLLM.EffectiveFallbackMetadataModels()
|
||||
if len(fallbacks) == 0 {
|
||||
fallbacks = cfg.Metadata.EffectiveFallbackModels()
|
||||
}
|
||||
return compat.New(compat.Config{
|
||||
Name: "litellm",
|
||||
BaseURL: cfg.LiteLLM.BaseURL,
|
||||
APIKey: cfg.LiteLLM.APIKey,
|
||||
EmbeddingModel: cfg.LiteLLM.EmbeddingModel,
|
||||
MetadataModel: cfg.LiteLLM.MetadataModel,
|
||||
FallbackMetadataModels: fallbacks,
|
||||
Temperature: cfg.Metadata.Temperature,
|
||||
Headers: cfg.LiteLLM.RequestHeaders,
|
||||
HTTPClient: httpClient,
|
||||
Log: log,
|
||||
Dimensions: cfg.Embeddings.Dimensions,
|
||||
LogConversations: cfg.Metadata.LogConversations,
|
||||
}), nil
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai/compat"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
)
|
||||
|
||||
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
|
||||
return compat.New(compat.Config{
|
||||
Name: "ollama",
|
||||
BaseURL: cfg.Ollama.BaseURL,
|
||||
APIKey: cfg.Ollama.APIKey,
|
||||
EmbeddingModel: cfg.Embeddings.Model,
|
||||
MetadataModel: cfg.Metadata.Model,
|
||||
FallbackMetadataModels: cfg.Metadata.EffectiveFallbackModels(),
|
||||
Temperature: cfg.Metadata.Temperature,
|
||||
Headers: cfg.Ollama.RequestHeaders,
|
||||
HTTPClient: httpClient,
|
||||
Log: log,
|
||||
Dimensions: cfg.Embeddings.Dimensions,
|
||||
LogConversations: cfg.Metadata.LogConversations,
|
||||
}), nil
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package openrouter
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai/compat"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
)
|
||||
|
||||
func New(cfg config.AIConfig, httpClient *http.Client, log *slog.Logger) (*compat.Client, error) {
|
||||
headers := make(map[string]string, len(cfg.OpenRouter.ExtraHeaders)+2)
|
||||
for key, value := range cfg.OpenRouter.ExtraHeaders {
|
||||
headers[key] = value
|
||||
}
|
||||
if cfg.OpenRouter.SiteURL != "" {
|
||||
headers["HTTP-Referer"] = cfg.OpenRouter.SiteURL
|
||||
}
|
||||
if cfg.OpenRouter.AppName != "" {
|
||||
headers["X-Title"] = cfg.OpenRouter.AppName
|
||||
}
|
||||
|
||||
return compat.New(compat.Config{
|
||||
Name: "openrouter",
|
||||
BaseURL: cfg.OpenRouter.BaseURL,
|
||||
APIKey: cfg.OpenRouter.APIKey,
|
||||
EmbeddingModel: cfg.Embeddings.Model,
|
||||
MetadataModel: cfg.Metadata.Model,
|
||||
FallbackMetadataModels: cfg.Metadata.EffectiveFallbackModels(),
|
||||
Temperature: cfg.Metadata.Temperature,
|
||||
Headers: headers,
|
||||
HTTPClient: httpClient,
|
||||
Log: log,
|
||||
Dimensions: cfg.Embeddings.Dimensions,
|
||||
LogConversations: cfg.Metadata.LogConversations,
|
||||
}), nil
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
type Provider interface {
|
||||
Embed(ctx context.Context, input string) ([]float32, error)
|
||||
ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error)
|
||||
Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error)
|
||||
Name() string
|
||||
EmbeddingModel() string
|
||||
}
|
||||
96
internal/ai/registry.go
Normal file
96
internal/ai/registry.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai/compat"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
)
|
||||
|
||||
// Registry holds one compat.Client per named provider. Runners look up clients
|
||||
// by provider name when walking a role chain.
|
||||
type Registry struct {
|
||||
clients map[string]*compat.Client
|
||||
}
|
||||
|
||||
// NewRegistry builds a Registry from the configured providers. Each provider
|
||||
// type maps onto a compat.Client with type-specific header plumbing (e.g.
|
||||
// openrouter's HTTP-Referer / X-Title).
|
||||
func NewRegistry(providers map[string]config.ProviderConfig, httpClient *http.Client, log *slog.Logger) (*Registry, error) {
|
||||
if httpClient == nil {
|
||||
return nil, fmt.Errorf("ai registry: http client is required")
|
||||
}
|
||||
if len(providers) == 0 {
|
||||
return nil, fmt.Errorf("ai registry: no providers configured")
|
||||
}
|
||||
|
||||
clients := make(map[string]*compat.Client, len(providers))
|
||||
for name, p := range providers {
|
||||
headers, err := providerHeaders(p)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ai registry: provider %q: %w", name, err)
|
||||
}
|
||||
clients[name] = compat.New(compat.Config{
|
||||
Name: name,
|
||||
BaseURL: p.BaseURL,
|
||||
APIKey: p.APIKey,
|
||||
Headers: headers,
|
||||
HTTPClient: httpClient,
|
||||
Log: log,
|
||||
})
|
||||
}
|
||||
return &Registry{clients: clients}, nil
|
||||
}
|
||||
|
||||
// Client returns the compat.Client registered under name.
|
||||
func (r *Registry) Client(name string) (*compat.Client, error) {
|
||||
c, ok := r.clients[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("ai registry: provider %q is not configured", name)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Names returns the registered provider names.
|
||||
func (r *Registry) Names() []string {
|
||||
names := make([]string, 0, len(r.clients))
|
||||
for name := range r.clients {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func providerHeaders(p config.ProviderConfig) (map[string]string, error) {
|
||||
switch p.Type {
|
||||
case "litellm", "ollama":
|
||||
return cloneHeaders(p.RequestHeaders), nil
|
||||
case "openrouter":
|
||||
headers := cloneHeaders(p.RequestHeaders)
|
||||
if headers == nil {
|
||||
headers = map[string]string{}
|
||||
}
|
||||
if s := strings.TrimSpace(p.SiteURL); s != "" {
|
||||
headers["HTTP-Referer"] = s
|
||||
}
|
||||
if s := strings.TrimSpace(p.AppName); s != "" {
|
||||
headers["X-Title"] = s
|
||||
}
|
||||
return headers, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported provider type %q", p.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func cloneHeaders(in map[string]string) map[string]string {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]string, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
80
internal/ai/registry_test.go
Normal file
80
internal/ai/registry_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai/compat"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
)
|
||||
|
||||
func TestNewRegistryOpenRouterHeaders(t *testing.T) {
|
||||
var (
|
||||
gotReferer string
|
||||
gotTitle string
|
||||
gotCustom string
|
||||
)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotReferer = r.Header.Get("HTTP-Referer")
|
||||
gotTitle = r.Header.Get("X-Title")
|
||||
gotCustom = r.Header.Get("X-Custom")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{{"message": map[string]any{"role": "assistant", "content": "ok"}}},
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
providers := map[string]config.ProviderConfig{
|
||||
"router": {
|
||||
Type: "openrouter",
|
||||
BaseURL: srv.URL,
|
||||
APIKey: "secret",
|
||||
RequestHeaders: map[string]string{
|
||||
"X-Custom": "value",
|
||||
},
|
||||
AppName: "amcs",
|
||||
SiteURL: "https://example.com",
|
||||
},
|
||||
}
|
||||
|
||||
reg, err := NewRegistry(providers, srv.Client(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRegistry() error = %v", err)
|
||||
}
|
||||
|
||||
client, err := reg.Client("router")
|
||||
if err != nil {
|
||||
t.Fatalf("Client(router) error = %v", err)
|
||||
}
|
||||
|
||||
if _, err := client.SummarizeWith(context.Background(), compat.SummarizeOptions{Model: "gpt-4.1-mini"}, "system", "user"); err != nil {
|
||||
t.Fatalf("SummarizeWith() error = %v", err)
|
||||
}
|
||||
if gotReferer != "https://example.com" {
|
||||
t.Fatalf("HTTP-Referer = %q, want https://example.com", gotReferer)
|
||||
}
|
||||
if gotTitle != "amcs" {
|
||||
t.Fatalf("X-Title = %q, want amcs", gotTitle)
|
||||
}
|
||||
if gotCustom != "value" {
|
||||
t.Fatalf("X-Custom = %q, want value", gotCustom)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRegistryRejectsUnsupportedProviderType(t *testing.T) {
|
||||
providers := map[string]config.ProviderConfig{
|
||||
"bad": {
|
||||
Type: "unknown",
|
||||
BaseURL: "http://localhost:4000/v1",
|
||||
APIKey: "secret",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := NewRegistry(providers, &http.Client{}, nil)
|
||||
if err == nil {
|
||||
t.Fatal("NewRegistry() error = nil, want unsupported provider type error")
|
||||
}
|
||||
}
|
||||
367
internal/ai/runner.go
Normal file
367
internal/ai/runner.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/ai/compat"
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
// Health TTLs per failure class. These are short enough that a healed target
|
||||
// gets retried without manual intervention, but long enough to avoid hammering
|
||||
// a broken provider every call.
|
||||
const (
|
||||
transientCooldown = 30 * time.Second
|
||||
permanentCooldown = 10 * time.Minute
|
||||
emptyResponseThreshold = 3
|
||||
emptyResponseCooldown = 2 * time.Minute
|
||||
dimensionMismatchWarning = "embedding dimension mismatch"
|
||||
)
|
||||
|
||||
// EmbedResult carries the vector plus the (provider, model) that produced it —
|
||||
// callers store the actual model so later searches against that row use the
|
||||
// matching query embedding.
|
||||
type EmbedResult struct {
|
||||
Vector []float32
|
||||
Provider string
|
||||
Model string
|
||||
}
|
||||
|
||||
// EmbeddingRunner executes the embeddings role chain with sequential fallback.
|
||||
type EmbeddingRunner struct {
|
||||
registry *Registry
|
||||
chain []config.RoleTarget
|
||||
dimensions int
|
||||
health *healthTracker
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
// MetadataRunner executes the metadata role chain with sequential fallback and
|
||||
// a heuristic fallthrough when every target is unhealthy or fails.
|
||||
type MetadataRunner struct {
|
||||
registry *Registry
|
||||
chain []config.RoleTarget
|
||||
opts metadataRunOpts
|
||||
health *healthTracker
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
type metadataRunOpts struct {
|
||||
temperature float64
|
||||
logConversations bool
|
||||
}
|
||||
|
||||
// NewEmbeddingRunner builds a runner for the embeddings role. chain must be
|
||||
// non-empty and every target must be registered.
|
||||
func NewEmbeddingRunner(registry *Registry, chain []config.RoleTarget, dimensions int, log *slog.Logger) (*EmbeddingRunner, error) {
|
||||
if registry == nil {
|
||||
return nil, fmt.Errorf("embedding runner: registry is required")
|
||||
}
|
||||
if len(chain) == 0 {
|
||||
return nil, fmt.Errorf("embedding runner: chain is empty")
|
||||
}
|
||||
if dimensions <= 0 {
|
||||
return nil, fmt.Errorf("embedding runner: dimensions must be > 0")
|
||||
}
|
||||
for i, t := range chain {
|
||||
if _, err := registry.Client(t.Provider); err != nil {
|
||||
return nil, fmt.Errorf("embedding runner: chain[%d]: %w", i, err)
|
||||
}
|
||||
}
|
||||
return &EmbeddingRunner{
|
||||
registry: registry,
|
||||
chain: chain,
|
||||
dimensions: dimensions,
|
||||
health: newHealthTracker(),
|
||||
log: log,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewMetadataRunner builds a runner for the metadata role.
|
||||
func NewMetadataRunner(registry *Registry, chain []config.RoleTarget, temperature float64, logConversations bool, log *slog.Logger) (*MetadataRunner, error) {
|
||||
if registry == nil {
|
||||
return nil, fmt.Errorf("metadata runner: registry is required")
|
||||
}
|
||||
if len(chain) == 0 {
|
||||
return nil, fmt.Errorf("metadata runner: chain is empty")
|
||||
}
|
||||
for i, t := range chain {
|
||||
if _, err := registry.Client(t.Provider); err != nil {
|
||||
return nil, fmt.Errorf("metadata runner: chain[%d]: %w", i, err)
|
||||
}
|
||||
}
|
||||
return &MetadataRunner{
|
||||
registry: registry,
|
||||
chain: chain,
|
||||
opts: metadataRunOpts{
|
||||
temperature: temperature,
|
||||
logConversations: logConversations,
|
||||
},
|
||||
health: newHealthTracker(),
|
||||
log: log,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PrimaryProvider returns the first provider in the chain.
|
||||
func (r *EmbeddingRunner) PrimaryProvider() string { return r.chain[0].Provider }
|
||||
|
||||
// PrimaryModel returns the first model in the chain — the one used as the
|
||||
// storage key for search matching.
|
||||
func (r *EmbeddingRunner) PrimaryModel() string { return r.chain[0].Model }
|
||||
|
||||
// Dimensions returns the required vector dimension.
|
||||
func (r *EmbeddingRunner) Dimensions() int { return r.dimensions }
|
||||
|
||||
// Embed walks the chain and returns the first successful embedding. The
|
||||
// returned EmbedResult names the actual (provider, model) that produced the
|
||||
// vector — callers use that when recording the row.
|
||||
func (r *EmbeddingRunner) Embed(ctx context.Context, input string) (EmbedResult, error) {
|
||||
var errs []error
|
||||
for _, target := range r.chain {
|
||||
if r.health.skip(target) {
|
||||
continue
|
||||
}
|
||||
client, err := r.registry.Client(target.Provider)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
vec, err := client.EmbedWith(ctx, target.Model, input)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return EmbedResult{}, ctx.Err()
|
||||
}
|
||||
r.classify(target, err)
|
||||
r.logFailure("embed", target, err)
|
||||
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err))
|
||||
continue
|
||||
}
|
||||
if len(vec) != r.dimensions {
|
||||
dimErr := fmt.Errorf("%s: expected %d, got %d", dimensionMismatchWarning, r.dimensions, len(vec))
|
||||
r.health.markTransient(target)
|
||||
r.logFailure("embed", target, dimErr)
|
||||
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, dimErr))
|
||||
continue
|
||||
}
|
||||
r.health.markHealthy(target)
|
||||
return EmbedResult{Vector: vec, Provider: target.Provider, Model: target.Model}, nil
|
||||
}
|
||||
return EmbedResult{}, fmt.Errorf("all embedding targets failed: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
// EmbedPrimary embeds using only the primary target — used for search queries
|
||||
// so the query vector matches rows stored under the primary model. Falls back
|
||||
// to returning the error without walking the chain.
|
||||
func (r *EmbeddingRunner) EmbedPrimary(ctx context.Context, input string) ([]float32, error) {
|
||||
target := r.chain[0]
|
||||
client, err := r.registry.Client(target.Provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vec, err := client.EmbedWith(ctx, target.Model, input)
|
||||
if err != nil {
|
||||
r.classify(target, err)
|
||||
return nil, err
|
||||
}
|
||||
if len(vec) != r.dimensions {
|
||||
return nil, fmt.Errorf("%s: expected %d, got %d", dimensionMismatchWarning, r.dimensions, len(vec))
|
||||
}
|
||||
r.health.markHealthy(target)
|
||||
return vec, nil
|
||||
}
|
||||
|
||||
// PrimaryProvider / PrimaryModel for metadata mirror the embedding runner.
|
||||
func (r *MetadataRunner) PrimaryProvider() string { return r.chain[0].Provider }
|
||||
func (r *MetadataRunner) PrimaryModel() string { return r.chain[0].Model }
|
||||
|
||||
// ExtractMetadata walks the chain sequentially. If every target fails or is
|
||||
// unhealthy, it returns a heuristic metadata so capture never hard-fails.
|
||||
func (r *MetadataRunner) ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) {
|
||||
var errs []error
|
||||
for _, target := range r.chain {
|
||||
if r.health.skip(target) {
|
||||
continue
|
||||
}
|
||||
client, err := r.registry.Client(target.Provider)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
md, err := client.ExtractMetadataWith(ctx, compat.MetadataOptions{
|
||||
Model: target.Model,
|
||||
Temperature: r.opts.temperature,
|
||||
LogConversations: r.opts.logConversations,
|
||||
}, input)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return thoughttypes.ThoughtMetadata{}, ctx.Err()
|
||||
}
|
||||
r.classify(target, err)
|
||||
r.logFailure("metadata", target, err)
|
||||
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err))
|
||||
continue
|
||||
}
|
||||
r.health.markHealthy(target)
|
||||
return md, nil
|
||||
}
|
||||
if r.log != nil {
|
||||
r.log.Warn("metadata chain exhausted, using heuristic fallback",
|
||||
slog.Int("targets", len(r.chain)),
|
||||
slog.String("error", errors.Join(errs...).Error()),
|
||||
)
|
||||
}
|
||||
return compat.HeuristicMetadataFromInput(input), nil
|
||||
}
|
||||
|
||||
// Summarize walks the chain; unlike metadata, there is no heuristic fallback —
|
||||
// returns the joined error when everything fails.
|
||||
func (r *MetadataRunner) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
|
||||
var errs []error
|
||||
for _, target := range r.chain {
|
||||
if r.health.skip(target) {
|
||||
continue
|
||||
}
|
||||
client, err := r.registry.Client(target.Provider)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
out, err := client.SummarizeWith(ctx, compat.SummarizeOptions{
|
||||
Model: target.Model,
|
||||
Temperature: r.opts.temperature,
|
||||
}, systemPrompt, userPrompt)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return "", ctx.Err()
|
||||
}
|
||||
r.classify(target, err)
|
||||
r.logFailure("summarize", target, err)
|
||||
errs = append(errs, fmt.Errorf("%s/%s: %w", target.Provider, target.Model, err))
|
||||
continue
|
||||
}
|
||||
r.health.markHealthy(target)
|
||||
return out, nil
|
||||
}
|
||||
return "", fmt.Errorf("all summarize targets failed: %w", errors.Join(errs...))
|
||||
}
|
||||
|
||||
func (r *EmbeddingRunner) classify(target config.RoleTarget, err error) {
|
||||
switch {
|
||||
case compat.IsPermanentModelError(err):
|
||||
r.health.markPermanent(target)
|
||||
default:
|
||||
r.health.markTransient(target)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MetadataRunner) classify(target config.RoleTarget, err error) {
|
||||
switch {
|
||||
case compat.IsPermanentModelError(err):
|
||||
r.health.markPermanent(target)
|
||||
case errors.Is(err, compat.ErrEmptyResponse):
|
||||
r.health.markEmpty(target)
|
||||
default:
|
||||
r.health.markTransient(target)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *EmbeddingRunner) logFailure(role string, target config.RoleTarget, err error) {
|
||||
if r.log == nil {
|
||||
return
|
||||
}
|
||||
r.log.Warn("ai target failed",
|
||||
slog.String("role", role),
|
||||
slog.String("provider", target.Provider),
|
||||
slog.String("model", target.Model),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
|
||||
func (r *MetadataRunner) logFailure(role string, target config.RoleTarget, err error) {
|
||||
if r.log == nil {
|
||||
return
|
||||
}
|
||||
r.log.Warn("ai target failed",
|
||||
slog.String("role", role),
|
||||
slog.String("provider", target.Provider),
|
||||
slog.String("model", target.Model),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
|
||||
// healthTracker records per-(provider, model) failure state. skip returns true
|
||||
// when a target is still inside its cooldown window; the caller then tries the
|
||||
// next target in the chain.
|
||||
type healthTracker struct {
|
||||
mu sync.Mutex
|
||||
states map[config.RoleTarget]*healthState
|
||||
}
|
||||
|
||||
type healthState struct {
|
||||
unhealthyUntil time.Time
|
||||
emptyCount int
|
||||
}
|
||||
|
||||
func newHealthTracker() *healthTracker {
|
||||
return &healthTracker{states: map[config.RoleTarget]*healthState{}}
|
||||
}
|
||||
|
||||
func (h *healthTracker) skip(target config.RoleTarget) bool {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
s, ok := h.states[target]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return time.Now().Before(s.unhealthyUntil)
|
||||
}
|
||||
|
||||
func (h *healthTracker) markTransient(target config.RoleTarget) {
|
||||
h.setCooldown(target, transientCooldown)
|
||||
}
|
||||
|
||||
func (h *healthTracker) markPermanent(target config.RoleTarget) {
|
||||
h.setCooldown(target, permanentCooldown)
|
||||
}
|
||||
|
||||
func (h *healthTracker) markEmpty(target config.RoleTarget) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
s := h.states[target]
|
||||
if s == nil {
|
||||
s = &healthState{}
|
||||
h.states[target] = s
|
||||
}
|
||||
s.emptyCount++
|
||||
if s.emptyCount >= emptyResponseThreshold {
|
||||
s.unhealthyUntil = time.Now().Add(emptyResponseCooldown)
|
||||
s.emptyCount = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (h *healthTracker) markHealthy(target config.RoleTarget) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if s, ok := h.states[target]; ok {
|
||||
s.unhealthyUntil = time.Time{}
|
||||
s.emptyCount = 0
|
||||
}
|
||||
}
|
||||
|
||||
func (h *healthTracker) setCooldown(target config.RoleTarget, d time.Duration) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
s := h.states[target]
|
||||
if s == nil {
|
||||
s = &healthState{}
|
||||
h.states[target] = s
|
||||
}
|
||||
s.unhealthyUntil = time.Now().Add(d)
|
||||
s.emptyCount = 0
|
||||
}
|
||||
139
internal/ai/runner_test.go
Normal file
139
internal/ai/runner_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
)
|
||||
|
||||
func TestEmbeddingRunnerFallsBackAndSkipsUnhealthyPrimary(t *testing.T) {
|
||||
var (
|
||||
mu sync.Mutex
|
||||
primaryCalls int
|
||||
)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/embeddings" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
switch req.Model {
|
||||
case "embed-primary":
|
||||
mu.Lock()
|
||||
primaryCalls++
|
||||
mu.Unlock()
|
||||
http.Error(w, "upstream down", http.StatusBadGateway)
|
||||
case "embed-fallback":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"data": []map[string]any{{"embedding": []float32{0.1, 0.2, 0.3}}},
|
||||
})
|
||||
default:
|
||||
http.Error(w, "unknown model", http.StatusBadRequest)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
reg, err := NewRegistry(map[string]config.ProviderConfig{
|
||||
"p1": {Type: "litellm", BaseURL: srv.URL, APIKey: "k1"},
|
||||
"p2": {Type: "litellm", BaseURL: srv.URL, APIKey: "k2"},
|
||||
}, srv.Client(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRegistry() error = %v", err)
|
||||
}
|
||||
|
||||
runner, err := NewEmbeddingRunner(reg, []config.RoleTarget{
|
||||
{Provider: "p1", Model: "embed-primary"},
|
||||
{Provider: "p2", Model: "embed-fallback"},
|
||||
}, 3, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewEmbeddingRunner() error = %v", err)
|
||||
}
|
||||
|
||||
res, err := runner.Embed(context.Background(), "hello")
|
||||
if err != nil {
|
||||
t.Fatalf("Embed() first call error = %v", err)
|
||||
}
|
||||
if res.Provider != "p2" || res.Model != "embed-fallback" {
|
||||
t.Fatalf("Embed() first call target = %s/%s, want p2/embed-fallback", res.Provider, res.Model)
|
||||
}
|
||||
|
||||
res, err = runner.Embed(context.Background(), "hello again")
|
||||
if err != nil {
|
||||
t.Fatalf("Embed() second call error = %v", err)
|
||||
}
|
||||
if res.Provider != "p2" || res.Model != "embed-fallback" {
|
||||
t.Fatalf("Embed() second call target = %s/%s, want p2/embed-fallback", res.Provider, res.Model)
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
calls := primaryCalls
|
||||
mu.Unlock()
|
||||
if calls != 3 {
|
||||
t.Fatalf("primary calls = %d, want 3 (first request retries 3x; second call should skip unhealthy primary)", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMetadataRunnerSummarizeFallsBack(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/chat/completions" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
switch req.Model {
|
||||
case "sum-primary":
|
||||
http.Error(w, "provider error", http.StatusBadGateway)
|
||||
case "sum-fallback":
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{{
|
||||
"message": map[string]any{"role": "assistant", "content": "fallback summary"},
|
||||
}},
|
||||
})
|
||||
default:
|
||||
http.Error(w, "unknown model", http.StatusBadRequest)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
reg, err := NewRegistry(map[string]config.ProviderConfig{
|
||||
"p1": {Type: "litellm", BaseURL: srv.URL, APIKey: "k1"},
|
||||
"p2": {Type: "litellm", BaseURL: srv.URL, APIKey: "k2"},
|
||||
}, srv.Client(), nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRegistry() error = %v", err)
|
||||
}
|
||||
|
||||
runner, err := NewMetadataRunner(reg, []config.RoleTarget{
|
||||
{Provider: "p1", Model: "sum-primary"},
|
||||
{Provider: "p2", Model: "sum-fallback"},
|
||||
}, 0.1, false, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewMetadataRunner() error = %v", err)
|
||||
}
|
||||
|
||||
summary, err := runner.Summarize(context.Background(), "system", "user")
|
||||
if err != nil {
|
||||
t.Fatalf("Summarize() error = %v", err)
|
||||
}
|
||||
if summary != "fallback summary" {
|
||||
t.Fatalf("summary = %q, want %q", summary, "fallback summary")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user