feat(tools): implement CRUD operations for thoughts and projects
* Add tools for creating, retrieving, updating, and deleting thoughts. * Implement project management tools for creating and listing projects. * Introduce linking functionality between thoughts. * Add search and recall capabilities for thoughts based on semantic queries. * Implement statistics and summarization tools for thought analysis. * Create database migrations for thoughts, projects, and links. * Add helper functions for UUID parsing and project resolution.
This commit is contained in:
330
internal/ai/compat/client.go
Normal file
330
internal/ai/compat/client.go
Normal file
@@ -0,0 +1,330 @@
|
||||
package compat
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
const metadataSystemPrompt = `You extract structured metadata from short notes.
|
||||
Return only valid JSON matching this schema:
|
||||
{
|
||||
"people": ["string"],
|
||||
"action_items": ["string"],
|
||||
"dates_mentioned": ["string"],
|
||||
"topics": ["string"],
|
||||
"type": "observation|task|idea|reference|person_note",
|
||||
"source": "string"
|
||||
}
|
||||
Rules:
|
||||
- Keep arrays concise.
|
||||
- Use lowercase for type.
|
||||
- If unsure, prefer "observation".
|
||||
- Do not include any text outside the JSON object.`
|
||||
|
||||
type Client struct {
|
||||
name string
|
||||
baseURL string
|
||||
apiKey string
|
||||
embeddingModel string
|
||||
metadataModel string
|
||||
temperature float64
|
||||
headers map[string]string
|
||||
httpClient *http.Client
|
||||
log *slog.Logger
|
||||
dimensions int
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string
|
||||
BaseURL string
|
||||
APIKey string
|
||||
EmbeddingModel string
|
||||
MetadataModel string
|
||||
Temperature float64
|
||||
Headers map[string]string
|
||||
HTTPClient *http.Client
|
||||
Log *slog.Logger
|
||||
Dimensions int
|
||||
}
|
||||
|
||||
type embeddingsRequest struct {
|
||||
Input string `json:"input"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
type embeddingsResponse struct {
|
||||
Data []struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
} `json:"data"`
|
||||
Error *providerError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type chatCompletionsRequest struct {
|
||||
Model string `json:"model"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
ResponseFormat *responseType `json:"response_format,omitempty"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type responseType struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type chatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type chatCompletionsResponse struct {
|
||||
Choices []struct {
|
||||
Message chatMessage `json:"message"`
|
||||
} `json:"choices"`
|
||||
Error *providerError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type providerError struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
func New(cfg Config) *Client {
|
||||
return &Client{
|
||||
name: cfg.Name,
|
||||
baseURL: cfg.BaseURL,
|
||||
apiKey: cfg.APIKey,
|
||||
embeddingModel: cfg.EmbeddingModel,
|
||||
metadataModel: cfg.MetadataModel,
|
||||
temperature: cfg.Temperature,
|
||||
headers: cfg.Headers,
|
||||
httpClient: cfg.HTTPClient,
|
||||
log: cfg.Log,
|
||||
dimensions: cfg.Dimensions,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Embed(ctx context.Context, input string) ([]float32, error) {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return nil, fmt.Errorf("%s embed: input must not be empty", c.name)
|
||||
}
|
||||
|
||||
var resp embeddingsResponse
|
||||
err := c.doJSON(ctx, "/embeddings", embeddingsRequest{
|
||||
Input: input,
|
||||
Model: c.embeddingModel,
|
||||
}, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return nil, fmt.Errorf("%s embed error: %s", c.name, resp.Error.Message)
|
||||
}
|
||||
if len(resp.Data) == 0 {
|
||||
return nil, fmt.Errorf("%s embed: no embedding returned", c.name)
|
||||
}
|
||||
if c.dimensions > 0 && len(resp.Data[0].Embedding) != c.dimensions {
|
||||
return nil, fmt.Errorf("%s embed: expected %d dimensions, got %d", c.name, c.dimensions, len(resp.Data[0].Embedding))
|
||||
}
|
||||
|
||||
return resp.Data[0].Embedding, nil
|
||||
}
|
||||
|
||||
func (c *Client) ExtractMetadata(ctx context.Context, input string) (thoughttypes.ThoughtMetadata, error) {
|
||||
input = strings.TrimSpace(input)
|
||||
if input == "" {
|
||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s extract metadata: input must not be empty", c.name)
|
||||
}
|
||||
|
||||
req := chatCompletionsRequest{
|
||||
Model: c.metadataModel,
|
||||
Temperature: c.temperature,
|
||||
ResponseFormat: &responseType{
|
||||
Type: "json_object",
|
||||
},
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: metadataSystemPrompt},
|
||||
{Role: "user", Content: input},
|
||||
},
|
||||
}
|
||||
|
||||
var resp chatCompletionsResponse
|
||||
if err := c.doJSON(ctx, "/chat/completions", req, &resp); err != nil {
|
||||
return thoughttypes.ThoughtMetadata{}, err
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata error: %s", c.name, resp.Error.Message)
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: no choices returned", c.name)
|
||||
}
|
||||
|
||||
metadataText := strings.TrimSpace(resp.Choices[0].Message.Content)
|
||||
metadataText = stripCodeFence(metadataText)
|
||||
|
||||
var metadata thoughttypes.ThoughtMetadata
|
||||
if err := json.Unmarshal([]byte(metadataText), &metadata); err != nil {
|
||||
return thoughttypes.ThoughtMetadata{}, fmt.Errorf("%s metadata: parse json: %w", c.name, err)
|
||||
}
|
||||
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
func (c *Client) Summarize(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
|
||||
req := chatCompletionsRequest{
|
||||
Model: c.metadataModel,
|
||||
Temperature: 0.2,
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: userPrompt},
|
||||
},
|
||||
}
|
||||
|
||||
var resp chatCompletionsResponse
|
||||
if err := c.doJSON(ctx, "/chat/completions", req, &resp); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return "", fmt.Errorf("%s summarize error: %s", c.name, resp.Error.Message)
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", fmt.Errorf("%s summarize: no choices returned", c.name)
|
||||
}
|
||||
|
||||
return strings.TrimSpace(resp.Choices[0].Message.Content), nil
|
||||
}
|
||||
|
||||
func (c *Client) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
func (c *Client) doJSON(ctx context.Context, path string, requestBody any, dest any) error {
|
||||
body, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s request marshal: %w", c.name, err)
|
||||
}
|
||||
|
||||
const maxAttempts = 3
|
||||
|
||||
var lastErr error
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(c.baseURL, "/")+path, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s build request: %w", c.name, err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
for key, value := range c.headers {
|
||||
if strings.TrimSpace(key) == "" || strings.TrimSpace(value) == "" {
|
||||
continue
|
||||
}
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
lastErr = fmt.Errorf("%s request failed: %w", c.name, err)
|
||||
if attempt < maxAttempts && isRetryableError(err) {
|
||||
if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil {
|
||||
return retryErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
payload, readErr := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if readErr != nil {
|
||||
lastErr = fmt.Errorf("%s read response: %w", c.name, readErr)
|
||||
if attempt < maxAttempts {
|
||||
if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil {
|
||||
return retryErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
lastErr = fmt.Errorf("%s request failed with status %d: %s", c.name, resp.StatusCode, strings.TrimSpace(string(payload)))
|
||||
if attempt < maxAttempts && isRetryableStatus(resp.StatusCode) {
|
||||
if retryErr := sleepRetry(ctx, attempt, c.log, c.name); retryErr != nil {
|
||||
return retryErr
|
||||
}
|
||||
continue
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(payload, dest); err != nil {
|
||||
if c.log != nil {
|
||||
c.log.Debug("provider response body", slog.String("provider", c.name), slog.String("body", string(payload)))
|
||||
}
|
||||
return fmt.Errorf("%s decode response: %w", c.name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func stripCodeFence(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if !strings.HasPrefix(value, "```") {
|
||||
return value
|
||||
}
|
||||
|
||||
value = strings.TrimPrefix(value, "```json")
|
||||
value = strings.TrimPrefix(value, "```")
|
||||
value = strings.TrimSuffix(value, "```")
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
func isRetryableStatus(status int) bool {
|
||||
switch status {
|
||||
case http.StatusTooManyRequests, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
||||
return true
|
||||
}
|
||||
var netErr net.Error
|
||||
return errors.As(err, &netErr)
|
||||
}
|
||||
|
||||
func sleepRetry(ctx context.Context, attempt int, log *slog.Logger, provider string) error {
|
||||
delay := time.Duration(attempt*attempt) * 200 * time.Millisecond
|
||||
if log != nil {
|
||||
log.Warn("retrying provider request", slog.String("provider", provider), slog.Duration("delay", delay), slog.Int("attempt", attempt+1))
|
||||
}
|
||||
timer := time.NewTimer(delay)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
91
internal/ai/compat/client_test.go
Normal file
91
internal/ai/compat/client_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package compat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func discardLogger() *slog.Logger {
|
||||
return slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
}
|
||||
|
||||
func TestEmbedRetriesTransientFailures(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if calls.Add(1) < 3 {
|
||||
http.Error(w, "temporary failure", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"data": []map[string]any{
|
||||
{"embedding": []float32{1, 2, 3}},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := New(Config{
|
||||
Name: "test",
|
||||
BaseURL: server.URL,
|
||||
APIKey: "secret",
|
||||
EmbeddingModel: "embed-model",
|
||||
MetadataModel: "meta-model",
|
||||
HTTPClient: server.Client(),
|
||||
Log: discardLogger(),
|
||||
Dimensions: 3,
|
||||
})
|
||||
|
||||
embedding, err := client.Embed(context.Background(), "hello")
|
||||
if err != nil {
|
||||
t.Fatalf("Embed() error = %v", err)
|
||||
}
|
||||
if len(embedding) != 3 {
|
||||
t.Fatalf("embedding len = %d, want 3", len(embedding))
|
||||
}
|
||||
if got := calls.Load(); got != 3 {
|
||||
t.Fatalf("call count = %d, want 3", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractMetadataParsesCodeFencedJSON(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"message": map[string]any{
|
||||
"content": "```json\n{\"people\":[\"Alice\"],\"action_items\":[],\"dates_mentioned\":[],\"topics\":[\"memory\"],\"type\":\"idea\",\"source\":\"mcp\"}\n```",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := New(Config{
|
||||
Name: "test",
|
||||
BaseURL: server.URL,
|
||||
APIKey: "secret",
|
||||
EmbeddingModel: "embed-model",
|
||||
MetadataModel: "meta-model",
|
||||
HTTPClient: server.Client(),
|
||||
Log: discardLogger(),
|
||||
})
|
||||
|
||||
metadata, err := client.ExtractMetadata(context.Background(), "hello")
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractMetadata() error = %v", err)
|
||||
}
|
||||
if metadata.Type != "idea" {
|
||||
t.Fatalf("metadata type = %q, want idea", metadata.Type)
|
||||
}
|
||||
if len(metadata.People) != 1 || metadata.People[0] != "Alice" {
|
||||
t.Fatalf("metadata people = %#v, want [Alice]", metadata.People)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user