78 lines
2.3 KiB
Go
78 lines
2.3 KiB
Go
package mcpserver
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"reflect"
|
|
"time"
|
|
|
|
"github.com/google/jsonschema-go/jsonschema"
|
|
"github.com/google/uuid"
|
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
|
)
|
|
|
|
var toolSchemaOptions = &jsonschema.ForOptions{
|
|
TypeSchemas: map[reflect.Type]*jsonschema.Schema{
|
|
reflect.TypeFor[uuid.UUID](): {
|
|
Type: "string",
|
|
Format: "uuid",
|
|
},
|
|
},
|
|
}
|
|
|
|
func addTool[In any, Out any](server *mcp.Server, logger *slog.Logger, tool *mcp.Tool, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) {
|
|
if err := setToolSchemas[In, Out](tool); err != nil {
|
|
panic(fmt.Sprintf("configure MCP tool %q schemas: %v", tool.Name, err))
|
|
}
|
|
mcp.AddTool(server, tool, logToolCall(logger, tool.Name, handler))
|
|
}
|
|
|
|
func logToolCall[In any, Out any](logger *slog.Logger, toolName string, handler func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error)) func(context.Context, *mcp.CallToolRequest, In) (*mcp.CallToolResult, Out, error) {
|
|
if logger == nil {
|
|
return handler
|
|
}
|
|
|
|
return func(ctx context.Context, req *mcp.CallToolRequest, in In) (*mcp.CallToolResult, Out, error) {
|
|
start := time.Now()
|
|
attrs := []any{slog.String("tool", toolName)}
|
|
if req != nil && req.Params != nil {
|
|
attrs = append(attrs, slog.Any("arguments", req.Params.Arguments))
|
|
}
|
|
|
|
logger.Info("mcp tool started", attrs...)
|
|
result, out, err := handler(ctx, req, in)
|
|
|
|
completionAttrs := append([]any{}, attrs...)
|
|
completionAttrs = append(completionAttrs, slog.Duration("duration", time.Since(start)))
|
|
if err != nil {
|
|
completionAttrs = append(completionAttrs, slog.String("error", err.Error()))
|
|
logger.Error("mcp tool completed", completionAttrs...)
|
|
return result, out, err
|
|
}
|
|
|
|
logger.Info("mcp tool completed", completionAttrs...)
|
|
return result, out, nil
|
|
}
|
|
}
|
|
|
|
func setToolSchemas[In any, Out any](tool *mcp.Tool) error {
|
|
if tool.InputSchema == nil {
|
|
inputSchema, err := jsonschema.For[In](toolSchemaOptions)
|
|
if err != nil {
|
|
return fmt.Errorf("infer input schema: %w", err)
|
|
}
|
|
tool.InputSchema = inputSchema
|
|
}
|
|
|
|
if tool.OutputSchema == nil {
|
|
outputSchema, err := jsonschema.For[Out](toolSchemaOptions)
|
|
if err != nil {
|
|
return fmt.Errorf("infer output schema: %w", err)
|
|
}
|
|
tool.OutputSchema = outputSchema
|
|
}
|
|
|
|
return nil
|
|
}
|