test(config): add migration tests for litellm provider
Some checks failed
CI / build-and-test (push) Failing after -32m22s

* Implement tests for migrating configuration from v1 to v2 for the litellm provider.
* Validate the structure and values of the migrated configuration.
* Ensure migration rejects newer versions of the configuration.
fix(validate): enhance AI provider validation logic
* Consolidate provider validation into a dedicated method.
* Ensure at least one provider is specified and validate its type.
* Check for required fields based on provider type.
fix(mcpserver): update tool set to use new enrichment tool
* Replace RetryMetadataTool with RetryEnrichmentTool in the ToolSet.
fix(tools): refactor tools to use embedding and metadata runners
* Update tools to utilize EmbeddingRunner and MetadataRunner instead of Provider.
* Adjust method calls to align with the new runner interfaces.
This commit is contained in:
2026-04-21 21:14:28 +02:00
parent 532d1560a3
commit 14e218d784
39 changed files with 2062 additions and 901 deletions

View File

@@ -14,7 +14,6 @@ import (
"regexp"
"slices"
"strings"
"sync"
"time"
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
@@ -36,36 +35,39 @@ Rules:
- If unsure, prefer "observation".
- Do not include any text outside the JSON object.`
// Client is a low-level OpenAI-compatible HTTP client. It knows nothing about
// role chains, fallbacks, or health — those concerns belong to ai.Runner. Each
// method takes the model name per-call so a single Client instance can service
// many different models on the same base URL.
type Client struct {
name string
baseURL string
apiKey string
embeddingModel string
metadataModel string
fallbackMetadataModels []string
temperature float64
headers map[string]string
httpClient *http.Client
log *slog.Logger
dimensions int
logConversations bool
modelHealthMu sync.Mutex
modelHealth map[string]modelHealthState
name string
baseURL string
apiKey string
headers map[string]string
httpClient *http.Client
log *slog.Logger
}
type Config struct {
Name string
BaseURL string
APIKey string
EmbeddingModel string
MetadataModel string
FallbackMetadataModels []string
Temperature float64
Headers map[string]string
HTTPClient *http.Client
Log *slog.Logger
Dimensions int
LogConversations bool
Name string
BaseURL string
APIKey string
Headers map[string]string
HTTPClient *http.Client
Log *slog.Logger
}
// MetadataOptions control a single ExtractMetadataWith call.
type MetadataOptions struct {
Model string
Temperature float64
LogConversations bool
}
// SummarizeOptions control a single SummarizeWith call.
type SummarizeOptions struct {
Model string
Temperature float64
}
type embeddingsRequest struct {
@@ -127,65 +129,38 @@ type providerError struct {
const maxMetadataAttempts = 3
const (
emptyResponseCircuitThreshold = 3
emptyResponseCircuitTTL = 5 * time.Minute
permanentModelFailureTTL = 24 * time.Hour
)
// ErrEmptyResponse and ErrNoJSONObject are sentinel errors callers can inspect
// to classify metadata failures (e.g. bump empty-response health counters).
var (
errMetadataEmptyResponse = errors.New("metadata empty response")
errMetadataNoJSONObject = errors.New("metadata response contains no JSON object")
ErrEmptyResponse = errors.New("metadata empty response")
ErrNoJSONObject = errors.New("metadata response contains no JSON object")
)
type modelHealthState struct {
consecutiveEmpty int
unhealthyUntil time.Time
}
func New(cfg Config) *Client {
fallbacks := make([]string, 0, len(cfg.FallbackMetadataModels))
seen := make(map[string]struct{}, len(cfg.FallbackMetadataModels))
for _, model := range cfg.FallbackMetadataModels {
model = strings.TrimSpace(model)
if model == "" {
continue
}
if _, ok := seen[model]; ok {
continue
}
seen[model] = struct{}{}
fallbacks = append(fallbacks, model)
}
return &Client{
name: cfg.Name,
baseURL: cfg.BaseURL,
apiKey: cfg.APIKey,
embeddingModel: cfg.EmbeddingModel,
metadataModel: cfg.MetadataModel,
fallbackMetadataModels: fallbacks,
temperature: cfg.Temperature,
headers: cfg.Headers,
httpClient: cfg.HTTPClient,
log: cfg.Log,
dimensions: cfg.Dimensions,
logConversations: cfg.LogConversations,
modelHealth: make(map[string]modelHealthState),
name: cfg.Name,
baseURL: cfg.BaseURL,
apiKey: cfg.APIKey,
headers: cfg.Headers,
httpClient: cfg.HTTPClient,
log: cfg.Log,
}
}
func (c *Client) Embed(ctx context.Context, input string) ([]float32, error) {
func (c *Client) Name() string { return c.name }
// EmbedWith generates an embedding for the given input using model.
func (c *Client) EmbedWith(ctx context.Context, model, input string) ([]float32, error) {
input = strings.TrimSpace(input)
if input == "" {
return nil, fmt.Errorf("%s embed: input must not be empty", c.name)
}
if strings.TrimSpace(model) == "" {
return nil, fmt.Errorf("%s embed: model is required", c.name)
}
var resp embeddingsResponse
err := c.doJSON(ctx, "/embeddings", embeddingsRequest{
Input: input,
Model: c.embeddingModel,
}, &resp)
err := c.doJSON(ctx, "/embeddings", embeddingsRequest{Input: input, Model: model}, &resp)
if err != nil {
return nil, err
}
@@ -195,141 +170,34 @@ func (c *Client) Embed(ctx context.Context, input string) ([]float32, error) {
if len(resp.Data) == 0 {
return nil, fmt.Errorf("%s embed: no embedding returned", c.name)
}
if c.dimensions > 0 && len(resp.Data[0].Embedding) != c.dimensions {
return nil, fmt.Errorf("%s embed: expected %d dimensions, got %d", c.name, c.dimensions, len(resp.Data[0].Embedding))
}
return resp.Data[0].Embedding, nil
}
func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) {
// ExtractMetadataWith extracts structured metadata for input using opts.Model.
// Returns compat.ErrEmptyResponse / ErrNoJSONObject wrapped when the model
// produces unusable output so callers can classify the failure.
func (c *Client) ExtractMetadataWith(ctx context.Context, opts MetadataOptions, input string) (thoughttypes.ThoughtMetadata, error) {
input = strings.TrimSpace(input)
if input == "" {
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: input must not be empty", c.name)
}
start := time.Now()
if c.log != nil {
c.log.Info("metadata client started",
slog.String("provider", c.name),
slog.String("model", c.metadataModel),
)
}
logCompletion := func(model string, err error) {
if c.log == nil {
return
}
attrs := []any{
slog.String("provider", c.name),
slog.String("model", model),
slog.String("duration", formatLogDuration(time.Since(start))),
}
if err != nil {
attrs = append(attrs, slog.String("error", err.Error()))
c.log.Error("metadata client completed", attrs...)
return
}
c.log.Info("metadata client completed", attrs...)
}
result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel)
if errors.Is(err, errMetadataEmptyResponse) {
c.noteEmptyResponse(c.metadataModel)
}
if isPermanentModelError(err) {
c.notePermanentModelFailure(c.metadataModel, err)
}
if err == nil {
c.noteModelSuccess(c.metadataModel)
logCompletion(c.metadataModel, nil)
return result, nil
}
for _, fallbackModel := range c.fallbackMetadataModels {
if ctx.Err() != nil {
break
}
if fallbackModel == "" || fallbackModel == c.metadataModel {
continue
}
if c.shouldBypassModel(fallbackModel) {
continue
}
if c.log != nil {
c.log.Warn("metadata extraction failed, trying fallback model",
slog.String("provider", c.name),
slog.String("primary_model", c.metadataModel),
slog.String("fallback_model", fallbackModel),
slog.String("error", err.Error()),
)
}
fallbackResult, fallbackErr := c.extractMetadataWithModel(ctx, input, fallbackModel)
if errors.Is(fallbackErr, errMetadataEmptyResponse) {
c.noteEmptyResponse(fallbackModel)
}
if isPermanentModelError(fallbackErr) {
c.notePermanentModelFailure(fallbackModel, fallbackErr)
}
if fallbackErr == nil {
c.noteModelSuccess(fallbackModel)
logCompletion(fallbackModel, nil)
return fallbackResult, nil
}
err = fallbackErr
}
if ctx.Err() != nil {
err = fmt.Errorf("%s metadata: %w", c.name, ctx.Err())
logCompletion(c.metadataModel, err)
return thoughttypes.ThoughtMetadata{}, err
}
heuristic := heuristicMetadataFromInput(input)
if c.log != nil {
c.log.Warn("metadata extraction failed for all models, using heuristic fallback",
slog.String("provider", c.name),
slog.String("error", err.Error()),
)
}
logCompletion(c.metadataModel, nil)
return heuristic, nil
}
func formatLogDuration(d time.Duration) string {
if d < 0 {
d = -d
}
totalMilliseconds := d.Milliseconds()
minutes := totalMilliseconds / 60000
seconds := (totalMilliseconds / 1000) % 60
milliseconds := totalMilliseconds % 1000
return fmt.Sprintf("%02d:%02d:%03d", minutes, seconds, milliseconds)
}
func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) {
if c.shouldBypassModel(model) {
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: model %q temporarily bypassed after repeated empty responses", c.name, model)
if strings.TrimSpace(opts.Model) == "" {
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: model is required", c.name)
}
stream := true
req := chatCompletionsRequest{
Model: model,
Temperature: c.temperature,
ResponseFormat: &responseType{
Type: "json_object",
},
Stream: &stream,
Model: opts.Model,
Temperature: opts.Temperature,
ResponseFormat: &responseType{Type: "json_object"},
Stream: &stream,
Messages: []chatMessage{
{Role: "system", Content: metadataSystemPrompt},
{Role: "user", Content: input},
},
}
metadata, err := c.extractMetadataWithRequest(ctx, req, input, model)
metadata, err := c.extractMetadataWithRequest(ctx, req, input, opts)
if err == nil || !shouldRetryWithoutJSONMode(err) {
return metadata, err
}
@@ -337,23 +205,22 @@ func (c *Client) extractMetadataWithModel(ctx context.Context, input, model stri
if c.log != nil {
c.log.Warn("metadata json mode failed, retrying without response_format",
slog.String("provider", c.name),
slog.String("model", model),
slog.String("model", opts.Model),
slog.String("error", err.Error()),
)
}
req.ResponseFormat = nil
return c.extractMetadataWithRequest(ctx, req, input, model)
return c.extractMetadataWithRequest(ctx, req, input, opts)
}
func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatCompletionsRequest, input, model string) (thoughttypes.ThoughtMetadata, error) {
func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatCompletionsRequest, input string, opts MetadataOptions) (thoughttypes.ThoughtMetadata, error) {
var lastErr error
for attempt := 1; attempt <= maxMetadataAttempts; attempt++ {
if c.logConversations && c.log != nil {
if opts.LogConversations && c.log != nil {
c.log.Info("metadata conversation request",
slog.String("provider", c.name),
slog.String("model", model),
slog.String("model", opts.Model),
slog.Int("attempt", attempt),
slog.String("system", metadataSystemPrompt),
slog.String("input", input),
@@ -373,10 +240,10 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet
rawResponse := extractChoiceText(resp.Choices[0].Message, resp.Choices[0].Text)
if c.logConversations && c.log != nil {
if opts.LogConversations && c.log != nil {
c.log.Info("metadata conversation response",
slog.String("provider", c.name),
slog.String("model", model),
slog.String("model", opts.Model),
slog.Int("attempt", attempt),
slog.String("response", rawResponse),
)
@@ -387,13 +254,13 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet
metadataText = stripCodeFence(metadataText)
metadataText = extractJSONObject(metadataText)
if metadataText == "" {
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataNoJSONObject)
lastErr = fmt.Errorf("%s metadata: %w", c.name, ErrNoJSONObject)
if strings.TrimSpace(rawResponse) == "" && attempt < maxMetadataAttempts && ctx.Err() == nil {
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataEmptyResponse)
lastErr = fmt.Errorf("%s metadata: %w", c.name, ErrEmptyResponse)
if c.log != nil {
c.log.Warn("metadata response empty, waiting and retrying",
slog.String("provider", c.name),
slog.String("model", model),
slog.String("model", opts.Model),
slog.Int("attempt", attempt+1),
)
}
@@ -403,7 +270,7 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet
continue
}
if strings.TrimSpace(rawResponse) == "" {
lastErr = fmt.Errorf("%s metadata: %w", c.name, errMetadataEmptyResponse)
lastErr = fmt.Errorf("%s metadata: %w", c.name, ErrEmptyResponse)
}
return thoughttypes.ThoughtMetadata{}, lastErr
}
@@ -420,13 +287,17 @@ func (c *Client) extractMetadataWithRequest(ctx context.Context, req chatComplet
if lastErr != nil {
return thoughttypes.ThoughtMetadata{}, lastErr
}
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: %w", c.name, errMetadataNoJSONObject)
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: %w", c.name, ErrNoJSONObject)
}
func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
// SummarizeWith runs a chat-completion summarisation using opts.Model.
func (c *Client) SummarizeWith(ctx context.Context, opts SummarizeOptions, systemPrompt, userPrompt string) (string, error) {
if strings.TrimSpace(opts.Model) == "" {
return "", fmt.Errorf("%s summarize: model is required", c.name)
}
req := chatCompletionsRequest{
Model: c.metadataModel,
Temperature: 0.2,
Model: opts.Model,
Temperature: opts.Temperature,
Messages: []chatMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userPrompt},
@@ -447,12 +318,49 @@ func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string)
return extractChoiceText(resp.Choices[0].Message, resp.Choices[0].Text), nil
}
func (c *Client) Name() string {
return c.name
// IsPermanentModelError reports whether err indicates the model itself is
// invalid or missing (vs. a transient outage). Runners use this to mark a
// target unhealthy for longer.
func IsPermanentModelError(err error) bool {
if err == nil {
return false
}
lower := strings.ToLower(err.Error())
for _, marker := range []string{
"invalid model name",
"model_not_found",
"model not found",
"unknown model",
"no such model",
"does not exist",
} {
if strings.Contains(lower, marker) {
return true
}
}
return false
}
func (c *Client) EmbeddingModel() string {
return c.embeddingModel
// HeuristicMetadataFromInput produces best-effort metadata from the note text
// when every model in the chain has failed. Exported so ai.Runner can use it.
func HeuristicMetadataFromInput(input string) thoughttypes.ThoughtMetadata {
text := strings.TrimSpace(input)
lower := strings.ToLower(text)
metadata := thoughttypes.ThoughtMetadata{
People: heuristicPeople(text),
ActionItems: heuristicActionItems(text),
DatesMentioned: heuristicDates(text),
Topics: heuristicTopics(lower),
Type: heuristicType(lower),
}
if len(metadata.Topics) == 0 {
metadata.Topics = []string{"uncategorized"}
}
if metadata.Type == "" {
metadata.Type = "observation"
}
return metadata
}
func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest any) error {
@@ -724,8 +632,6 @@ func isRetryableChatResponseError(err error) bool {
return strings.Contains(lower, "read response") || strings.Contains(lower, "read stream response")
}
// extractJSONObject finds the first complete {...} block in s.
// It handles models that prepend prose to a JSON response despite json_object mode.
func extractJSONObject(s string) string {
for start := 0; start < len(s); start++ {
if s[start] != '{' {
@@ -768,10 +674,6 @@ func extractJSONObject(s string) string {
return ""
}
// stripThinkingBlocks removes <think>...</think> and <thinking>...</thinking>
// blocks produced by reasoning models (DeepSeek R1, QwQ, etc.) so that the
// remaining text can be parsed as JSON without interference from thinking content
// that may itself contain braces.
func stripThinkingBlocks(s string) string {
for _, tag := range []string{"think", "thinking"} {
open := "<" + tag + ">"
@@ -857,7 +759,6 @@ func extractTextFromAny(value any) string {
}
return strings.Join(parts, "\n")
case map[string]any:
// Common provider shapes for chat content parts.
for _, key := range []string{"text", "output_text", "content", "value"} {
if nested, ok := typed[key]; ok {
if text := strings.TrimSpace(extractTextFromAny(nested)); text != "" {
@@ -875,28 +776,6 @@ var (
wordPattern = regexp.MustCompile(`[a-zA-Z][a-zA-Z0-9_/-]{2,}`)
)
func heuristicMetadataFromInput(input string) thoughttypes.ThoughtMetadata {
text := strings.TrimSpace(input)
lower := strings.ToLower(text)
metadata := thoughttypes.ThoughtMetadata{
People: heuristicPeople(text),
ActionItems: heuristicActionItems(text),
DatesMentioned: heuristicDates(text),
Topics: heuristicTopics(lower),
Type: heuristicType(lower),
Source: "",
}
if len(metadata.Topics) == 0 {
metadata.Topics = []string{"uncategorized"}
}
if metadata.Type == "" {
metadata.Type = "observation"
}
return metadata
}
func heuristicType(lower string) string {
switch {
case strings.Contains(lower, "preferred name"), strings.Contains(lower, "personal profile"), strings.Contains(lower, "wife:"), strings.Contains(lower, "daughter:"), strings.Contains(lower, "born:"):
@@ -1055,7 +934,7 @@ func shouldRetryWithoutJSONMode(err error) bool {
if err == nil {
return false
}
if errors.Is(err, errMetadataEmptyResponse) || errors.Is(err, errMetadataNoJSONObject) {
if errors.Is(err, ErrEmptyResponse) || errors.Is(err, ErrNoJSONObject) {
return true
}
@@ -1063,27 +942,6 @@ func shouldRetryWithoutJSONMode(err error) bool {
return strings.Contains(lower, "parse json")
}
func isPermanentModelError(err error) bool {
if err == nil {
return false
}
lower := strings.ToLower(err.Error())
for _, marker := range []string{
"invalid model name",
"model_not_found",
"model not found",
"unknown model",
"no such model",
"does not exist",
} {
if strings.Contains(lower, marker) {
return true
}
}
return false
}
func sleepRetry(ctx context.Context, attempt int, log *slog.Logger, provider string) error {
delay := time.Duration(attempt*attempt) * 200 * time.Millisecond
if log != nil {
@@ -1110,59 +968,3 @@ func sleepMetadataRetry(ctx context.Context, attempt int) error {
return nil
}
}
func (c *Client) shouldBypassModel(model string) bool {
c.modelHealthMu.Lock()
defer c.modelHealthMu.Unlock()
state, ok := c.modelHealth[model]
if !ok {
return false
}
return !state.unhealthyUntil.IsZero() && time.Now().Before(state.unhealthyUntil)
}
func (c *Client) noteEmptyResponse(model string) {
c.modelHealthMu.Lock()
defer c.modelHealthMu.Unlock()
state := c.modelHealth[model]
state.consecutiveEmpty++
if state.consecutiveEmpty >= emptyResponseCircuitThreshold {
state.unhealthyUntil = time.Now().Add(emptyResponseCircuitTTL)
if c.log != nil {
c.log.Warn("metadata model marked temporarily unhealthy after repeated empty responses",
slog.String("provider", c.name),
slog.String("model", model),
slog.Time("until", state.unhealthyUntil),
)
}
}
c.modelHealth[model] = state
}
func (c *Client) noteModelSuccess(model string) {
c.modelHealthMu.Lock()
defer c.modelHealthMu.Unlock()
delete(c.modelHealth, model)
}
func (c *Client) notePermanentModelFailure(model string, err error) {
c.modelHealthMu.Lock()
defer c.modelHealthMu.Unlock()
state := c.modelHealth[model]
state.consecutiveEmpty = emptyResponseCircuitThreshold
state.unhealthyUntil = time.Now().Add(permanentModelFailureTTL)
c.modelHealth[model] = state
if c.log != nil {
c.log.Warn("metadata model marked unhealthy after permanent failure",
slog.String("provider", c.name),
slog.String("model", model),
slog.String("error", err.Error()),
slog.Time("until", state.unhealthyUntil),
)
}
}