package mcpserver import ( "context" "encoding/json" "fmt" "log/slog" "reflect" "time" "github.com/google/jsonschema-go/jsonschema" "github.com/google/uuid" "github.com/modelcontextprotocol/go-sdk/mcp" ) const maxLoggedArgBytes = 512 var toolSchemaOptions = &jsonschema.ForOptions{ TypeSchemas: map[reflect.Type]*jsonschema.Schema{ reflect.TypeFor[uuid.UUID](): { Type: "string", Format: "uuid", }, }, } func addTool[In any, Out any](server *mcp.Server, logger *slog.Logger, tool *mcp.Tool, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) { if err := setToolSchemas[In, Out](tool); err != nil { panic(fmt.Sprintf("configure MCP tool %q schemas: %v", tool.Name, err)) } mcp.AddTool(server, tool, logToolCall(logger, tool.Name, handler)) } func logToolCall[In any, Out any](logger *slog.Logger, toolName string, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error) { if logger == nil { return handler } return func(ctx context.Context, req *mcp.CallToolRequest, in In) (*mcp.CallToolResult, Out, error) { start := time.Now() attrs := []any{slog.String("tool", toolName)} if req != nil && req.Params != nil { attrs = append(attrs, slog.String("arguments", truncateArgs(req.Params.Arguments))) } logger.Info("mcp tool started", attrs...) result, out, err := handler(ctx, req, in) completionAttrs := append([]any{}, attrs...) completionAttrs = append(completionAttrs, slog.String("duration", formatLogDuration(time.Since(start)))) if err != nil { completionAttrs = append(completionAttrs, slog.String("error", err.Error())) logger.Error("mcp tool completed", completionAttrs...) return result, out, err } logger.Info("mcp tool completed", completionAttrs...) return result, out, nil } } func truncateArgs(args any) string { b, err := json.Marshal(args) if err != nil { return "" } if len(b) <= maxLoggedArgBytes { return string(b) } return string(b[:maxLoggedArgBytes]) + fmt.Sprintf("… (%d bytes total)", len(b)) } func formatLogDuration(d time.Duration) string { if d < 0 { d = -d } totalMilliseconds := d.Milliseconds() minutes := totalMilliseconds / 60000 seconds := (totalMilliseconds / 1000) % 60 milliseconds := totalMilliseconds % 1000 return fmt.Sprintf("%02d:%02d:%03d", minutes, seconds, milliseconds) } func setToolSchemas[In any, Out any](tool *mcp.Tool) error { if tool.InputSchema == nil { inputSchema, err := jsonschema.For[In](toolSchemaOptions) if err != nil { return fmt.Errorf("infer input schema: %w", err) } tool.InputSchema = inputSchema } if tool.OutputSchema == nil { outputSchema, err := jsonschema.For[Out](toolSchemaOptions) if err != nil { return fmt.Errorf("infer output schema: %w", err) } tool.OutputSchema = outputSchema } return nil }