351 lines
8.6 KiB
Go
351 lines
8.6 KiB
Go
package mcpserver
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"reflect"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/jsonschema-go/jsonschema"
|
|
"github.com/google/uuid"
|
|
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
|
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
|
|
|
"git.warky.dev/wdevs/amcs/internal/mcperrors"
|
|
)
|
|
|
|
const maxLoggedArgBytes = 512
|
|
|
|
var sensitiveToolArgKeys = map[string]struct{}{
|
|
"client_secret": {},
|
|
"content": {},
|
|
"content_base64": {},
|
|
"content_path": {},
|
|
"email": {},
|
|
"notes": {},
|
|
"phone": {},
|
|
"value": {},
|
|
}
|
|
|
|
var toolSchemaOptions = &jsonschema.ForOptions{
|
|
TypeSchemas: map[reflect.Type]*jsonschema.Schema{
|
|
reflect.TypeFor[uuid.UUID](): {
|
|
Type: "string",
|
|
Format: "uuid",
|
|
},
|
|
},
|
|
}
|
|
|
|
var (
|
|
missingPropertyPattern = regexp.MustCompile(`missing properties: \["([^"]+)"\]`)
|
|
propertyPathPattern = regexp.MustCompile(`validating /properties/([^:]+):`)
|
|
structFieldPattern = regexp.MustCompile(`\.([A-Za-z0-9_]+) of type`)
|
|
)
|
|
|
|
func addTool[In any, Out any](server *mcp.Server, logger *slog.Logger, tool *mcp.Tool, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) error {
|
|
if err := setToolSchemas[In, Out](tool); err != nil {
|
|
return fmt.Errorf("configure MCP tool %q schemas: %w", tool.Name, err)
|
|
}
|
|
|
|
inputResolved, err := resolveToolSchema(tool.InputSchema)
|
|
if err != nil {
|
|
return fmt.Errorf("resolve MCP tool %q input schema: %w", tool.Name, err)
|
|
}
|
|
outputResolved, err := resolveToolSchema(tool.OutputSchema)
|
|
if err != nil {
|
|
return fmt.Errorf("resolve MCP tool %q output schema: %w", tool.Name, err)
|
|
}
|
|
|
|
server.AddTool(tool, logToolCall(logger, tool.Name, func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
var input json.RawMessage
|
|
if req != nil && req.Params != nil && req.Params.Arguments != nil {
|
|
input = req.Params.Arguments
|
|
}
|
|
|
|
input, err = applyToolSchema(input, inputResolved)
|
|
if err != nil {
|
|
return nil, invalidArgumentsError(err)
|
|
}
|
|
|
|
var in In
|
|
if input != nil {
|
|
if err := json.Unmarshal(input, &in); err != nil {
|
|
return nil, invalidArgumentsError(err)
|
|
}
|
|
}
|
|
|
|
result, out, err := handler(ctx, req, in)
|
|
if err != nil {
|
|
if wireErr, ok := err.(*jsonrpc.Error); ok {
|
|
return nil, wireErr
|
|
}
|
|
var errResult mcp.CallToolResult
|
|
errResult.SetError(err)
|
|
return &errResult, nil
|
|
}
|
|
|
|
if result == nil {
|
|
result = &mcp.CallToolResult{}
|
|
}
|
|
|
|
if outputResolved == nil {
|
|
return result, nil
|
|
}
|
|
|
|
outValue := normalizeTypedNil[Out](out)
|
|
if outValue == nil {
|
|
return result, nil
|
|
}
|
|
|
|
outBytes, err := json.Marshal(outValue)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshaling output: %w", err)
|
|
}
|
|
outJSON, err := applyToolSchema(json.RawMessage(outBytes), outputResolved)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("validating tool output: %w", err)
|
|
}
|
|
|
|
result.StructuredContent = outJSON
|
|
if result.Content == nil {
|
|
result.Content = []mcp.Content{&mcp.TextContent{Text: string(outJSON)}}
|
|
}
|
|
|
|
return result, nil
|
|
}))
|
|
return nil
|
|
}
|
|
|
|
func logToolCall(logger *slog.Logger, toolName string, handler mcp.ToolHandler) mcp.ToolHandler {
|
|
if logger == nil {
|
|
return handler
|
|
}
|
|
|
|
return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
start := time.Now()
|
|
attrs := []any{slog.String("tool", toolName)}
|
|
if req != nil && req.Params != nil {
|
|
attrs = append(attrs, slog.String("arguments", truncateArgs(req.Params.Arguments)))
|
|
}
|
|
|
|
logger.Info("mcp tool started", attrs...)
|
|
result, err := handler(ctx, req)
|
|
|
|
completionAttrs := append([]any{}, attrs...)
|
|
completionAttrs = append(completionAttrs, slog.String("duration", formatLogDuration(time.Since(start))))
|
|
if err != nil {
|
|
completionAttrs = append(completionAttrs, slog.String("error", err.Error()))
|
|
logger.Error("mcp tool completed", completionAttrs...)
|
|
return result, err
|
|
}
|
|
|
|
logger.Info("mcp tool completed", completionAttrs...)
|
|
return result, nil
|
|
}
|
|
}
|
|
|
|
func truncateArgs(args any) string {
|
|
redacted, err := redactArguments(args)
|
|
if err != nil {
|
|
return "<unserializable>"
|
|
}
|
|
b, err := json.Marshal(redacted)
|
|
if err != nil {
|
|
return "<unserializable>"
|
|
}
|
|
if len(b) <= maxLoggedArgBytes {
|
|
return string(b)
|
|
}
|
|
return string(b[:maxLoggedArgBytes]) + fmt.Sprintf("… (%d bytes total)", len(b))
|
|
}
|
|
|
|
func redactArguments(args any) (any, error) {
|
|
if args == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
b, err := json.Marshal(args)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var decoded any
|
|
if err := json.Unmarshal(b, &decoded); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return redactJSONValue("", decoded), nil
|
|
}
|
|
|
|
func redactJSONValue(key string, value any) any {
|
|
if isSensitiveToolArgKey(key) {
|
|
return "<redacted>"
|
|
}
|
|
|
|
switch typed := value.(type) {
|
|
case map[string]any:
|
|
redacted := make(map[string]any, len(typed))
|
|
for childKey, childValue := range typed {
|
|
redacted[childKey] = redactJSONValue(childKey, childValue)
|
|
}
|
|
return redacted
|
|
case []any:
|
|
redacted := make([]any, len(typed))
|
|
for i, item := range typed {
|
|
redacted[i] = redactJSONValue("", item)
|
|
}
|
|
return redacted
|
|
default:
|
|
return value
|
|
}
|
|
}
|
|
|
|
func isSensitiveToolArgKey(key string) bool {
|
|
_, ok := sensitiveToolArgKeys[strings.ToLower(strings.TrimSpace(key))]
|
|
return ok
|
|
}
|
|
|
|
func formatLogDuration(d time.Duration) string {
|
|
if d < 0 {
|
|
d = -d
|
|
}
|
|
|
|
totalMilliseconds := d.Milliseconds()
|
|
minutes := totalMilliseconds / 60000
|
|
seconds := (totalMilliseconds / 1000) % 60
|
|
milliseconds := totalMilliseconds % 1000
|
|
return fmt.Sprintf("%02d:%02d:%03d", minutes, seconds, milliseconds)
|
|
}
|
|
|
|
func normalizeObjectSchema(schema *jsonschema.Schema) {
|
|
if schema != nil && schema.Type == "object" && schema.Properties == nil {
|
|
schema.Properties = map[string]*jsonschema.Schema{}
|
|
}
|
|
}
|
|
|
|
func setToolSchemas[In any, Out any](tool *mcp.Tool) error {
|
|
if tool.InputSchema == nil {
|
|
inputSchema, err := jsonschema.For[In](toolSchemaOptions)
|
|
if err != nil {
|
|
return fmt.Errorf("infer input schema: %w", err)
|
|
}
|
|
normalizeObjectSchema(inputSchema)
|
|
tool.InputSchema = inputSchema
|
|
}
|
|
|
|
if tool.OutputSchema == nil {
|
|
outputSchema, err := jsonschema.For[Out](toolSchemaOptions)
|
|
if err != nil {
|
|
return fmt.Errorf("infer output schema: %w", err)
|
|
}
|
|
tool.OutputSchema = outputSchema
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func resolveToolSchema(schema any) (*jsonschema.Resolved, error) {
|
|
if schema == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
if typed, ok := schema.(*jsonschema.Schema); ok {
|
|
return typed.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
|
|
}
|
|
|
|
var remarshalTarget *jsonschema.Schema
|
|
if err := remarshalJSON(schema, &remarshalTarget); err != nil {
|
|
return nil, err
|
|
}
|
|
return remarshalTarget.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
|
|
}
|
|
|
|
func applyToolSchema(data json.RawMessage, resolved *jsonschema.Resolved) (json.RawMessage, error) {
|
|
if resolved == nil {
|
|
return data, nil
|
|
}
|
|
|
|
value := make(map[string]any)
|
|
if len(data) > 0 {
|
|
if err := json.Unmarshal(data, &value); err != nil {
|
|
return nil, fmt.Errorf("unmarshaling arguments: %w", err)
|
|
}
|
|
}
|
|
if err := resolved.ApplyDefaults(&value); err != nil {
|
|
return nil, fmt.Errorf("applying schema defaults: %w", err)
|
|
}
|
|
if err := resolved.Validate(&value); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
normalized, err := json.Marshal(value)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshalling with defaults: %w", err)
|
|
}
|
|
return normalized, nil
|
|
}
|
|
|
|
func remarshalJSON(src any, dst any) error {
|
|
b, err := json.Marshal(src)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return json.Unmarshal(b, dst)
|
|
}
|
|
|
|
func normalizeTypedNil[T any](value T) any {
|
|
rt := reflect.TypeFor[T]()
|
|
if rt.Kind() == reflect.Pointer {
|
|
var zero T
|
|
if any(value) == any(zero) {
|
|
return reflect.Zero(rt.Elem()).Interface()
|
|
}
|
|
}
|
|
return any(value)
|
|
}
|
|
|
|
func invalidArgumentsError(err error) error {
|
|
detail := err.Error()
|
|
field := validationField(detail)
|
|
|
|
payload := mcperrors.Data{
|
|
Type: mcperrors.TypeInvalidArguments,
|
|
Detail: detail,
|
|
}
|
|
if field != "" {
|
|
payload.Field = field
|
|
payload.Hint = "check the " + field + " argument"
|
|
}
|
|
|
|
data, marshalErr := json.Marshal(payload)
|
|
if marshalErr != nil {
|
|
return &jsonrpc.Error{
|
|
Code: jsonrpc.CodeInvalidParams,
|
|
Message: "invalid tool arguments",
|
|
}
|
|
}
|
|
|
|
return &jsonrpc.Error{
|
|
Code: jsonrpc.CodeInvalidParams,
|
|
Message: "invalid tool arguments",
|
|
Data: data,
|
|
}
|
|
}
|
|
|
|
func validationField(detail string) string {
|
|
if matches := missingPropertyPattern.FindStringSubmatch(detail); len(matches) == 2 {
|
|
return matches[1]
|
|
}
|
|
if matches := propertyPathPattern.FindStringSubmatch(detail); len(matches) == 2 {
|
|
return matches[1]
|
|
}
|
|
if matches := structFieldPattern.FindStringSubmatch(detail); len(matches) == 2 {
|
|
return strings.ToLower(matches[1])
|
|
}
|
|
return ""
|
|
}
|