Files
amcs/internal/ai/compat/client.go

394 lines
11 KiB
Go

package compat
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"strings"
"time"
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
)
const metadataSystemPrompt = `You extract structured metadata from short notes.
Return only valid JSON matching this schema:
{
"people": ["string"],
"action_items": ["string"],
"dates_mentioned": ["string"],
"topics": ["string"],
"type": "observation|task|idea|reference|person_note",
"source": "string"
}
Rules:
- Keep arrays concise.
- Use lowercase for type.
- If unsure, prefer "observation".
- Do not include any text outside the JSON object.`
type Client struct {
name string
baseURL string
apiKey string
embeddingModel string
metadataModel string
fallbackMetadataModel string
temperature float64
headers map[string]string
httpClient *http.Client
log *slog.Logger
dimensions int
}
type Config struct {
Name string
BaseURL string
APIKey string
EmbeddingModel string
MetadataModel string
FallbackMetadataModel string
Temperature float64
Headers map[string]string
HTTPClient *http.Client
Log *slog.Logger
Dimensions int
}
type embeddingsRequest struct {
Input string `json:"input"`
Model string `json:"model"`
}
type embeddingsResponse struct {
Data []struct {
Embedding []float32 `json:"embedding"`
} `json:"data"`
Error *providerError `json:"error,omitempty"`
}
type chatCompletionsRequest struct {
Model string `json:"model"`
Temperature float64 `json:"temperature,omitempty"`
ResponseFormat *responseType `json:"response_format,omitempty"`
Messages []chatMessage `json:"messages"`
}
type responseType struct {
Type string `json:"type"`
}
type chatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type chatCompletionsResponse struct {
Choices []struct {
Message chatMessage `json:"message"`
} `json:"choices"`
Error *providerError `json:"error,omitempty"`
}
type providerError struct {
Message string `json:"message"`
Type string `json:"type,omitempty"`
}
func New(cfg Config) *Client {
return &Client{
name: cfg.Name,
baseURL: cfg.BaseURL,
apiKey: cfg.APIKey,
embeddingModel: cfg.EmbeddingModel,
metadataModel: cfg.MetadataModel,
fallbackMetadataModel: cfg.FallbackMetadataModel,
temperature: cfg.Temperature,
headers: cfg.Headers,
httpClient: cfg.HTTPClient,
log: cfg.Log,
dimensions: cfg.Dimensions,
}
}
func (c *Client) Embed(ctx context.Context, input string) ([]float32, error) {
input = strings.TrimSpace(input)
if input == "" {
return nil, fmt.Errorf("%s embed: input must not be empty", c.name)
}
var resp embeddingsResponse
err := c.doJSON(ctx, "/embeddings", embeddingsRequest{
Input: input,
Model: c.embeddingModel,
}, &resp)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, fmt.Errorf("%s embed error: %s", c.name, resp.Error.Message)
}
if len(resp.Data) == 0 {
return nil, fmt.Errorf("%s embed: no embedding returned", c.name)
}
if c.dimensions > 0 && len(resp.Data[0].Embedding) != c.dimensions {
return nil, fmt.Errorf("%s embed: expected %d dimensions, got %d", c.name, c.dimensions, len(resp.Data[0].Embedding))
}
return resp.Data[0].Embedding, nil
}
func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) {
input = strings.TrimSpace(input)
if input == "" {
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: input must not be empty", c.name)
}
result, err := c.extractMetadataWithModel(ctx, input, c.metadataModel)
if err != nil && c.fallbackMetadataModel != "" && ctx.Err() == nil {
if c.log != nil {
c.log.Warn("metadata extraction failed, trying fallback model",
slog.String("provider", c.name),
slog.String("primary_model", c.metadataModel),
slog.String("fallback_model", c.fallbackMetadataModel),
slog.String("error", err.Error()),
)
}
return c.extractMetadataWithModel(ctx, input, c.fallbackMetadataModel)
}
return result, err
}
func (c *Client) extractMetadataWithModel(ctx context.Context, input, model string) (thoughttypes.ThoughtMetadata, error) {
req := chatCompletionsRequest{
Model: model,
Temperature: c.temperature,
ResponseFormat: &responseType{
Type: "json_object",
},
Messages: []chatMessage{
{Role: "system", Content: metadataSystemPrompt},
{Role: "user", Content: input},
},
}
var resp chatCompletionsResponse
if err := c.doJSON(ctx, "/chat/completions", req, &resp); err != nil {
return thoughttypes.ThoughtMetadata{}, err
}
if resp.Error != nil {
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata error: %s", c.name, resp.Error.Message)
}
if len(resp.Choices) == 0 {
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: no choices returned", c.name)
}
metadataText := strings.TrimSpace(resp.Choices[0].Message.Content)
metadataText = stripThinkingBlocks(metadataText)
metadataText = stripCodeFence(metadataText)
metadataText = extractJSONObject(metadataText)
if metadataText == "" {
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: response contains no JSON object", c.name)
}
var metadata thoughttypes.ThoughtMetadata
if err := json.Unmarshal([]byte(metadataText), &metadata); err != nil {
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: parse json: %w", c.name, err)
}
return metadata, nil
}
func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
req := chatCompletionsRequest{
Model: c.metadataModel,
Temperature: 0.2,
Messages: []chatMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userPrompt},
},
}
var resp chatCompletionsResponse
if err := c.doJSON(ctx, "/chat/completions", req, &resp); err != nil {
return "", err
}
if resp.Error != nil {
return "", fmt.Errorf("%s summarize error: %s", c.name, resp.Error.Message)
}
if len(resp.Choices) == 0 {
return "", fmt.Errorf("%s summarize: no choices returned", c.name)
}
return strings.TrimSpace(resp.Choices[0].Message.Content), nil
}
func (c *Client) Name() string {
return c.name
}
func (c *Client) EmbeddingModel() string {
return c.embeddingModel
}
func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest any) error {
body, err := json.Marshal(requestBody)
if err != nil {
return fmt.Errorf("%s request marshal: %w", c.name, err)
}
const maxAttempts = 3
var lastErr error
for attempt := 1; attempt <= maxAttempts; attempt++ {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(c.baseURL, "/")+path, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("%s build request: %w", c.name, err)
}
req.Header.Set("Authorization", "Bearer "+c.apiKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
for key, value := range c.headers {
if strings.TrimSpace(key) == "" || strings.TrimSpace(value) == "" {
continue
}
req.Header.Set(key, value)
}
resp, err := c.httpClient.Do(req)
if err != nil {
lastErr = fmt.Errorf("%s request failed: %w", c.name, err)
if attempt < maxAttempts && ctx.Err() == nil && isRetryableError(err) {
if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil {
return retryErr
}
continue
}
return lastErr
}
payload, readErr := io.ReadAll(resp.Body)
resp.Body.Close()
if readErr != nil {
lastErr = fmt.Errorf("%s read response: %w", c.name, readErr)
if attempt < maxAttempts {
if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil {
return retryErr
}
continue
}
return lastErr
}
if resp.StatusCode >= http.StatusBadRequest {
lastErr = fmt.Errorf("%s request failed with status %d: %s", c.name, resp.StatusCode, strings.TrimSpace(string(payload)))
if attempt < maxAttempts && isRetryableStatus(resp.StatusCode) {
if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil {
return retryErr
}
continue
}
return lastErr
}
if err := json.Unmarshal(payload, dest); err != nil {
if c.log != nil {
c.log.Debug("provider response body", slog.String("provider", c.name), slog.String("body", string(payload)))
}
return fmt.Errorf("%s decode response: %w", c.name, err)
}
return nil
}
return lastErr
}
// extractJSONObject finds the first complete {...} block in s.
// It handles models that prepend prose to a JSON response despite json_object mode.
func extractJSONObject(s string) string {
start := strings.Index(s, "{")
end := strings.LastIndex(s, "}")
if start == -1 || end == -1 || end <= start {
return ""
}
return s[start : end+1]
}
// stripThinkingBlocks removes <think>...</think> and <thinking>...</thinking>
// blocks produced by reasoning models (DeepSeek R1, QwQ, etc.) so that the
// remaining text can be parsed as JSON without interference from thinking content
// that may itself contain braces.
func stripThinkingBlocks(s string) string {
for _, tag := range []string{"think", "thinking"} {
open := "<" + tag + ">"
close := "</" + tag + ">"
for {
start := strings.Index(s, open)
if start == -1 {
break
}
end := strings.Index(s[start:], close)
if end == -1 {
s = s[:start]
break
}
s = s[:start] + s[start+end+len(close):]
}
}
return strings.TrimSpace(s)
}
func stripCodeFence(value string) string {
value = strings.TrimSpace(value)
if !strings.HasPrefix(value, "```") {
return value
}
value = strings.TrimPrefix(value, "```json")
value = strings.TrimPrefix(value, "```")
value = strings.TrimSuffix(value, "```")
return strings.TrimSpace(value)
}
func isRetryableStatus(status int) bool {
switch status {
case http.StatusTooManyRequests, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
return true
default:
return false
}
}
func isRetryableError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
return true
}
var netErr net.Error
return errors.As(err, &netErr)
}
func sleepRetry(ctx context.Context, attempt int, log *slog.Logger, provider string) error {
delay := time.Duration(attempt*attempt) * 200 * time.Millisecond
if log != nil {
log.Warn("retrying provider request", slog.String("provider", provider), slog.Duration("delay", delay), slog.Int("attempt", attempt+1))
}
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}