feat(cli): add verbose logging option for CLI commands
Some checks failed
CI / build-and-test (push) Failing after -32m43s
Some checks failed
CI / build-and-test (push) Failing after -32m43s
* Introduced a new flag `--verbose` to enable detailed logging. * Implemented logging for connection events in SSE and stdio commands. * Added a utility function to handle verbose logging.
This commit is contained in:
@@ -17,6 +17,7 @@ var (
|
||||
serverFlag string
|
||||
tokenFlag string
|
||||
outputFlag string
|
||||
verbose bool
|
||||
cfg Config
|
||||
)
|
||||
|
||||
@@ -42,6 +43,7 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVar(&serverFlag, "server", "", "AMCS server URL")
|
||||
rootCmd.PersistentFlags().StringVar(&tokenFlag, "token", "", "AMCS bearer token")
|
||||
rootCmd.PersistentFlags().StringVar(&outputFlag, "output", "json", "Output format: json or yaml")
|
||||
rootCmd.PersistentFlags().BoolVar(&verbose, "verbose", false, "Enable verbose logging to stderr")
|
||||
}
|
||||
|
||||
func loadConfig() error {
|
||||
@@ -122,6 +124,7 @@ func connectRemote(ctx context.Context) (*mcp.ClientSession, error) {
|
||||
if err := requireServer(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
verboseLogf("connecting to %s", endpointURL())
|
||||
client := mcp.NewClient(&mcp.Implementation{Name: "amcs-cli", Version: "0.0.1"}, nil)
|
||||
transport := &mcp.StreamableClientTransport{
|
||||
Endpoint: endpointURL(),
|
||||
@@ -133,5 +136,13 @@ func connectRemote(ctx context.Context) (*mcp.ClientSession, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to AMCS server: %w", err)
|
||||
}
|
||||
verboseLogf("connected to %s", endpointURL())
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func verboseLogf(format string, args ...any) {
|
||||
if !verbose {
|
||||
return
|
||||
}
|
||||
_, _ = fmt.Fprintf(os.Stderr, "[amcs-cli] "+format+"\n", args...)
|
||||
}
|
||||
|
||||
31
cmd/amcs-cli/cmd/root_test.go
Normal file
31
cmd/amcs-cli/cmd/root_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBearerTransportFormatsBearerToken(t *testing.T) {
|
||||
const want = "Bearer X"
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if got := r.Header.Get("Authorization"); got != want {
|
||||
t.Fatalf("Authorization header = %q, want %q", got, want)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{Transport: &bearerTransport{token: "X"}}
|
||||
req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest() error = %v", err)
|
||||
}
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("client.Do() error = %v", err)
|
||||
}
|
||||
_ = res.Body.Close()
|
||||
}
|
||||
@@ -29,11 +29,13 @@ var sseCmd = &cobra.Command{
|
||||
connectCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
verboseLogf("connecting to SSE endpoint %s", sseEndpointURL())
|
||||
remote, err := client.Connect(connectCtx, transport, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to AMCS SSE endpoint: %w", err)
|
||||
}
|
||||
defer func() { _ = remote.Close() }()
|
||||
verboseLogf("connected to SSE endpoint %s", sseEndpointURL())
|
||||
|
||||
tools, err := remote.ListTools(ctx, nil)
|
||||
if err != nil {
|
||||
@@ -67,6 +69,8 @@ var sseCmd = &cobra.Command{
|
||||
return fmt.Errorf("start stdio bridge: %w", err)
|
||||
}
|
||||
defer func() { _ = session.Close() }()
|
||||
verboseLogf("sse stdio bridge ready")
|
||||
verboseLogf("waiting for MCP commands on stdin")
|
||||
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
|
||||
@@ -51,6 +51,8 @@ var stdioCmd = &cobra.Command{
|
||||
return fmt.Errorf("start stdio bridge: %w", err)
|
||||
}
|
||||
defer func() { _ = session.Close() }()
|
||||
verboseLogf("stdio bridge connected to remote AMCS and ready")
|
||||
verboseLogf("waiting for MCP commands on stdin")
|
||||
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
|
||||
@@ -193,7 +193,7 @@ func routes(logger *slog.Logger, cfg *config.Config, info buildinfo.Info, db *st
|
||||
backfillTool := tools.NewBackfillTool(db, bgEmbeddings, activeProjects, logger)
|
||||
|
||||
toolSet := mcpserver.ToolSet{
|
||||
Capture: tools.NewCaptureTool(db, embeddings, metadata, cfg.Capture, cfg.AI.Metadata.Timeout, activeProjects, nil, backfillTool, logger),
|
||||
Capture: tools.NewCaptureTool(db, embeddings, cfg.Capture, activeProjects, enrichmentRetryer, backfillTool),
|
||||
Search: tools.NewSearchTool(db, embeddings, cfg.Search, activeProjects),
|
||||
List: tools.NewListTool(db, cfg.Search, activeProjects),
|
||||
Stats: tools.NewStatsTool(db),
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/auth"
|
||||
"git.warky.dev/wdevs/amcs/internal/requestip"
|
||||
)
|
||||
|
||||
// --- JSON types ---
|
||||
@@ -261,7 +262,7 @@ func handleClientCredentials(w http.ResponseWriter, r *http.Request, oauthRegist
|
||||
}
|
||||
keyID, ok := oauthRegistry.Lookup(clientID, clientSecret)
|
||||
if !ok {
|
||||
log.Warn("oauth token: invalid client credentials", slog.String("remote_addr", r.RemoteAddr))
|
||||
log.Warn("oauth token: invalid client credentials", slog.String("remote_addr", requestip.FromRequest(r)))
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="oauth"`)
|
||||
writeTokenError(w, "invalid_client", http.StatusUnauthorized)
|
||||
return
|
||||
@@ -290,7 +291,7 @@ func handleAuthorizationCode(w http.ResponseWriter, r *http.Request, authCodes *
|
||||
return
|
||||
}
|
||||
if !verifyPKCE(codeVerifier, entry.CodeChallenge, entry.CodeChallengeMethod) {
|
||||
log.Warn("oauth token: PKCE verification failed", slog.String("remote_addr", r.RemoteAddr))
|
||||
log.Warn("oauth token: PKCE verification failed", slog.String("remote_addr", requestip.FromRequest(r)))
|
||||
writeTokenError(w, "invalid_grant", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -157,3 +157,34 @@ func TestMiddlewareRejectsMissingOrInvalidKey(t *testing.T) {
|
||||
t.Fatalf("invalid key status = %d, want %d", rec.Code, http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareRecordsForwardedRemoteAddr(t *testing.T) {
|
||||
keyring, err := NewKeyring([]config.APIKey{{ID: "client-a", Value: "secret"}})
|
||||
if err != nil {
|
||||
t.Fatalf("NewKeyring() error = %v", err)
|
||||
}
|
||||
tracker := NewAccessTracker()
|
||||
|
||||
handler := Middleware(config.AuthConfig{HeaderName: "x-brain-key"}, keyring, nil, nil, tracker, testLogger())(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
|
||||
req.RemoteAddr = "10.0.0.5:2222"
|
||||
req.Header.Set("x-brain-key", "secret")
|
||||
req.Header.Set("X-Real-IP", "203.0.113.99")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNoContent {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
|
||||
}
|
||||
|
||||
snap := tracker.Snapshot()
|
||||
if len(snap) != 1 {
|
||||
t.Fatalf("len(snapshot) = %d, want 1", len(snap))
|
||||
}
|
||||
if snap[0].RemoteAddr != "203.0.113.99" {
|
||||
t.Fatalf("snapshot remote_addr = %q, want %q", snap[0].RemoteAddr, "203.0.113.99")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/config"
|
||||
"git.warky.dev/wdevs/amcs/internal/requestip"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
@@ -22,17 +23,18 @@ func Middleware(cfg config.AuthConfig, keyring *Keyring, oauthRegistry *OAuthReg
|
||||
}
|
||||
recordAccess := func(r *http.Request, keyID string) {
|
||||
if tracker != nil {
|
||||
tracker.Record(keyID, r.URL.Path, r.RemoteAddr, r.UserAgent(), time.Now())
|
||||
tracker.Record(keyID, r.URL.Path, requestip.FromRequest(r), r.UserAgent(), time.Now())
|
||||
}
|
||||
}
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
remoteAddr := requestip.FromRequest(r)
|
||||
// 1. Custom header → keyring only.
|
||||
if keyring != nil {
|
||||
if token := strings.TrimSpace(r.Header.Get(headerName)); token != "" {
|
||||
keyID, ok := keyring.Lookup(token)
|
||||
if !ok {
|
||||
log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr))
|
||||
log.Warn("authentication failed", slog.String("remote_addr", remoteAddr))
|
||||
http.Error(w, "invalid API key", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
@@ -58,7 +60,7 @@ func Middleware(cfg config.AuthConfig, keyring *Keyring, oauthRegistry *OAuthReg
|
||||
return
|
||||
}
|
||||
}
|
||||
log.Warn("bearer token rejected", slog.String("remote_addr", r.RemoteAddr))
|
||||
log.Warn("bearer token rejected", slog.String("remote_addr", remoteAddr))
|
||||
http.Error(w, "invalid token or API key", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
@@ -71,7 +73,7 @@ func Middleware(cfg config.AuthConfig, keyring *Keyring, oauthRegistry *OAuthReg
|
||||
}
|
||||
keyID, ok := oauthRegistry.Lookup(clientID, clientSecret)
|
||||
if !ok {
|
||||
log.Warn("oauth client authentication failed", slog.String("remote_addr", r.RemoteAddr))
|
||||
log.Warn("oauth client authentication failed", slog.String("remote_addr", remoteAddr))
|
||||
http.Error(w, "invalid OAuth client credentials", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
@@ -85,7 +87,7 @@ func Middleware(cfg config.AuthConfig, keyring *Keyring, oauthRegistry *OAuthReg
|
||||
if token := strings.TrimSpace(r.URL.Query().Get(cfg.QueryParam)); token != "" {
|
||||
keyID, ok := keyring.Lookup(token)
|
||||
if !ok {
|
||||
log.Warn("authentication failed", slog.String("remote_addr", r.RemoteAddr))
|
||||
log.Warn("authentication failed", slog.String("remote_addr", remoteAddr))
|
||||
http.Error(w, "invalid API key", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -3,12 +3,13 @@ package observability
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"git.warky.dev/wdevs/amcs/internal/requestip"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
@@ -67,7 +68,7 @@ func AccessLog(log *slog.Logger) func(http.Handler) http.Handler {
|
||||
slog.String("path", r.URL.Path),
|
||||
slog.Int("status", recorder.status),
|
||||
slog.Duration("duration", time.Since(started)),
|
||||
slog.String("remote_addr", stripPort(r.RemoteAddr)),
|
||||
slog.String("remote_addr", requestip.FromRequest(r)),
|
||||
)
|
||||
})
|
||||
}
|
||||
@@ -100,11 +101,3 @@ func (s *statusRecorder) WriteHeader(statusCode int) {
|
||||
s.status = statusCode
|
||||
s.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func stripPort(remote string) string {
|
||||
host, _, err := net.SplitHostPort(remote)
|
||||
if err != nil {
|
||||
return remote
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package observability
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -57,3 +59,24 @@ func TestRecoverHandlesPanic(t *testing.T) {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLogUsesForwardedClientIP(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
handler := AccessLog(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/mcp", nil)
|
||||
req.RemoteAddr = "10.0.0.10:1234"
|
||||
req.Header.Set("X-Real-IP", "203.0.113.7")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNoContent {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent)
|
||||
}
|
||||
if !strings.Contains(buf.String(), "remote_addr=203.0.113.7") {
|
||||
t.Fatalf("log output = %q, want remote_addr=203.0.113.7", buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
76
internal/requestip/requestip.go
Normal file
76
internal/requestip/requestip.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package requestip
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FromRequest returns the best-effort client IP/host for a request, preferring
|
||||
// proxy headers before falling back to RemoteAddr.
|
||||
//
|
||||
// Header precedence:
|
||||
// 1) X-Real-IP
|
||||
// 2) X-Forwarded-Host
|
||||
// 3) X-Forwarded-For (first value)
|
||||
// 4) Forwarded (for=...)
|
||||
// 5) RemoteAddr (host part)
|
||||
func FromRequest(r *http.Request) string {
|
||||
if r == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if v := firstAddressToken(r.Header.Get("X-Real-IP")); v != "" {
|
||||
return stripPort(v)
|
||||
}
|
||||
if v := firstAddressToken(r.Header.Get("X-Forwarded-Host")); v != "" {
|
||||
return stripPort(v)
|
||||
}
|
||||
if v := firstAddressToken(r.Header.Get("X-Forwarded-For")); v != "" {
|
||||
return stripPort(v)
|
||||
}
|
||||
if v := forwardedForValue(r.Header.Get("Forwarded")); v != "" {
|
||||
return stripPort(v)
|
||||
}
|
||||
return stripPort(strings.TrimSpace(r.RemoteAddr))
|
||||
}
|
||||
|
||||
func firstAddressToken(v string) string {
|
||||
if v == "" {
|
||||
return ""
|
||||
}
|
||||
part := strings.TrimSpace(strings.Split(v, ",")[0])
|
||||
part = strings.Trim(part, `"`)
|
||||
return strings.TrimSpace(part)
|
||||
}
|
||||
|
||||
func forwardedForValue(v string) string {
|
||||
for _, part := range strings.Split(v, ",") {
|
||||
for _, kv := range strings.Split(part, ";") {
|
||||
k, raw, ok := strings.Cut(strings.TrimSpace(kv), "=")
|
||||
if !ok || !strings.EqualFold(strings.TrimSpace(k), "for") {
|
||||
continue
|
||||
}
|
||||
candidate := strings.Trim(strings.TrimSpace(raw), `"`)
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func stripPort(addr string) string {
|
||||
addr = strings.TrimSpace(addr)
|
||||
if addr == "" {
|
||||
return ""
|
||||
}
|
||||
// RFC 7239 quoted values may wrap IPv6 with brackets.
|
||||
addr = strings.Trim(addr, "[]")
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err == nil {
|
||||
return host
|
||||
}
|
||||
return addr
|
||||
}
|
||||
47
internal/requestip/requestip_test.go
Normal file
47
internal/requestip/requestip_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package requestip
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFromRequestPrefersXRealIP(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "10.0.0.10:5555"
|
||||
req.Header.Set("X-Forwarded-Host", "proxy.example.com")
|
||||
req.Header.Set("X-Real-IP", "203.0.113.10")
|
||||
|
||||
if got := FromRequest(req); got != "203.0.113.10" {
|
||||
t.Fatalf("FromRequest() = %q, want %q", got, "203.0.113.10")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromRequestUsesXForwardedHostWhenRealIPMissing(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "10.0.0.10:5555"
|
||||
req.Header.Set("X-Forwarded-Host", "203.0.113.22")
|
||||
|
||||
if got := FromRequest(req); got != "203.0.113.22" {
|
||||
t.Fatalf("FromRequest() = %q, want %q", got, "203.0.113.22")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromRequestUsesXForwardedForFirstValue(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "10.0.0.10:5555"
|
||||
req.Header.Set("X-Forwarded-For", "198.51.100.7, 10.1.1.2")
|
||||
|
||||
if got := FromRequest(req); got != "198.51.100.7" {
|
||||
t.Fatalf("FromRequest() = %q, want %q", got, "198.51.100.7")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromRequestFallsBackToRemoteAddr(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = "192.0.2.5:1234"
|
||||
|
||||
if got := FromRequest(req); got != "192.0.2.5" {
|
||||
t.Fatalf("FromRequest() = %q, want %q", got, "192.0.2.5")
|
||||
}
|
||||
}
|
||||
@@ -55,24 +55,41 @@ func NewBackfillTool(db *store.DB, embeddings *ai.EmbeddingRunner, sessions *ses
|
||||
// It is used by capture when the embedding provider is temporarily unavailable.
|
||||
func (t *BackfillTool) QueueThought(ctx context.Context, id uuid.UUID, content string) {
|
||||
go func() {
|
||||
started := time.Now()
|
||||
t.logger.Info("background embedding started",
|
||||
slog.String("thought_id", id.String()),
|
||||
slog.String("provider", t.embeddings.PrimaryProvider()),
|
||||
slog.String("model", t.embeddings.PrimaryModel()),
|
||||
)
|
||||
|
||||
result, err := t.embeddings.Embed(ctx, content)
|
||||
if err != nil {
|
||||
t.logger.Warn("background embedding retry failed",
|
||||
t.logger.Warn("background embedding error",
|
||||
slog.String("thought_id", id.String()),
|
||||
slog.String("provider", t.embeddings.PrimaryProvider()),
|
||||
slog.String("model", t.embeddings.PrimaryModel()),
|
||||
slog.String("stage", "embed"),
|
||||
slog.Duration("duration", time.Since(started)),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
return
|
||||
}
|
||||
if err := t.store.UpsertEmbedding(ctx, id, result.Model, result.Vector); err != nil {
|
||||
t.logger.Warn("background embedding upsert failed",
|
||||
t.logger.Warn("background embedding error",
|
||||
slog.String("thought_id", id.String()),
|
||||
slog.String("provider", t.embeddings.PrimaryProvider()),
|
||||
slog.String("model", result.Model),
|
||||
slog.String("stage", "upsert"),
|
||||
slog.Duration("duration", time.Since(started)),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
return
|
||||
}
|
||||
t.logger.Info("background embedding retry succeeded",
|
||||
t.logger.Info("background embedding complete",
|
||||
slog.String("thought_id", id.String()),
|
||||
slog.String("provider", t.embeddings.PrimaryProvider()),
|
||||
slog.String("model", result.Model),
|
||||
slog.Duration("duration", time.Since(started)),
|
||||
)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -2,9 +2,7 @@ package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
@@ -29,15 +27,12 @@ type MetadataQueuer interface {
|
||||
}
|
||||
|
||||
type CaptureTool struct {
|
||||
store *store.DB
|
||||
embeddings *ai.EmbeddingRunner
|
||||
metadata *ai.MetadataRunner
|
||||
capture config.CaptureConfig
|
||||
sessions *session.ActiveProjects
|
||||
metadataTimeout time.Duration
|
||||
retryer MetadataQueuer
|
||||
embedRetryer EmbeddingQueuer
|
||||
log *slog.Logger
|
||||
store *store.DB
|
||||
embeddings *ai.EmbeddingRunner
|
||||
capture config.CaptureConfig
|
||||
sessions *session.ActiveProjects
|
||||
retryer MetadataQueuer
|
||||
embedRetryer EmbeddingQueuer
|
||||
}
|
||||
|
||||
type CaptureInput struct {
|
||||
@@ -49,8 +44,8 @@ type CaptureOutput struct {
|
||||
Thought thoughttypes.Thought `json:"thought"`
|
||||
}
|
||||
|
||||
func NewCaptureTool(db *store.DB, embeddings *ai.EmbeddingRunner, metadata *ai.MetadataRunner, capture config.CaptureConfig, metadataTimeout time.Duration, sessions *session.ActiveProjects, retryer MetadataQueuer, embedRetryer EmbeddingQueuer, log *slog.Logger) *CaptureTool {
|
||||
return &CaptureTool{store: db, embeddings: embeddings, metadata: metadata, capture: capture, sessions: sessions, metadataTimeout: metadataTimeout, retryer: retryer, embedRetryer: embedRetryer, log: log}
|
||||
func NewCaptureTool(db *store.DB, embeddings *ai.EmbeddingRunner, capture config.CaptureConfig, sessions *session.ActiveProjects, retryer MetadataQueuer, embedRetryer EmbeddingQueuer) *CaptureTool {
|
||||
return &CaptureTool{store: db, embeddings: embeddings, capture: capture, sessions: sessions, retryer: retryer, embedRetryer: embedRetryer}
|
||||
}
|
||||
|
||||
func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in CaptureInput) (*mcp.CallToolResult, CaptureOutput, error) {
|
||||
@@ -65,6 +60,7 @@ func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in C
|
||||
}
|
||||
|
||||
rawMetadata := metadata.Fallback(t.capture)
|
||||
rawMetadata.MetadataStatus = metadata.MetadataStatusPending
|
||||
thought := thoughttypes.Thought{
|
||||
Content: content,
|
||||
Metadata: rawMetadata,
|
||||
@@ -81,56 +77,12 @@ func (t *CaptureTool) Handle(ctx context.Context, req *mcp.CallToolRequest, in C
|
||||
_ = t.store.TouchProject(ctx, project.ID)
|
||||
}
|
||||
|
||||
if t.retryer != nil || t.embedRetryer != nil {
|
||||
t.launchEnrichment(created.ID, content)
|
||||
if t.retryer != nil {
|
||||
t.retryer.QueueThought(created.ID)
|
||||
}
|
||||
if t.embedRetryer != nil {
|
||||
t.embedRetryer.QueueThought(ctx, created.ID, content)
|
||||
}
|
||||
|
||||
return nil, CaptureOutput{Thought: created}, nil
|
||||
}
|
||||
|
||||
func (t *CaptureTool) launchEnrichment(id uuid.UUID, content string) {
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
if t.retryer != nil {
|
||||
attemptedAt := time.Now().UTC()
|
||||
rawMetadata := metadata.Fallback(t.capture)
|
||||
extracted, err := t.metadata.ExtractMetadata(ctx, content)
|
||||
if err != nil {
|
||||
failed := metadata.MarkMetadataFailed(rawMetadata, t.capture, attemptedAt, err)
|
||||
if _, updateErr := t.store.UpdateThoughtMetadata(ctx, id, failed); updateErr != nil {
|
||||
t.log.Warn("deferred metadata failure could not be persisted",
|
||||
slog.String("thought_id", id.String()),
|
||||
slog.String("error", updateErr.Error()),
|
||||
)
|
||||
}
|
||||
t.log.Warn("deferred metadata extraction failed",
|
||||
slog.String("thought_id", id.String()),
|
||||
slog.String("provider", t.metadata.PrimaryProvider()),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
t.retryer.QueueThought(id)
|
||||
} else {
|
||||
completed := metadata.MarkMetadataComplete(extracted, t.capture, attemptedAt)
|
||||
if _, updateErr := t.store.UpdateThoughtMetadata(ctx, id, completed); updateErr != nil {
|
||||
t.log.Warn("deferred metadata completion could not be persisted",
|
||||
slog.String("thought_id", id.String()),
|
||||
slog.String("error", updateErr.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if t.embedRetryer != nil {
|
||||
if _, err := t.embeddings.Embed(ctx, content); err != nil {
|
||||
t.log.Warn("deferred embedding failed",
|
||||
slog.String("thought_id", id.String()),
|
||||
slog.String("provider", t.embeddings.PrimaryProvider()),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
}
|
||||
t.embedRetryer.QueueThought(ctx, id, content)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -91,12 +91,30 @@ func (t *RetryEnrichmentTool) Handle(ctx context.Context, req *mcp.CallToolReque
|
||||
|
||||
func (r *EnrichmentRetryer) QueueThought(id uuid.UUID) {
|
||||
go func() {
|
||||
if _, err := r.retryOne(r.backgroundCtx, id); err != nil {
|
||||
r.logger.Warn("background metadata retry failed",
|
||||
started := time.Now()
|
||||
r.logger.Info("background metadata started",
|
||||
slog.String("thought_id", id.String()),
|
||||
slog.String("provider", r.metadata.PrimaryProvider()),
|
||||
slog.String("model", r.metadata.PrimaryModel()),
|
||||
)
|
||||
updated, err := r.retryOne(r.backgroundCtx, id)
|
||||
if err != nil {
|
||||
r.logger.Warn("background metadata error",
|
||||
slog.String("thought_id", id.String()),
|
||||
slog.String("provider", r.metadata.PrimaryProvider()),
|
||||
slog.String("model", r.metadata.PrimaryModel()),
|
||||
slog.Duration("duration", time.Since(started)),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
return
|
||||
}
|
||||
r.logger.Info("background metadata complete",
|
||||
slog.String("thought_id", id.String()),
|
||||
slog.String("provider", r.metadata.PrimaryProvider()),
|
||||
slog.String("model", r.metadata.PrimaryModel()),
|
||||
slog.Bool("updated", updated),
|
||||
slog.Duration("duration", time.Since(started)),
|
||||
)
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -113,13 +113,35 @@ func (t *RetryMetadataTool) Handle(ctx context.Context, req *mcp.CallToolRequest
|
||||
|
||||
func (r *MetadataRetryer) QueueThought(id uuid.UUID) {
|
||||
go func() {
|
||||
started := time.Now()
|
||||
if !r.lock.Acquire(id, 15*time.Minute) {
|
||||
return
|
||||
}
|
||||
defer r.lock.Release(id)
|
||||
if _, err := r.retryOne(r.backgroundCtx, id); err != nil {
|
||||
r.logger.Warn("background metadata retry failed", slog.String("thought_id", id.String()), slog.String("error", err.Error()))
|
||||
|
||||
r.logger.Info("background metadata started",
|
||||
slog.String("thought_id", id.String()),
|
||||
slog.String("provider", r.metadata.PrimaryProvider()),
|
||||
slog.String("model", r.metadata.PrimaryModel()),
|
||||
)
|
||||
updated, err := r.retryOne(r.backgroundCtx, id)
|
||||
if err != nil {
|
||||
r.logger.Warn("background metadata error",
|
||||
slog.String("thought_id", id.String()),
|
||||
slog.String("provider", r.metadata.PrimaryProvider()),
|
||||
slog.String("model", r.metadata.PrimaryModel()),
|
||||
slog.Duration("duration", time.Since(started)),
|
||||
slog.String("error", err.Error()),
|
||||
)
|
||||
return
|
||||
}
|
||||
r.logger.Info("background metadata complete",
|
||||
slog.String("thought_id", id.String()),
|
||||
slog.String("provider", r.metadata.PrimaryProvider()),
|
||||
slog.String("model", r.metadata.PrimaryModel()),
|
||||
slog.Bool("updated", updated),
|
||||
slog.Duration("duration", time.Since(started)),
|
||||
)
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user