diff --git a/internal/app/app.go b/internal/app/app.go index 8a76186..e84c369 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -129,7 +129,7 @@ func routes(logger *slog.Logger, cfg *config.Config, db *store.DB, provider ai.P mux := http.NewServeMux() toolSet := mcpserver.ToolSet{ - Capture: tools.NewCaptureTool(db, provider, cfg.Capture, activeProjects, logger), + Capture: tools.NewCaptureTool(db, provider, cfg.Capture, cfg.AI.Metadata.Timeout, activeProjects, logger), Search: tools.NewSearchTool(db, provider, cfg.Search, activeProjects), List: tools.NewListTool(db, cfg.Search, activeProjects), Stats: tools.NewStatsTool(db), diff --git a/internal/config/config.go b/internal/config/config.go index da39545..84c0daa 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -84,11 +84,12 @@ type AIEmbeddingConfig struct { } type AIMetadataConfig struct { - Model string `yaml:"model"` - FallbackModels []string `yaml:"fallback_models"` - FallbackModel string `yaml:"fallback_model"` // legacy single fallback - Temperature float64 `yaml:"temperature"` - LogConversations bool `yaml:"log_conversations"` + Model string `yaml:"model"` + FallbackModels []string `yaml:"fallback_models"` + FallbackModel string `yaml:"fallback_model"` // legacy single fallback + Temperature float64 `yaml:"temperature"` + LogConversations bool `yaml:"log_conversations"` + Timeout time.Duration `yaml:"timeout"` } type LiteLLMConfig struct { diff --git a/internal/config/loader.go b/internal/config/loader.go index 0a27f04..2448248 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -73,6 +73,7 @@ func defaultConfig() Config { Metadata: AIMetadataConfig{ Model: "gpt-4o-mini", Temperature: 0.1, + Timeout: 10 * time.Second, }, Ollama: OllamaConfig{ BaseURL: "http://localhost:11434/v1", diff --git a/internal/tools/capture.go b/internal/tools/capture.go index 91d3e9a..3a9f879 100644 --- a/internal/tools/capture.go +++ b/internal/tools/capture.go @@ -4,6 +4,7 @@ import ( "context" "log/slog" "strings" + "time" "github.com/modelcontextprotocol/go-sdk/mcp" "golang.org/x/sync/errgroup" @@ -17,11 +18,12 @@ import ( ) type CaptureTool struct { - store *store.DB - provider ai.Provider - capture config.CaptureConfig - sessions *session.ActiveProjects - log *slog.Logger + store *store.DB + provider ai.Provider + capture config.CaptureConfig + sessions *session.ActiveProjects + metadataTimeout time.Duration + log *slog.Logger } type CaptureInput struct { @@ -33,8 +35,8 @@ type CaptureOutput struct { Thought thoughttypes.Thought `json:"thought"` } -func NewCaptureTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, sessions *session.ActiveProjects, log *slog.Logger) *CaptureTool { - return &CaptureTool{store: db, provider: provider, capture: capture, sessions: sessions, log: log} +func NewCaptureTool(db *store.DB, provider ai.Provider, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, log *slog.Logger) *CaptureTool { + return &CaptureTool{store: db, provider: provider, capture: capture, sessions: sessions, metadataTimeout: metadataTimeout, log: log} } func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in CaptureInput) (*mcp.CallToolResult, CaptureOutput, error) { @@ -61,7 +63,13 @@ func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in C return nil }) group.Go(func() error { - extracted, err := t.provider.ExtractMetadata(groupCtx, content) + metaCtx := groupCtx + if t.metadataTimeout > 0 { + var cancel context.CancelFunc + metaCtx, cancel = context.WithTimeout(groupCtx, t.metadataTimeout) + defer cancel() + } + extracted, err := t.provider.ExtractMetadata(metaCtx, content) if err != nil { t.log.Warn("metadata extraction failed, using fallback", slog.String("provider", t.provider.Name()), slog.String("error", err.Error())) return nil