Files
amcs/internal/mcpserver/schema.go

104 lines
2.9 KiB
Go

package mcpserver
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"reflect"
"time"
"github.com/google/jsonschema-go/jsonschema"
"github.com/google/uuid"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
const maxLoggedArgBytes = 512
var toolSchemaOptions = &jsonschema.ForOptions{
TypeSchemas: map[reflect.Type]*jsonschema.Schema{
reflect.TypeFor[uuid.UUID](): {
Type: "string",
Format: "uuid",
},
},
}
func addTool[In any, Out any](server *mcp.Server, logger *slog.Logger, tool *mcp.Tool, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) {
if err := setToolSchemas[In, Out](tool); err != nil {
panic(fmt.Sprintf("configure MCP tool %q schemas: %v", tool.Name, err))
}
mcp.AddTool(server, tool, logToolCall(logger, tool.Name, handler))
}
func logToolCall[In any, Out any](logger *slog.Logger, toolName string, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error) {
if logger == nil {
return handler
}
return func(ctx context.Context, req *mcp.CallToolRequest, in In) (*mcp.CallToolResult, Out, error) {
start := time.Now()
attrs := []any{slog.String("tool", toolName)}
if req != nil && req.Params != nil {
attrs = append(attrs, slog.String("arguments", truncateArgs(req.Params.Arguments)))
}
logger.Info("mcp tool started", attrs...)
result, out, err := handler(ctx, req, in)
completionAttrs := append([]any{}, attrs...)
completionAttrs = append(completionAttrs, slog.String("duration", formatLogDuration(time.Since(start))))
if err != nil {
completionAttrs = append(completionAttrs, slog.String("error", err.Error()))
logger.Error("mcp tool completed", completionAttrs...)
return result, out, err
}
logger.Info("mcp tool completed", completionAttrs...)
return result, out, nil
}
}
func truncateArgs(args any) string {
b, err := json.Marshal(args)
if err != nil {
return "<unserializable>"
}
if len(b) <= maxLoggedArgBytes {
return string(b)
}
return string(b[:maxLoggedArgBytes]) + fmt.Sprintf("… (%d bytes total)", len(b))
}
func formatLogDuration(d time.Duration) string {
if d < 0 {
d = -d
}
totalMilliseconds := d.Milliseconds()
minutes := totalMilliseconds / 60000
seconds := (totalMilliseconds / 1000) % 60
milliseconds := totalMilliseconds % 1000
return fmt.Sprintf("%02d:%02d:%03d", minutes, seconds, milliseconds)
}
func setToolSchemas[In any, Out any](tool *mcp.Tool) error {
if tool.InputSchema == nil {
inputSchema, err := jsonschema.For[In](toolSchemaOptions)
if err != nil {
return fmt.Errorf("infer input schema: %w", err)
}
tool.InputSchema = inputSchema
}
if tool.OutputSchema == nil {
outputSchema, err := jsonschema.For[Out](toolSchemaOptions)
if err != nil {
return fmt.Errorf("infer output schema: %w", err)
}
tool.OutputSchema = outputSchema
}
return nil
}