test(config): add migration tests for litellm provider
Some checks failed
CI / build-and-test (push) Failing after -32m22s

* Implement tests for migrating configuration from v1 to v2 for the litellm provider.
* Validate the structure and values of the migrated configuration.
* Ensure migration rejects newer versions of the configuration.
fix(validate): enhance AI provider validation logic
* Consolidate provider validation into a dedicated method.
* Ensure at least one provider is specified and validate its type.
* Check for required fields based on provider type.
fix(mcpserver): update tool set to use new enrichment tool
* Replace RetryMetadataTool with RetryEnrichmentTool in the ToolSet.
fix(tools): refactor tools to use embedding and metadata runners
* Update tools to utilize EmbeddingRunner and MetadataRunner instead of Provider.
* Adjust method calls to align with the new runner interfaces.
This commit is contained in:
2026-04-21 21:14:28 +02:00
parent 532d1560a3
commit 14e218d784
39 changed files with 2062 additions and 901 deletions

View File

@@ -18,10 +18,10 @@ import (
const backfillConcurrency = 4
type BackfillTool struct {
store *store.DB
provider ai.Provider
sessions *session.ActiveProjects
logger *slog.Logger
store *store.DB
embeddings *ai.EmbeddingRunner
sessions *session.ActiveProjects
logger *slog.Logger
}
type BackfillInput struct {
@@ -47,15 +47,15 @@ type BackfillOutput struct {
Failures []BackfillFailure `json:"failures,omitempty"`
}
func NewBackfillTool(db *store.DB, provider ai.Provider, sessions *session.ActiveProjects, logger *slog.Logger) *BackfillTool {
return &BackfillTool{store: db, provider: provider, sessions: sessions, logger: logger}
func NewBackfillTool(db *store.DB, embeddings *ai.EmbeddingRunner, sessions *session.ActiveProjects, logger *slog.Logger) *BackfillTool {
return &BackfillTool{store: db, embeddings: embeddings, sessions: sessions, logger: logger}
}
// QueueThought queues a single thought for background embedding generation.
// It is used by capture when the embedding provider is temporarily unavailable.
func (t *BackfillTool) QueueThought(ctx context.Context, id uuid.UUID, content string) {
go func() {
vec, err := t.provider.Embed(ctx, content)
result, err := t.embeddings.Embed(ctx, content)
if err != nil {
t.logger.Warn("background embedding retry failed",
slog.String("thought_id", id.String()),
@@ -63,15 +63,17 @@ func (t *BackfillTool) QueueThought(ctx context.Context, id uuid.UUID, content s
)
return
}
model := t.provider.EmbeddingModel()
if err := t.store.UpsertEmbedding(ctx, id, model, vec); err != nil {
if err := t.store.UpsertEmbedding(ctx, id, result.Model, result.Vector); err != nil {
t.logger.Warn("background embedding upsert failed",
slog.String("thought_id", id.String()),
slog.String("error", err.Error()),
)
return
}
t.logger.Info("background embedding retry succeeded", slog.String("thought_id", id.String()))
t.logger.Info("background embedding retry succeeded",
slog.String("thought_id", id.String()),
slog.String("model", result.Model),
)
}()
}
@@ -91,15 +93,15 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
projectID = &project.ID
}
model := t.provider.EmbeddingModel()
primaryModel := t.embeddings.PrimaryModel()
thoughts, err := t.store.ListThoughtsMissingEmbedding(ctx, model, limit, projectID, in.IncludeArchived, in.OlderThanDays)
thoughts, err := t.store.ListThoughtsMissingEmbedding(ctx, primaryModel, limit, projectID, in.IncludeArchived, in.OlderThanDays)
if err != nil {
return nil, BackfillOutput{}, err
}
out := BackfillOutput{
Model: model,
Model: primaryModel,
Scanned: len(thoughts),
DryRun: in.DryRun,
}
@@ -125,7 +127,7 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
defer wg.Done()
defer sem.Release(1)
vec, embedErr := t.provider.Embed(ctx, content)
result, embedErr := t.embeddings.Embed(ctx, content)
if embedErr != nil {
mu.Lock()
out.Failures = append(out.Failures, BackfillFailure{ID: id.String(), Error: embedErr.Error()})
@@ -134,7 +136,7 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
return
}
if upsertErr := t.store.UpsertEmbedding(ctx, id, model, vec); upsertErr != nil {
if upsertErr := t.store.UpsertEmbedding(ctx, id, result.Model, result.Vector); upsertErr != nil {
mu.Lock()
out.Failures = append(out.Failures, BackfillFailure{ID: id.String(), Error: upsertErr.Error()})
mu.Unlock()
@@ -154,7 +156,7 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
out.Skipped = out.Scanned - out.Embedded - out.Failed
t.logger.Info("backfill completed",
slog.String("model", model),
slog.String("model", primaryModel),
slog.Int("scanned", out.Scanned),
slog.Int("embedded", out.Embedded),
slog.Int("failed", out.Failed),

View File

@@ -22,13 +22,20 @@ type EmbeddingQueuer interface {
QueueThought(ctx context.Context, id uuid.UUID, content string)
}
// MetadataQueuer queues a thought for background metadata retry. Both
// MetadataRetryer and EnrichmentRetryer satisfy this.
type MetadataQueuer interface {
QueueThought(id uuid.UUID)
}
type CaptureTool struct {
store *store.DB
provider ai.Provider
embeddings *ai.EmbeddingRunner
metadata *ai.MetadataRunner
capture config.CaptureConfig
sessions *session.ActiveProjects
metadataTimeout time.Duration
retryer *MetadataRetryer
retryer MetadataQueuer
embedRetryer EmbeddingQueuer
log *slog.Logger
}
@@ -42,8 +49,8 @@ type CaptureOutput struct {
Thought thoughttypes.Thought `json:"thought"`
}
func NewCaptureTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, retryer *MetadataRetryer, embedRetryer EmbeddingQueuer, log *slog.Logger) *CaptureTool {
return &CaptureTool{store: db, provider: provider, capture: capture, sessions: sessions, metadataTimeout: metadataTimeout, retryer: retryer, embedRetryer: embedRetryer, log: log}
func NewCaptureTool(db *store.DB, embeddings *ai.EmbeddingRunner, metadata *ai.MetadataRunner, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, retryer MetadataQueuer, embedRetryer EmbeddingQueuer, log *slog.Logger) *CaptureTool {
return &CaptureTool{store: db, embeddings: embeddings, metadata: metadata, capture: capture, sessions: sessions, metadataTimeout: metadataTimeout, retryer: retryer, embedRetryer: embedRetryer, log: log}
}
func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in CaptureInput) (*mcp.CallToolResult, CaptureOutput, error) {
@@ -66,7 +73,7 @@ func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in C
thought.ProjectID = &project.ID
}
created, err := t.store.InsertThought(ctx, thought, t.provider.EmbeddingModel())
created, err := t.store.InsertThought(ctx, thought, t.embeddings.PrimaryModel())
if err != nil {
return nil, CaptureOutput{}, err
}
@@ -89,7 +96,7 @@ func (t *CaptureTool) launchEnrichment(id uuid.UUID, content string) {
if t.retryer != nil {
attemptedAt := time.Now().UTC()
rawMetadata := metadata.Fallback(t.capture)
extracted, err := t.provider.ExtractMetadata(ctx, content)
extracted, err := t.metadata.ExtractMetadata(ctx, content)
if err != nil {
failed := metadata.MarkMetadataFailed(rawMetadata, t.capture, attemptedAt, err)
if _, updateErr := t.store.UpdateThoughtMetadata(ctx, id, failed); updateErr != nil {
@@ -100,7 +107,7 @@ func (t *CaptureTool) launchEnrichment(id uuid.UUID, content string) {
}
t.log.Warn("deferred metadata extraction failed",
slog.String("thought_id", id.String()),
slog.String("provider", t.provider.Name()),
slog.String("provider", t.metadata.PrimaryProvider()),
slog.String("error", err.Error()),
)
t.retryer.QueueThought(id)
@@ -116,10 +123,10 @@ func (t *CaptureTool) launchEnrichment(id uuid.UUID, content string) {
}
if t.embedRetryer != nil {
if _, err := t.provider.Embed(ctx, content); err != nil {
if _, err := t.embeddings.Embed(ctx, content); err != nil {
t.log.Warn("deferred embedding failed",
slog.String("thought_id", id.String()),
slog.String("provider", t.provider.Name()),
slog.String("provider", t.embeddings.PrimaryProvider()),
slog.String("error", err.Error()),
)
}

View File

@@ -15,10 +15,10 @@ import (
)
type ContextTool struct {
store *store.DB
provider ai.Provider
search config.SearchConfig
sessions *session.ActiveProjects
store *store.DB
embeddings *ai.EmbeddingRunner
search config.SearchConfig
sessions *session.ActiveProjects
}
type ProjectContextInput struct {
@@ -41,8 +41,8 @@ type ProjectContextOutput struct {
Items []ContextItem `json:"items"`
}
func NewContextTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *ContextTool {
return &ContextTool{store: db, provider: provider, search: search, sessions: sessions}
func NewContextTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig, sessions *session.ActiveProjects) *ContextTool {
return &ContextTool{store: db, embeddings: embeddings, search: search, sessions: sessions}
}
func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in ProjectContextInput) (*mcp.CallToolResult, ProjectContextOutput, error) {
@@ -72,7 +72,7 @@ func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in P
query := strings.TrimSpace(in.Query)
if query != "" {
semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, &project.ID, nil)
semantic, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, t.search.DefaultThreshold, &project.ID, nil)
if err != nil {
return nil, ProjectContextOutput{}, err
}

View File

@@ -32,7 +32,7 @@ var enrichmentRetryBackoff = []time.Duration{
type EnrichmentRetryer struct {
backgroundCtx context.Context
store *store.DB
provider ai.Provider
metadata *ai.MetadataRunner
capture config.CaptureConfig
sessions *session.ActiveProjects
metadataTimeout time.Duration
@@ -66,14 +66,14 @@ type RetryEnrichmentOutput struct {
Failures []RetryEnrichmentFailure `json:"failures,omitempty"`
}
func NewEnrichmentRetryer(backgroundCtx context.Context, db *store.DB, provider ai.Provider, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, logger *slog.Logger) *EnrichmentRetryer {
func NewEnrichmentRetryer(backgroundCtx context.Context, db *store.DB, metadataRunner *ai.MetadataRunner, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, logger *slog.Logger) *EnrichmentRetryer {
if backgroundCtx == nil {
backgroundCtx = context.Background()
}
return &EnrichmentRetryer{
backgroundCtx: backgroundCtx,
store: db,
provider: provider,
metadata: metadataRunner,
capture: capture,
sessions: sessions,
metadataTimeout: metadataTimeout,
@@ -190,7 +190,7 @@ func (r *EnrichmentRetryer) retryOne(ctx context.Context, id uuid.UUID) (bool, e
}
attemptedAt := time.Now().UTC()
extracted, extractErr := r.provider.ExtractMetadata(attemptCtx, thought.Content)
extracted, extractErr := r.metadata.ExtractMetadata(attemptCtx, thought.Content)
if extractErr != nil {
failedMetadata := metadata.MarkMetadataFailed(thought.Metadata, r.capture, attemptedAt, extractErr)
if _, updateErr := r.store.UpdateThoughtMetadata(ctx, thought.ID, failedMetadata); updateErr != nil {

View File

@@ -13,9 +13,9 @@ import (
)
type LinksTool struct {
store *store.DB
provider ai.Provider
search config.SearchConfig
store *store.DB
embeddings *ai.EmbeddingRunner
search config.SearchConfig
}
type LinkInput struct {
@@ -47,8 +47,8 @@ type RelatedOutput struct {
Related []RelatedThought `json:"related"`
}
func NewLinksTool(db *store.DB, provider ai.Provider, search config.SearchConfig) *LinksTool {
return &LinksTool{store: db, provider: provider, search: search}
func NewLinksTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig) *LinksTool {
return &LinksTool{store: db, embeddings: embeddings, search: search}
}
func (t *LinksTool) Link(ctx context.Context, _ *mcp.CallToolRequest, in LinkInput) (*mcp.CallToolResult, LinkOutput, error) {
@@ -117,7 +117,7 @@ func (t *LinksTool) Related(ctx context.Context, _ *mcp.CallToolRequest, in Rela
}
if includeSemantic {
semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, thought.Content, t.search.DefaultLimit, t.search.DefaultThreshold, thought.ProjectID, &thought.ID)
semantic, err := semanticSearch(ctx, t.store, t.embeddings, t.search, thought.Content, t.search.DefaultLimit, t.search.DefaultThreshold, thought.ProjectID, &thought.ID)
if err != nil {
return nil, RelatedOutput{}, err
}

View File

@@ -23,7 +23,7 @@ const metadataRetryConcurrency = 4
type MetadataRetryer struct {
backgroundCtx context.Context
store *store.DB
provider ai.Provider
metadata *ai.MetadataRunner
capture config.CaptureConfig
sessions *session.ActiveProjects
metadataTimeout time.Duration
@@ -87,14 +87,14 @@ type RetryMetadataOutput struct {
Failures []RetryMetadataFailure `json:"failures,omitempty"`
}
func NewMetadataRetryer(backgroundCtx context.Context, db *store.DB, provider ai.Provider, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, logger *slog.Logger) *MetadataRetryer {
func NewMetadataRetryer(backgroundCtx context.Context, db *store.DB, metadataRunner *ai.MetadataRunner, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, logger *slog.Logger) *MetadataRetryer {
if backgroundCtx == nil {
backgroundCtx = context.Background()
}
return &MetadataRetryer{
backgroundCtx: backgroundCtx,
store: db,
provider: provider,
metadata: metadataRunner,
capture: capture,
sessions: sessions,
metadataTimeout: metadataTimeout,
@@ -223,7 +223,7 @@ func (r *MetadataRetryer) retryOne(ctx context.Context, id uuid.UUID) (bool, err
}
attemptedAt := time.Now().UTC()
extracted, extractErr := r.provider.ExtractMetadata(attemptCtx, thought.Content)
extracted, extractErr := r.metadata.ExtractMetadata(attemptCtx, thought.Content)
if extractErr != nil {
failedMetadata := metadata.MarkMetadataFailed(thought.Metadata, r.capture, attemptedAt, extractErr)
if _, updateErr := r.store.UpdateThoughtMetadata(ctx, thought.ID, failedMetadata); updateErr != nil {

View File

@@ -15,10 +15,10 @@ import (
)
type RecallTool struct {
store *store.DB
provider ai.Provider
search config.SearchConfig
sessions *session.ActiveProjects
store *store.DB
embeddings *ai.EmbeddingRunner
search config.SearchConfig
sessions *session.ActiveProjects
}
type RecallInput struct {
@@ -32,8 +32,8 @@ type RecallOutput struct {
Items []ContextItem `json:"items"`
}
func NewRecallTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *RecallTool {
return &RecallTool{store: db, provider: provider, search: search, sessions: sessions}
func NewRecallTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig, sessions *session.ActiveProjects) *RecallTool {
return &RecallTool{store: db, embeddings: embeddings, search: search, sessions: sessions}
}
func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in RecallInput) (*mcp.CallToolResult, RecallOutput, error) {
@@ -54,7 +54,7 @@ func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Re
projectID = &project.ID
}
semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
semantic, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
if err != nil {
return nil, RecallOutput{}, err
}

View File

@@ -23,7 +23,7 @@ const metadataReparseConcurrency = 4
type ReparseMetadataTool struct {
store *store.DB
provider ai.Provider
metadata *ai.MetadataRunner
capture config.CaptureConfig
sessions *session.ActiveProjects
logger *slog.Logger
@@ -53,8 +53,8 @@ type ReparseMetadataOutput struct {
Failures []ReparseMetadataFailure `json:"failures,omitempty"`
}
func NewReparseMetadataTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, sessions *session.ActiveProjects, logger *slog.Logger) *ReparseMetadataTool {
return &ReparseMetadataTool{store: db, provider: provider, capture: capture, sessions: sessions, logger: logger}
func NewReparseMetadataTool(db *store.DB, metadataRunner *ai.MetadataRunner, capture config.CaptureConfig, sessions *session.ActiveProjects, logger *slog.Logger) *ReparseMetadataTool {
return &ReparseMetadataTool{store: db, metadata: metadataRunner, capture: capture, sessions: sessions, logger: logger}
}
func (t *ReparseMetadataTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in ReparseMetadataInput) (*mcp.CallToolResult, ReparseMetadataOutput, error) {
@@ -107,7 +107,7 @@ func (t *ReparseMetadataTool) Handle(ctx context.Context, req *mcp.CallToolReque
normalizedCurrent := metadata.Normalize(thought.Metadata, t.capture)
attemptedAt := time.Now().UTC()
extracted, extractErr := t.provider.ExtractMetadata(ctx, thought.Content)
extracted, extractErr := t.metadata.ExtractMetadata(ctx, thought.Content)
normalizedTarget := normalizedCurrent
if extractErr != nil {
normalizedTarget = metadata.MarkMetadataFailed(normalizedCurrent, t.capture, attemptedAt, extractErr)

View File

@@ -11,12 +11,14 @@ import (
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
)
// semanticSearch runs vector similarity search if embeddings exist for the active model
// in the given scope, otherwise falls back to Postgres full-text search.
// semanticSearch runs vector similarity search if embeddings exist for the
// primary embedding model in the given scope, otherwise falls back to Postgres
// full-text search. Search always uses the primary model so query vectors
// match rows stored under the primary model name.
func semanticSearch(
ctx context.Context,
db *store.DB,
provider ai.Provider,
embeddings *ai.EmbeddingRunner,
search config.SearchConfig,
query string,
limit int,
@@ -24,17 +26,18 @@ func semanticSearch(
projectID *uuid.UUID,
excludeID *uuid.UUID,
) ([]thoughttypes.SearchResult, error) {
hasEmbeddings, err := db.HasEmbeddingsForModel(ctx, provider.EmbeddingModel(), projectID)
model := embeddings.PrimaryModel()
hasEmbeddings, err := db.HasEmbeddingsForModel(ctx, model, projectID)
if err != nil {
return nil, err
}
if hasEmbeddings {
embedding, err := provider.Embed(ctx, query)
embedding, err := embeddings.EmbedPrimary(ctx, query)
if err != nil {
return nil, err
}
return db.SearchSimilarThoughts(ctx, embedding, provider.EmbeddingModel(), threshold, limit, projectID, excludeID)
return db.SearchSimilarThoughts(ctx, embedding, model, threshold, limit, projectID, excludeID)
}
return db.SearchThoughtsText(ctx, query, limit, projectID, excludeID)

View File

@@ -15,10 +15,10 @@ import (
)
type SearchTool struct {
store *store.DB
provider ai.Provider
search config.SearchConfig
sessions *session.ActiveProjects
store *store.DB
embeddings *ai.EmbeddingRunner
search config.SearchConfig
sessions *session.ActiveProjects
}
type SearchInput struct {
@@ -32,8 +32,8 @@ type SearchOutput struct {
Results []thoughttypes.SearchResult `json:"results"`
}
func NewSearchTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *SearchTool {
return &SearchTool{store: db, provider: provider, search: search, sessions: sessions}
func NewSearchTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig, sessions *session.ActiveProjects) *SearchTool {
return &SearchTool{store: db, embeddings: embeddings, search: search, sessions: sessions}
}
func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in SearchInput) (*mcp.CallToolResult, SearchOutput, error) {
@@ -56,7 +56,7 @@ func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Se
_ = t.store.TouchProject(ctx, project.ID)
}
results, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, threshold, projectID, nil)
results, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, threshold, projectID, nil)
if err != nil {
return nil, SearchOutput{}, err
}

View File

@@ -14,10 +14,11 @@ import (
)
type SummarizeTool struct {
store *store.DB
provider ai.Provider
search config.SearchConfig
sessions *session.ActiveProjects
store *store.DB
embeddings *ai.EmbeddingRunner
metadata *ai.MetadataRunner
search config.SearchConfig
sessions *session.ActiveProjects
}
type SummarizeInput struct {
@@ -32,8 +33,8 @@ type SummarizeOutput struct {
Count int `json:"count"`
}
func NewSummarizeTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *SummarizeTool {
return &SummarizeTool{store: db, provider: provider, search: search, sessions: sessions}
func NewSummarizeTool(db *store.DB, embeddings *ai.EmbeddingRunner, metadata *ai.MetadataRunner, search config.SearchConfig, sessions *session.ActiveProjects) *SummarizeTool {
return &SummarizeTool{store: db, embeddings: embeddings, metadata: metadata, search: search, sessions: sessions}
}
func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in SummarizeInput) (*mcp.CallToolResult, SummarizeOutput, error) {
@@ -52,7 +53,7 @@ func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
if project != nil {
projectID = &project.ID
}
results, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
results, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
if err != nil {
return nil, SummarizeOutput{}, err
}
@@ -77,7 +78,7 @@ func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
userPrompt := formatContextBlock("Summarize the following thoughts into concise prose with themes, action items, and notable people.", lines)
systemPrompt := "You summarize note collections. Be concise, concrete, and structured in plain prose."
summary, err := t.provider.Summarize(ctx, systemPrompt, userPrompt)
summary, err := t.metadata.Summarize(ctx, systemPrompt, userPrompt)
if err != nil {
return nil, SummarizeOutput{}, err
}

View File

@@ -16,10 +16,11 @@ import (
)
type UpdateTool struct {
store *store.DB
provider ai.Provider
capture config.CaptureConfig
log *slog.Logger
store *store.DB
embeddings *ai.EmbeddingRunner
metadata *ai.MetadataRunner
capture config.CaptureConfig
log *slog.Logger
}
type UpdateInput struct {
@@ -33,8 +34,8 @@ type UpdateOutput struct {
Thought thoughttypes.Thought `json:"thought"`
}
func NewUpdateTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, log *slog.Logger) *UpdateTool {
return &UpdateTool{store: db, provider: provider, capture: capture, log: log}
func NewUpdateTool(db *store.DB, embeddings *ai.EmbeddingRunner, metadata *ai.MetadataRunner, capture config.CaptureConfig, log *slog.Logger) *UpdateTool {
return &UpdateTool{store: db, embeddings: embeddings, metadata: metadata, capture: capture, log: log}
}
func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in UpdateInput) (*mcp.CallToolResult, UpdateOutput, error) {
@@ -50,6 +51,7 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda
content := current.Content
var embedding []float32
embeddingModel := ""
mergedMetadata := current.Metadata
projectID := current.ProjectID
@@ -58,11 +60,13 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda
if content == "" {
return nil, UpdateOutput{}, errInvalidInput("content must not be empty")
}
embedding, err = t.provider.Embed(ctx, content)
embedResult, err := t.embeddings.Embed(ctx, content)
if err != nil {
return nil, UpdateOutput{}, err
}
extracted, extractErr := t.provider.ExtractMetadata(ctx, content)
embedding = embedResult.Vector
embeddingModel = embedResult.Model
extracted, extractErr := t.metadata.ExtractMetadata(ctx, content)
if extractErr != nil {
t.log.Warn("metadata extraction failed during update, keeping current metadata", slog.String("error", extractErr.Error()))
mergedMetadata = metadata.MarkMetadataFailed(mergedMetadata, t.capture, time.Now().UTC(), extractErr)
@@ -82,7 +86,7 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda
projectID = &project.ID
}
updated, err := t.store.UpdateThought(ctx, id, content, embedding, t.provider.EmbeddingModel(), mergedMetadata, projectID)
updated, err := t.store.UpdateThought(ctx, id, content, embedding, embeddingModel, mergedMetadata, projectID)
if err != nil {
return nil, UpdateOutput{}, err
}