package mcpserver import ( "context" "encoding/json" "fmt" "log/slog" "reflect" "regexp" "strings" "time" "github.com/google/jsonschema-go/jsonschema" "github.com/google/uuid" "github.com/modelcontextprotocol/go-sdk/jsonrpc" "github.com/modelcontextprotocol/go-sdk/mcp" "git.warky.dev/wdevs/amcs/internal/mcperrors" ) const maxLoggedArgBytes = 512 var sensitiveToolArgKeys = map[string]struct{}{ "client_secret": {}, "content": {}, "content_base64": {}, "content_path": {}, "email": {}, "notes": {}, "phone": {}, "value": {}, } var toolSchemaOptions = &jsonschema.ForOptions{ TypeSchemas: map[reflect.Type]*jsonschema.Schema{ reflect.TypeFor[uuid.UUID](): { Type: "string", Format: "uuid", }, }, } var ( missingPropertyPattern = regexp.MustCompile(`missing properties: \["([^"]+)"\]`) propertyPathPattern = regexp.MustCompile(`validating /properties/([^:]+):`) structFieldPattern = regexp.MustCompile(`\.([A-Za-z0-9_]+) of type`) ) func addTool[In any, Out any](server *mcp.Server, logger *slog.Logger, tool *mcp.Tool, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) error { if err := setToolSchemas[In, Out](tool); err != nil { return fmt.Errorf("configure MCP tool %q schemas: %w", tool.Name, err) } inputResolved, err := resolveToolSchema(tool.InputSchema) if err != nil { return fmt.Errorf("resolve MCP tool %q input schema: %w", tool.Name, err) } outputResolved, err := resolveToolSchema(tool.OutputSchema) if err != nil { return fmt.Errorf("resolve MCP tool %q output schema: %w", tool.Name, err) } server.AddTool(tool, logToolCall(logger, tool.Name, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { var input json.RawMessage if req != nil && req.Params != nil && req.Params.Arguments != nil { input = req.Params.Arguments } input, err = applyToolSchema(input, inputResolved) if err != nil { return nil, invalidArgumentsError(err) } var in In if input != nil { if err := json.Unmarshal(input, &in); err != nil { return nil, invalidArgumentsError(err) } } result, out, err := handler(ctx, req, in) if err != nil { if wireErr, ok := err.(*jsonrpc.Error); ok { return nil, wireErr } var errResult mcp.CallToolResult errResult.SetError(err) return &errResult, nil } if result == nil { result = &mcp.CallToolResult{} } if outputResolved == nil { return result, nil } outValue := normalizeTypedNil[Out](out) if outValue == nil { return result, nil } outBytes, err := json.Marshal(outValue) if err != nil { return nil, fmt.Errorf("marshaling output: %w", err) } outJSON, err := applyToolSchema(json.RawMessage(outBytes), outputResolved) if err != nil { return nil, fmt.Errorf("validating tool output: %w", err) } result.StructuredContent = outJSON if result.Content == nil { result.Content = []mcp.Content{&mcp.TextContent{Text: string(outJSON)}} } return result, nil })) return nil } func logToolCall(logger *slog.Logger, toolName string, handler mcp.ToolHandler) mcp.ToolHandler { if logger == nil { return handler } return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { start := time.Now() attrs := []any{slog.String("tool", toolName)} if req != nil && req.Params != nil { attrs = append(attrs, slog.String("arguments", truncateArgs(req.Params.Arguments))) } logger.Info("mcp tool started", attrs...) result, err := handler(ctx, req) completionAttrs := append([]any{}, attrs...) completionAttrs = append(completionAttrs, slog.String("duration", formatLogDuration(time.Since(start)))) if err != nil { completionAttrs = append(completionAttrs, slog.String("error", err.Error())) logger.Error("mcp tool completed", completionAttrs...) return result, err } logger.Info("mcp tool completed", completionAttrs...) return result, nil } } func truncateArgs(args any) string { redacted, err := redactArguments(args) if err != nil { return "" } b, err := json.Marshal(redacted) if err != nil { return "" } if len(b) <= maxLoggedArgBytes { return string(b) } return string(b[:maxLoggedArgBytes]) + fmt.Sprintf("… (%d bytes total)", len(b)) } func redactArguments(args any) (any, error) { if args == nil { return nil, nil } b, err := json.Marshal(args) if err != nil { return nil, err } var decoded any if err := json.Unmarshal(b, &decoded); err != nil { return nil, err } return redactJSONValue("", decoded), nil } func redactJSONValue(key string, value any) any { if isSensitiveToolArgKey(key) { return "" } switch typed := value.(type) { case map[string]any: redacted := make(map[string]any, len(typed)) for childKey, childValue := range typed { redacted[childKey] = redactJSONValue(childKey, childValue) } return redacted case []any: redacted := make([]any, len(typed)) for i, item := range typed { redacted[i] = redactJSONValue("", item) } return redacted default: return value } } func isSensitiveToolArgKey(key string) bool { _, ok := sensitiveToolArgKeys[strings.ToLower(strings.TrimSpace(key))] return ok } func formatLogDuration(d time.Duration) string { if d < 0 { d = -d } totalMilliseconds := d.Milliseconds() minutes := totalMilliseconds / 60000 seconds := (totalMilliseconds / 1000) % 60 milliseconds := totalMilliseconds % 1000 return fmt.Sprintf("%02d:%02d:%03d", minutes, seconds, milliseconds) } func normalizeObjectSchema(schema *jsonschema.Schema) { if schema != nil && schema.Type == "object" && schema.Properties == nil { schema.Properties = map[string]*jsonschema.Schema{} } } func setToolSchemas[In any, Out any](tool *mcp.Tool) error { if tool.InputSchema == nil { inputSchema, err := jsonschema.For[In](toolSchemaOptions) if err != nil { return fmt.Errorf("infer input schema: %w", err) } normalizeObjectSchema(inputSchema) tool.InputSchema = inputSchema } if tool.OutputSchema == nil { outputSchema, err := jsonschema.For[Out](toolSchemaOptions) if err != nil { return fmt.Errorf("infer output schema: %w", err) } tool.OutputSchema = outputSchema } return nil } func resolveToolSchema(schema any) (*jsonschema.Resolved, error) { if schema == nil { return nil, nil } if typed, ok := schema.(*jsonschema.Schema); ok { return typed.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) } var remarshalTarget *jsonschema.Schema if err := remarshalJSON(schema, &remarshalTarget); err != nil { return nil, err } return remarshalTarget.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) } func applyToolSchema(data json.RawMessage, resolved *jsonschema.Resolved) (json.RawMessage, error) { if resolved == nil { return data, nil } value := make(map[string]any) if len(data) > 0 { if err := json.Unmarshal(data, &value); err != nil { return nil, fmt.Errorf("unmarshaling arguments: %w", err) } } if err := resolved.ApplyDefaults(&value); err != nil { return nil, fmt.Errorf("applying schema defaults: %w", err) } if err := resolved.Validate(&value); err != nil { return nil, err } normalized, err := json.Marshal(value) if err != nil { return nil, fmt.Errorf("marshalling with defaults: %w", err) } return normalized, nil } func remarshalJSON(src any, dst any) error { b, err := json.Marshal(src) if err != nil { return err } return json.Unmarshal(b, dst) } func normalizeTypedNil[T any](value T) any { rt := reflect.TypeFor[T]() if rt.Kind() == reflect.Pointer { var zero T if any(value) == any(zero) { return reflect.Zero(rt.Elem()).Interface() } } return any(value) } func invalidArgumentsError(err error) error { detail := err.Error() field := validationField(detail) payload := mcperrors.Data{ Type: mcperrors.TypeInvalidArguments, Detail: detail, } if field != "" { payload.Field = field payload.Hint = "check the " + field + " argument" } data, marshalErr := json.Marshal(payload) if marshalErr != nil { return &jsonrpc.Error{ Code: jsonrpc.CodeInvalidParams, Message: "invalid tool arguments", } } return &jsonrpc.Error{ Code: jsonrpc.CodeInvalidParams, Message: "invalid tool arguments", Data: data, } } func validationField(detail string) string { if matches := missingPropertyPattern.FindStringSubmatch(detail); len(matches) == 2 { return matches[1] } if matches := propertyPathPattern.FindStringSubmatch(detail); len(matches) == 2 { return matches[1] } if matches := structFieldPattern.FindStringSubmatch(detail); len(matches) == 2 { return strings.ToLower(matches[1]) } return "" }