test(config): add migration tests for litellm provider
Some checks failed
CI / build-and-test (push) Failing after -32m22s
Some checks failed
CI / build-and-test (push) Failing after -32m22s
* Implement tests for migrating configuration from v1 to v2 for the litellm provider. * Validate the structure and values of the migrated configuration. * Ensure migration rejects newer versions of the configuration. fix(validate): enhance AI provider validation logic * Consolidate provider validation into a dedicated method. * Ensure at least one provider is specified and validate its type. * Check for required fields based on provider type. fix(mcpserver): update tool set to use new enrichment tool * Replace RetryMetadataTool with RetryEnrichmentTool in the ToolSet. fix(tools): refactor tools to use embedding and metadata runners * Update tools to utilize EmbeddingRunner and MetadataRunner instead of Provider. * Adjust method calls to align with the new runner interfaces.
This commit is contained in:
@@ -18,10 +18,10 @@ import (
|
||||
const backfillConcurrency = 4
|
||||
|
||||
type BackfillTool struct {
|
||||
store *store.DB
|
||||
provider ai.Provider
|
||||
sessions *session.ActiveProjects
|
||||
logger *slog.Logger
|
||||
store *store.DB
|
||||
embeddings *ai.EmbeddingRunner
|
||||
sessions *session.ActiveProjects
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
type BackfillInput struct {
|
||||
@@ -47,15 +47,15 @@ type BackfillOutput struct {
|
||||
Failures []BackfillFailure `json:"failures,omitempty"`
|
||||
}
|
||||
|
||||
func NewBackfillTool(db *store.DB, provider ai.Provider, sessions *session.ActiveProjects, logger *slog.Logger) *BackfillTool {
|
||||
return &BackfillTool{store: db, provider: provider, sessions: sessions, logger: logger}
|
||||
func NewBackfillTool(db *store.DB, embeddings *ai.EmbeddingRunner, sessions *session.ActiveProjects, logger *slog.Logger) *BackfillTool {
|
||||
return &BackfillTool{store: db, embeddings: embeddings, sessions: sessions, logger: logger}
|
||||
}
|
||||
|
||||
// QueueThought queues a single thought for background embedding generation.
|
||||
// It is used by capture when the embedding provider is temporarily unavailable.
|
||||
func (t *BackfillTool) QueueThought(ctx context.Context, id uuid.UUID, content string) {
|
||||
go func() {
|
||||
vec, err := t.provider.Embed(ctx, content)
|
||||
result, err := t.embeddings.Embed(ctx, content)
|
||||
if err != nil {
|
||||
t.logger.Warn("background embedding retry failed",
|
||||
slog.String("thought_id", id.String()),
|
||||
@@ -63,15 +63,17 @@ func (t *BackfillTool) QueueThought(ctx context.Context, id uuid.UUID, content s
|
||||
)
|
||||
return
|
||||
}
|
||||
model := t.provider.EmbeddingModel()
|
||||
if err := t.store.UpsertEmbedding(ctx, id, model, vec); err != nil {
|
||||
if err := t.store.UpsertEmbedding(ctx, id, result.Model, result.Vector); err != nil {
|
||||
t.logger.Warn("background embedding upsert failed",
|
||||
slog.String("thought_id", id.String()),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
return
|
||||
}
|
||||
t.logger.Info("background embedding retry succeeded", slog.String("thought_id", id.String()))
|
||||
t.logger.Info("background embedding retry succeeded",
|
||||
slog.String("thought_id", id.String()),
|
||||
slog.String("model", result.Model),
|
||||
)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -91,15 +93,15 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
|
||||
projectID = &project.ID
|
||||
}
|
||||
|
||||
model := t.provider.EmbeddingModel()
|
||||
primaryModel := t.embeddings.PrimaryModel()
|
||||
|
||||
thoughts, err := t.store.ListThoughtsMissingEmbedding(ctx, model, limit, projectID, in.IncludeArchived, in.OlderThanDays)
|
||||
thoughts, err := t.store.ListThoughtsMissingEmbedding(ctx, primaryModel, limit, projectID, in.IncludeArchived, in.OlderThanDays)
|
||||
if err != nil {
|
||||
return nil, BackfillOutput{}, err
|
||||
}
|
||||
|
||||
out := BackfillOutput{
|
||||
Model: model,
|
||||
Model: primaryModel,
|
||||
Scanned: len(thoughts),
|
||||
DryRun: in.DryRun,
|
||||
}
|
||||
@@ -125,7 +127,7 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
|
||||
defer wg.Done()
|
||||
defer sem.Release(1)
|
||||
|
||||
vec, embedErr := t.provider.Embed(ctx, content)
|
||||
result, embedErr := t.embeddings.Embed(ctx, content)
|
||||
if embedErr != nil {
|
||||
mu.Lock()
|
||||
out.Failures = append(out.Failures, BackfillFailure{ID: id.String(), Error: embedErr.Error()})
|
||||
@@ -134,7 +136,7 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
|
||||
return
|
||||
}
|
||||
|
||||
if upsertErr := t.store.UpsertEmbedding(ctx, id, model, vec); upsertErr != nil {
|
||||
if upsertErr := t.store.UpsertEmbedding(ctx, id, result.Model, result.Vector); upsertErr != nil {
|
||||
mu.Lock()
|
||||
out.Failures = append(out.Failures, BackfillFailure{ID: id.String(), Error: upsertErr.Error()})
|
||||
mu.Unlock()
|
||||
@@ -154,7 +156,7 @@ func (t *BackfillTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
|
||||
out.Skipped = out.Scanned - out.Embedded - out.Failed
|
||||
|
||||
t.logger.Info("backfill completed",
|
||||
slog.String("model", model),
|
||||
slog.String("model", primaryModel),
|
||||
slog.Int("scanned", out.Scanned),
|
||||
slog.Int("embedded", out.Embedded),
|
||||
slog.Int("failed", out.Failed),
|
||||
|
||||
@@ -22,13 +22,20 @@ type EmbeddingQueuer interface {
|
||||
QueueThought(ctx context.Context, id uuid.UUID, content string)
|
||||
}
|
||||
|
||||
// MetadataQueuer queues a thought for background metadata retry. Both
|
||||
// MetadataRetryer and EnrichmentRetryer satisfy this.
|
||||
type MetadataQueuer interface {
|
||||
QueueThought(id uuid.UUID)
|
||||
}
|
||||
|
||||
type CaptureTool struct {
|
||||
store *store.DB
|
||||
provider ai.Provider
|
||||
embeddings *ai.EmbeddingRunner
|
||||
metadata *ai.MetadataRunner
|
||||
capture config.CaptureConfig
|
||||
sessions *session.ActiveProjects
|
||||
metadataTimeout time.Duration
|
||||
retryer *MetadataRetryer
|
||||
retryer MetadataQueuer
|
||||
embedRetryer EmbeddingQueuer
|
||||
log *slog.Logger
|
||||
}
|
||||
@@ -42,8 +49,8 @@ type CaptureOutput struct {
|
||||
Thought thoughttypes.Thought `json:"thought"`
|
||||
}
|
||||
|
||||
func NewCaptureTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, retryer *MetadataRetryer, embedRetryer EmbeddingQueuer, log *slog.Logger) *CaptureTool {
|
||||
return &CaptureTool{store: db, provider: provider, capture: capture, sessions: sessions, metadataTimeout: metadataTimeout, retryer: retryer, embedRetryer: embedRetryer, log: log}
|
||||
func NewCaptureTool(db *store.DB, embeddings *ai.EmbeddingRunner, metadata *ai.MetadataRunner, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, retryer MetadataQueuer, embedRetryer EmbeddingQueuer, log *slog.Logger) *CaptureTool {
|
||||
return &CaptureTool{store: db, embeddings: embeddings, metadata: metadata, capture: capture, sessions: sessions, metadataTimeout: metadataTimeout, retryer: retryer, embedRetryer: embedRetryer, log: log}
|
||||
}
|
||||
|
||||
func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in CaptureInput) (*mcp.CallToolResult, CaptureOutput, error) {
|
||||
@@ -66,7 +73,7 @@ func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in C
|
||||
thought.ProjectID = &project.ID
|
||||
}
|
||||
|
||||
created, err := t.store.InsertThought(ctx, thought, t.provider.EmbeddingModel())
|
||||
created, err := t.store.InsertThought(ctx, thought, t.embeddings.PrimaryModel())
|
||||
if err != nil {
|
||||
return nil, CaptureOutput{}, err
|
||||
}
|
||||
@@ -89,7 +96,7 @@ func (t *CaptureTool) launchEnrichment(id uuid.UUID, content string) {
|
||||
if t.retryer != nil {
|
||||
attemptedAt := time.Now().UTC()
|
||||
rawMetadata := metadata.Fallback(t.capture)
|
||||
extracted, err := t.provider.ExtractMetadata(ctx, content)
|
||||
extracted, err := t.metadata.ExtractMetadata(ctx, content)
|
||||
if err != nil {
|
||||
failed := metadata.MarkMetadataFailed(rawMetadata, t.capture, attemptedAt, err)
|
||||
if _, updateErr := t.store.UpdateThoughtMetadata(ctx, id, failed); updateErr != nil {
|
||||
@@ -100,7 +107,7 @@ func (t *CaptureTool) launchEnrichment(id uuid.UUID, content string) {
|
||||
}
|
||||
t.log.Warn("deferred metadata extraction failed",
|
||||
slog.String("thought_id", id.String()),
|
||||
slog.String("provider", t.provider.Name()),
|
||||
slog.String("provider", t.metadata.PrimaryProvider()),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
t.retryer.QueueThought(id)
|
||||
@@ -116,10 +123,10 @@ func (t *CaptureTool) launchEnrichment(id uuid.UUID, content string) {
|
||||
}
|
||||
|
||||
if t.embedRetryer != nil {
|
||||
if _, err := t.provider.Embed(ctx, content); err != nil {
|
||||
if _, err := t.embeddings.Embed(ctx, content); err != nil {
|
||||
t.log.Warn("deferred embedding failed",
|
||||
slog.String("thought_id", id.String()),
|
||||
slog.String("provider", t.provider.Name()),
|
||||
slog.String("provider", t.embeddings.PrimaryProvider()),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -15,10 +15,10 @@ import (
|
||||
)
|
||||
|
||||
type ContextTool struct {
|
||||
store *store.DB
|
||||
provider ai.Provider
|
||||
search config.SearchConfig
|
||||
sessions *session.ActiveProjects
|
||||
store *store.DB
|
||||
embeddings *ai.EmbeddingRunner
|
||||
search config.SearchConfig
|
||||
sessions *session.ActiveProjects
|
||||
}
|
||||
|
||||
type ProjectContextInput struct {
|
||||
@@ -41,8 +41,8 @@ type ProjectContextOutput struct {
|
||||
Items []ContextItem `json:"items"`
|
||||
}
|
||||
|
||||
func NewContextTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *ContextTool {
|
||||
return &ContextTool{store: db, provider: provider, search: search, sessions: sessions}
|
||||
func NewContextTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig, sessions *session.ActiveProjects) *ContextTool {
|
||||
return &ContextTool{store: db, embeddings: embeddings, search: search, sessions: sessions}
|
||||
}
|
||||
|
||||
func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in ProjectContextInput) (*mcp.CallToolResult, ProjectContextOutput, error) {
|
||||
@@ -72,7 +72,7 @@ func (t *ContextTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in P
|
||||
|
||||
query := strings.TrimSpace(in.Query)
|
||||
if query != "" {
|
||||
semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, &project.ID, nil)
|
||||
semantic, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, t.search.DefaultThreshold, &project.ID, nil)
|
||||
if err != nil {
|
||||
return nil, ProjectContextOutput{}, err
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ var enrichmentRetryBackoff = []time.Duration{
|
||||
type EnrichmentRetryer struct {
|
||||
backgroundCtx context.Context
|
||||
store *store.DB
|
||||
provider ai.Provider
|
||||
metadata *ai.MetadataRunner
|
||||
capture config.CaptureConfig
|
||||
sessions *session.ActiveProjects
|
||||
metadataTimeout time.Duration
|
||||
@@ -66,14 +66,14 @@ type RetryEnrichmentOutput struct {
|
||||
Failures []RetryEnrichmentFailure `json:"failures,omitempty"`
|
||||
}
|
||||
|
||||
func NewEnrichmentRetryer(backgroundCtx context.Context, db *store.DB, provider ai.Provider, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, logger *slog.Logger) *EnrichmentRetryer {
|
||||
func NewEnrichmentRetryer(backgroundCtx context.Context, db *store.DB, metadataRunner *ai.MetadataRunner, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, logger *slog.Logger) *EnrichmentRetryer {
|
||||
if backgroundCtx == nil {
|
||||
backgroundCtx = context.Background()
|
||||
}
|
||||
return &EnrichmentRetryer{
|
||||
backgroundCtx: backgroundCtx,
|
||||
store: db,
|
||||
provider: provider,
|
||||
metadata: metadataRunner,
|
||||
capture: capture,
|
||||
sessions: sessions,
|
||||
metadataTimeout: metadataTimeout,
|
||||
@@ -190,7 +190,7 @@ func (r *EnrichmentRetryer) retryOne(ctx context.Context, id uuid.UUID) (bool, e
|
||||
}
|
||||
|
||||
attemptedAt := time.Now().UTC()
|
||||
extracted, extractErr := r.provider.ExtractMetadata(attemptCtx, thought.Content)
|
||||
extracted, extractErr := r.metadata.ExtractMetadata(attemptCtx, thought.Content)
|
||||
if extractErr != nil {
|
||||
failedMetadata := metadata.MarkMetadataFailed(thought.Metadata, r.capture, attemptedAt, extractErr)
|
||||
if _, updateErr := r.store.UpdateThoughtMetadata(ctx, thought.ID, failedMetadata); updateErr != nil {
|
||||
|
||||
@@ -13,9 +13,9 @@ import (
|
||||
)
|
||||
|
||||
type LinksTool struct {
|
||||
store *store.DB
|
||||
provider ai.Provider
|
||||
search config.SearchConfig
|
||||
store *store.DB
|
||||
embeddings *ai.EmbeddingRunner
|
||||
search config.SearchConfig
|
||||
}
|
||||
|
||||
type LinkInput struct {
|
||||
@@ -47,8 +47,8 @@ type RelatedOutput struct {
|
||||
Related []RelatedThought `json:"related"`
|
||||
}
|
||||
|
||||
func NewLinksTool(db *store.DB, provider ai.Provider, search config.SearchConfig) *LinksTool {
|
||||
return &LinksTool{store: db, provider: provider, search: search}
|
||||
func NewLinksTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig) *LinksTool {
|
||||
return &LinksTool{store: db, embeddings: embeddings, search: search}
|
||||
}
|
||||
|
||||
func (t *LinksTool) Link(ctx context.Context, _ *mcp.CallToolRequest, in LinkInput) (*mcp.CallToolResult, LinkOutput, error) {
|
||||
@@ -117,7 +117,7 @@ func (t *LinksTool) Related(ctx context.Context, _ *mcp.CallToolRequest, in Rela
|
||||
}
|
||||
|
||||
if includeSemantic {
|
||||
semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, thought.Content, t.search.DefaultLimit, t.search.DefaultThreshold, thought.ProjectID, &thought.ID)
|
||||
semantic, err := semanticSearch(ctx, t.store, t.embeddings, t.search, thought.Content, t.search.DefaultLimit, t.search.DefaultThreshold, thought.ProjectID, &thought.ID)
|
||||
if err != nil {
|
||||
return nil, RelatedOutput{}, err
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ const metadataRetryConcurrency = 4
|
||||
type MetadataRetryer struct {
|
||||
backgroundCtx context.Context
|
||||
store *store.DB
|
||||
provider ai.Provider
|
||||
metadata *ai.MetadataRunner
|
||||
capture config.CaptureConfig
|
||||
sessions *session.ActiveProjects
|
||||
metadataTimeout time.Duration
|
||||
@@ -87,14 +87,14 @@ type RetryMetadataOutput struct {
|
||||
Failures []RetryMetadataFailure `json:"failures,omitempty"`
|
||||
}
|
||||
|
||||
func NewMetadataRetryer(backgroundCtx context.Context, db *store.DB, provider ai.Provider, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, logger *slog.Logger) *MetadataRetryer {
|
||||
func NewMetadataRetryer(backgroundCtx context.Context, db *store.DB, metadataRunner *ai.MetadataRunner, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, logger *slog.Logger) *MetadataRetryer {
|
||||
if backgroundCtx == nil {
|
||||
backgroundCtx = context.Background()
|
||||
}
|
||||
return &MetadataRetryer{
|
||||
backgroundCtx: backgroundCtx,
|
||||
store: db,
|
||||
provider: provider,
|
||||
metadata: metadataRunner,
|
||||
capture: capture,
|
||||
sessions: sessions,
|
||||
metadataTimeout: metadataTimeout,
|
||||
@@ -223,7 +223,7 @@ func (r *MetadataRetryer) retryOne(ctx context.Context, id uuid.UUID) (bool, err
|
||||
}
|
||||
|
||||
attemptedAt := time.Now().UTC()
|
||||
extracted, extractErr := r.provider.ExtractMetadata(attemptCtx, thought.Content)
|
||||
extracted, extractErr := r.metadata.ExtractMetadata(attemptCtx, thought.Content)
|
||||
if extractErr != nil {
|
||||
failedMetadata := metadata.MarkMetadataFailed(thought.Metadata, r.capture, attemptedAt, extractErr)
|
||||
if _, updateErr := r.store.UpdateThoughtMetadata(ctx, thought.ID, failedMetadata); updateErr != nil {
|
||||
|
||||
@@ -15,10 +15,10 @@ import (
|
||||
)
|
||||
|
||||
type RecallTool struct {
|
||||
store *store.DB
|
||||
provider ai.Provider
|
||||
search config.SearchConfig
|
||||
sessions *session.ActiveProjects
|
||||
store *store.DB
|
||||
embeddings *ai.EmbeddingRunner
|
||||
search config.SearchConfig
|
||||
sessions *session.ActiveProjects
|
||||
}
|
||||
|
||||
type RecallInput struct {
|
||||
@@ -32,8 +32,8 @@ type RecallOutput struct {
|
||||
Items []ContextItem `json:"items"`
|
||||
}
|
||||
|
||||
func NewRecallTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *RecallTool {
|
||||
return &RecallTool{store: db, provider: provider, search: search, sessions: sessions}
|
||||
func NewRecallTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig, sessions *session.ActiveProjects) *RecallTool {
|
||||
return &RecallTool{store: db, embeddings: embeddings, search: search, sessions: sessions}
|
||||
}
|
||||
|
||||
func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in RecallInput) (*mcp.CallToolResult, RecallOutput, error) {
|
||||
@@ -54,7 +54,7 @@ func (t *RecallTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Re
|
||||
projectID = &project.ID
|
||||
}
|
||||
|
||||
semantic, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
|
||||
semantic, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
|
||||
if err != nil {
|
||||
return nil, RecallOutput{}, err
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ const metadataReparseConcurrency = 4
|
||||
|
||||
type ReparseMetadataTool struct {
|
||||
store *store.DB
|
||||
provider ai.Provider
|
||||
metadata *ai.MetadataRunner
|
||||
capture config.CaptureConfig
|
||||
sessions *session.ActiveProjects
|
||||
logger *slog.Logger
|
||||
@@ -53,8 +53,8 @@ type ReparseMetadataOutput struct {
|
||||
Failures []ReparseMetadataFailure `json:"failures,omitempty"`
|
||||
}
|
||||
|
||||
func NewReparseMetadataTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, sessions *session.ActiveProjects, logger *slog.Logger) *ReparseMetadataTool {
|
||||
return &ReparseMetadataTool{store: db, provider: provider, capture: capture, sessions: sessions, logger: logger}
|
||||
func NewReparseMetadataTool(db *store.DB, metadataRunner *ai.MetadataRunner, capture config.CaptureConfig, sessions *session.ActiveProjects, logger *slog.Logger) *ReparseMetadataTool {
|
||||
return &ReparseMetadataTool{store: db, metadata: metadataRunner, capture: capture, sessions: sessions, logger: logger}
|
||||
}
|
||||
|
||||
func (t *ReparseMetadataTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in ReparseMetadataInput) (*mcp.CallToolResult, ReparseMetadataOutput, error) {
|
||||
@@ -107,7 +107,7 @@ func (t *ReparseMetadataTool) Handle(ctx context.Context, req *mcp.CallToolReque
|
||||
normalizedCurrent := metadata.Normalize(thought.Metadata, t.capture)
|
||||
|
||||
attemptedAt := time.Now().UTC()
|
||||
extracted, extractErr := t.provider.ExtractMetadata(ctx, thought.Content)
|
||||
extracted, extractErr := t.metadata.ExtractMetadata(ctx, thought.Content)
|
||||
normalizedTarget := normalizedCurrent
|
||||
if extractErr != nil {
|
||||
normalizedTarget = metadata.MarkMetadataFailed(normalizedCurrent, t.capture, attemptedAt, extractErr)
|
||||
|
||||
@@ -11,12 +11,14 @@ import (
|
||||
thoughttypes "git.warky.dev/wdevs/amcs/internal/types"
|
||||
)
|
||||
|
||||
// semanticSearch runs vector similarity search if embeddings exist for the active model
|
||||
// in the given scope, otherwise falls back to Postgres full-text search.
|
||||
// semanticSearch runs vector similarity search if embeddings exist for the
|
||||
// primary embedding model in the given scope, otherwise falls back to Postgres
|
||||
// full-text search. Search always uses the primary model so query vectors
|
||||
// match rows stored under the primary model name.
|
||||
func semanticSearch(
|
||||
ctx context.Context,
|
||||
db *store.DB,
|
||||
provider ai.Provider,
|
||||
embeddings *ai.EmbeddingRunner,
|
||||
search config.SearchConfig,
|
||||
query string,
|
||||
limit int,
|
||||
@@ -24,17 +26,18 @@ func semanticSearch(
|
||||
projectID *uuid.UUID,
|
||||
excludeID *uuid.UUID,
|
||||
) ([]thoughttypes.SearchResult, error) {
|
||||
hasEmbeddings, err := db.HasEmbeddingsForModel(ctx, provider.EmbeddingModel(), projectID)
|
||||
model := embeddings.PrimaryModel()
|
||||
hasEmbeddings, err := db.HasEmbeddingsForModel(ctx, model, projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if hasEmbeddings {
|
||||
embedding, err := provider.Embed(ctx, query)
|
||||
embedding, err := embeddings.EmbedPrimary(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db.SearchSimilarThoughts(ctx, embedding, provider.EmbeddingModel(), threshold, limit, projectID, excludeID)
|
||||
return db.SearchSimilarThoughts(ctx, embedding, model, threshold, limit, projectID, excludeID)
|
||||
}
|
||||
|
||||
return db.SearchThoughtsText(ctx, query, limit, projectID, excludeID)
|
||||
|
||||
@@ -15,10 +15,10 @@ import (
|
||||
)
|
||||
|
||||
type SearchTool struct {
|
||||
store *store.DB
|
||||
provider ai.Provider
|
||||
search config.SearchConfig
|
||||
sessions *session.ActiveProjects
|
||||
store *store.DB
|
||||
embeddings *ai.EmbeddingRunner
|
||||
search config.SearchConfig
|
||||
sessions *session.ActiveProjects
|
||||
}
|
||||
|
||||
type SearchInput struct {
|
||||
@@ -32,8 +32,8 @@ type SearchOutput struct {
|
||||
Results []thoughttypes.SearchResult `json:"results"`
|
||||
}
|
||||
|
||||
func NewSearchTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *SearchTool {
|
||||
return &SearchTool{store: db, provider: provider, search: search, sessions: sessions}
|
||||
func NewSearchTool(db *store.DB, embeddings *ai.EmbeddingRunner, search config.SearchConfig, sessions *session.ActiveProjects) *SearchTool {
|
||||
return &SearchTool{store: db, embeddings: embeddings, search: search, sessions: sessions}
|
||||
}
|
||||
|
||||
func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in SearchInput) (*mcp.CallToolResult, SearchOutput, error) {
|
||||
@@ -56,7 +56,7 @@ func (t *SearchTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in Se
|
||||
_ = t.store.TouchProject(ctx, project.ID)
|
||||
}
|
||||
|
||||
results, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, threshold, projectID, nil)
|
||||
results, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, threshold, projectID, nil)
|
||||
if err != nil {
|
||||
return nil, SearchOutput{}, err
|
||||
}
|
||||
|
||||
@@ -14,10 +14,11 @@ import (
|
||||
)
|
||||
|
||||
type SummarizeTool struct {
|
||||
store *store.DB
|
||||
provider ai.Provider
|
||||
search config.SearchConfig
|
||||
sessions *session.ActiveProjects
|
||||
store *store.DB
|
||||
embeddings *ai.EmbeddingRunner
|
||||
metadata *ai.MetadataRunner
|
||||
search config.SearchConfig
|
||||
sessions *session.ActiveProjects
|
||||
}
|
||||
|
||||
type SummarizeInput struct {
|
||||
@@ -32,8 +33,8 @@ type SummarizeOutput struct {
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
func NewSummarizeTool(db *store.DB, provider ai.Provider, search config.SearchConfig, sessions *session.ActiveProjects) *SummarizeTool {
|
||||
return &SummarizeTool{store: db, provider: provider, search: search, sessions: sessions}
|
||||
func NewSummarizeTool(db *store.DB, embeddings *ai.EmbeddingRunner, metadata *ai.MetadataRunner, search config.SearchConfig, sessions *session.ActiveProjects) *SummarizeTool {
|
||||
return &SummarizeTool{store: db, embeddings: embeddings, metadata: metadata, search: search, sessions: sessions}
|
||||
}
|
||||
|
||||
func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in SummarizeInput) (*mcp.CallToolResult, SummarizeOutput, error) {
|
||||
@@ -52,7 +53,7 @@ func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
|
||||
if project != nil {
|
||||
projectID = &project.ID
|
||||
}
|
||||
results, err := semanticSearch(ctx, t.store, t.provider, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
|
||||
results, err := semanticSearch(ctx, t.store, t.embeddings, t.search, query, limit, t.search.DefaultThreshold, projectID, nil)
|
||||
if err != nil {
|
||||
return nil, SummarizeOutput{}, err
|
||||
}
|
||||
@@ -77,7 +78,7 @@ func (t *SummarizeTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in
|
||||
|
||||
userPrompt := formatContextBlock("Summarize the following thoughts into concise prose with themes, action items, and notable people.", lines)
|
||||
systemPrompt := "You summarize note collections. Be concise, concrete, and structured in plain prose."
|
||||
summary, err := t.provider.Summarize(ctx, systemPrompt, userPrompt)
|
||||
summary, err := t.metadata.Summarize(ctx, systemPrompt, userPrompt)
|
||||
if err != nil {
|
||||
return nil, SummarizeOutput{}, err
|
||||
}
|
||||
|
||||
@@ -16,10 +16,11 @@ import (
|
||||
)
|
||||
|
||||
type UpdateTool struct {
|
||||
store *store.DB
|
||||
provider ai.Provider
|
||||
capture config.CaptureConfig
|
||||
log *slog.Logger
|
||||
store *store.DB
|
||||
embeddings *ai.EmbeddingRunner
|
||||
metadata *ai.MetadataRunner
|
||||
capture config.CaptureConfig
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
type UpdateInput struct {
|
||||
@@ -33,8 +34,8 @@ type UpdateOutput struct {
|
||||
Thought thoughttypes.Thought `json:"thought"`
|
||||
}
|
||||
|
||||
func NewUpdateTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, log *slog.Logger) *UpdateTool {
|
||||
return &UpdateTool{store: db, provider: provider, capture: capture, log: log}
|
||||
func NewUpdateTool(db *store.DB, embeddings *ai.EmbeddingRunner, metadata *ai.MetadataRunner, capture config.CaptureConfig, log *slog.Logger) *UpdateTool {
|
||||
return &UpdateTool{store: db, embeddings: embeddings, metadata: metadata, capture: capture, log: log}
|
||||
}
|
||||
|
||||
func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in UpdateInput) (*mcp.CallToolResult, UpdateOutput, error) {
|
||||
@@ -50,6 +51,7 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda
|
||||
|
||||
content := current.Content
|
||||
var embedding []float32
|
||||
embeddingModel := ""
|
||||
mergedMetadata := current.Metadata
|
||||
projectID := current.ProjectID
|
||||
|
||||
@@ -58,11 +60,13 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda
|
||||
if content == "" {
|
||||
return nil, UpdateOutput{}, errInvalidInput("content must not be empty")
|
||||
}
|
||||
embedding, err = t.provider.Embed(ctx, content)
|
||||
embedResult, err := t.embeddings.Embed(ctx, content)
|
||||
if err != nil {
|
||||
return nil, UpdateOutput{}, err
|
||||
}
|
||||
extracted, extractErr := t.provider.ExtractMetadata(ctx, content)
|
||||
embedding = embedResult.Vector
|
||||
embeddingModel = embedResult.Model
|
||||
extracted, extractErr := t.metadata.ExtractMetadata(ctx, content)
|
||||
if extractErr != nil {
|
||||
t.log.Warn("metadata extraction failed during update, keeping current metadata", slog.String("error", extractErr.Error()))
|
||||
mergedMetadata = metadata.MarkMetadataFailed(mergedMetadata, t.capture, time.Now().UTC(), extractErr)
|
||||
@@ -82,7 +86,7 @@ func (t *UpdateTool) Handle(ctx context.Context, _ *mcp.CallToolRequest, in Upda
|
||||
projectID = &project.ID
|
||||
}
|
||||
|
||||
updated, err := t.store.UpdateThought(ctx, id, content, embedding, t.provider.EmbeddingModel(), mergedMetadata, projectID)
|
||||
updated, err := t.store.UpdateThought(ctx, id, content, embedding, embeddingModel, mergedMetadata, projectID)
|
||||
if err != nil {
|
||||
return nil, UpdateOutput{}, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user