248 lines
6.9 KiB
Go
248 lines
6.9 KiB
Go
package mcpserver
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/google/jsonschema-go/jsonschema"
|
|
"github.com/modelcontextprotocol/go-sdk/mcp"
|
|
|
|
"git.warky.dev/wdevs/amcs/internal/tools"
|
|
)
|
|
|
|
func TestSetToolSchemasAddsEmptyPropertiesForNoArgInput(t *testing.T) {
|
|
type noArgInput struct{}
|
|
type anyOutput struct{}
|
|
|
|
tool := &mcp.Tool{Name: "no_args"}
|
|
if err := setToolSchemas[noArgInput, anyOutput](tool); err != nil {
|
|
t.Fatalf("set tool schemas: %v", err)
|
|
}
|
|
|
|
schema, ok := tool.InputSchema.(*jsonschema.Schema)
|
|
if !ok {
|
|
t.Fatalf("input schema type = %T, want *jsonschema.Schema", tool.InputSchema)
|
|
}
|
|
if schema.Properties == nil {
|
|
t.Fatal("input schema missing properties: strict MCP clients require properties:{} on object schemas")
|
|
}
|
|
}
|
|
|
|
func TestSetToolSchemasUsesStringUUIDsInListOutput(t *testing.T) {
|
|
tool := &mcp.Tool{Name: "list_thoughts"}
|
|
|
|
if err := setToolSchemas[tools.ListInput, tools.ListOutput](tool); err != nil {
|
|
t.Fatalf("set tool schemas: %v", err)
|
|
}
|
|
|
|
schema, ok := tool.OutputSchema.(*jsonschema.Schema)
|
|
if !ok {
|
|
t.Fatalf("output schema type = %T, want *jsonschema.Schema", tool.OutputSchema)
|
|
}
|
|
|
|
thoughtsSchema := schema.Properties["thoughts"]
|
|
if thoughtsSchema == nil {
|
|
t.Fatal("missing thoughts schema")
|
|
}
|
|
if thoughtsSchema.Items == nil {
|
|
t.Fatal("missing thoughts item schema")
|
|
}
|
|
|
|
idSchema := thoughtsSchema.Items.Properties["id"]
|
|
if idSchema == nil {
|
|
t.Fatal("missing id schema")
|
|
}
|
|
if idSchema.Type != "string" {
|
|
t.Fatalf("id schema type = %q, want %q", idSchema.Type, "string")
|
|
}
|
|
if idSchema.Format != "uuid" {
|
|
t.Fatalf("id schema format = %q, want %q", idSchema.Format, "uuid")
|
|
}
|
|
}
|
|
|
|
func TestTruncateArgsRedactsSensitiveFields(t *testing.T) {
|
|
args := json.RawMessage(`{
|
|
"query": "todo from yesterday",
|
|
"content": "private thought body",
|
|
"content_base64": "c2VjcmV0LWJpbmFyeQ==",
|
|
"email": "user@example.com",
|
|
"nested": {
|
|
"notes": "private note",
|
|
"phone": "+1-555-0100",
|
|
"keep": "visible"
|
|
}
|
|
}`)
|
|
|
|
got := truncateArgs(args)
|
|
|
|
for _, secret := range []string{
|
|
"private thought body",
|
|
"c2VjcmV0LWJpbmFyeQ==",
|
|
"user@example.com",
|
|
"private note",
|
|
"+1-555-0100",
|
|
} {
|
|
if strings.Contains(got, secret) {
|
|
t.Fatalf("truncateArgs leaked sensitive value %q in %s", secret, got)
|
|
}
|
|
}
|
|
|
|
for _, expected := range []string{
|
|
`"query":"todo from yesterday"`,
|
|
`"keep":"visible"`,
|
|
`"content":"\u003credacted\u003e"`,
|
|
`"content_base64":"\u003credacted\u003e"`,
|
|
`"email":"\u003credacted\u003e"`,
|
|
`"notes":"\u003credacted\u003e"`,
|
|
`"phone":"\u003credacted\u003e"`,
|
|
} {
|
|
if !strings.Contains(got, expected) {
|
|
t.Fatalf("truncateArgs(%s) missing %q", got, expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestAddToolReturnsSchemaErrorInsteadOfPanicking(t *testing.T) {
|
|
server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil)
|
|
|
|
err := addTool(server, nil, &mcp.Tool{Name: "broken"}, func(_ context.Context, _ *mcp.CallToolRequest, _ struct{}) (*mcp.CallToolResult, chan int, error) {
|
|
return nil, nil, nil
|
|
})
|
|
if err == nil {
|
|
t.Fatal("addTool() error = nil, want schema inference error")
|
|
}
|
|
if !strings.Contains(err.Error(), `configure MCP tool "broken" schemas`) {
|
|
t.Fatalf("addTool() error = %q, want tool context", err.Error())
|
|
}
|
|
}
|
|
|
|
func TestAddToolAppliesInputDefaultsAndSetsStructuredContent(t *testing.T) {
|
|
server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil)
|
|
|
|
type helloInput struct {
|
|
Name string `json:"name,omitempty"`
|
|
}
|
|
type helloOutput struct {
|
|
Message string `json:"message"`
|
|
}
|
|
|
|
tool := &mcp.Tool{
|
|
Name: "hello",
|
|
InputSchema: &jsonschema.Schema{
|
|
Type: "object",
|
|
Properties: map[string]*jsonschema.Schema{
|
|
"name": {
|
|
Type: "string",
|
|
Default: json.RawMessage(`"world"`),
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
var gotInput helloInput
|
|
if err := addTool(server, nil, tool, func(_ context.Context, _ *mcp.CallToolRequest, in helloInput) (*mcp.CallToolResult, helloOutput, error) {
|
|
gotInput = in
|
|
return nil, helloOutput{Message: "hello " + in.Name}, nil
|
|
}); err != nil {
|
|
t.Fatalf("addTool() error = %v", err)
|
|
}
|
|
|
|
ct, st := mcp.NewInMemoryTransports()
|
|
_, err := server.Connect(context.Background(), st, nil)
|
|
if err != nil {
|
|
t.Fatalf("connect server: %v", err)
|
|
}
|
|
|
|
client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "0.0.1"}, nil)
|
|
cs, err := client.Connect(context.Background(), ct, nil)
|
|
if err != nil {
|
|
t.Fatalf("connect client: %v", err)
|
|
}
|
|
defer func() {
|
|
_ = cs.Close()
|
|
}()
|
|
|
|
result, err := cs.CallTool(context.Background(), &mcp.CallToolParams{
|
|
Name: "hello",
|
|
Arguments: map[string]any{},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("CallTool(hello) error = %v", err)
|
|
}
|
|
|
|
if gotInput.Name != "world" {
|
|
t.Fatalf("handler input name = %q, want %q", gotInput.Name, "world")
|
|
}
|
|
|
|
gotStructured, ok := result.StructuredContent.(map[string]any)
|
|
if !ok {
|
|
t.Fatalf("structured content type = %T, want map[string]any", result.StructuredContent)
|
|
}
|
|
if gotStructured["message"] != "hello world" {
|
|
t.Fatalf("structured content message = %#v, want %q", gotStructured["message"], "hello world")
|
|
}
|
|
|
|
if len(result.Content) != 1 {
|
|
t.Fatalf("content count = %d, want 1", len(result.Content))
|
|
}
|
|
|
|
textContent, ok := result.Content[0].(*mcp.TextContent)
|
|
if !ok {
|
|
t.Fatalf("content[0] type = %T, want *mcp.TextContent", result.Content[0])
|
|
}
|
|
if textContent.Text != `{"message":"hello world"}` {
|
|
t.Fatalf("content[0].Text = %q, want %q", textContent.Text, `{"message":"hello world"}`)
|
|
}
|
|
}
|
|
|
|
func TestAddToolWrapsRegularErrorsInToolResults(t *testing.T) {
|
|
server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil)
|
|
|
|
toolErr := errors.New("boom")
|
|
if err := addTool(server, nil, &mcp.Tool{Name: "explode"}, func(_ context.Context, _ *mcp.CallToolRequest, _ struct{}) (*mcp.CallToolResult, struct{}, error) {
|
|
return nil, struct{}{}, toolErr
|
|
}); err != nil {
|
|
t.Fatalf("addTool() error = %v", err)
|
|
}
|
|
|
|
ct, st := mcp.NewInMemoryTransports()
|
|
_, err := server.Connect(context.Background(), st, nil)
|
|
if err != nil {
|
|
t.Fatalf("connect server: %v", err)
|
|
}
|
|
|
|
client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "0.0.1"}, nil)
|
|
cs, err := client.Connect(context.Background(), ct, nil)
|
|
if err != nil {
|
|
t.Fatalf("connect client: %v", err)
|
|
}
|
|
defer func() {
|
|
_ = cs.Close()
|
|
}()
|
|
|
|
result, err := cs.CallTool(context.Background(), &mcp.CallToolParams{Name: "explode"})
|
|
if err != nil {
|
|
t.Fatalf("CallTool(explode) error = %v, want nil transport error", err)
|
|
}
|
|
if !result.IsError {
|
|
t.Fatal("CallTool(explode) IsError = false, want true")
|
|
}
|
|
if result.StructuredContent != nil {
|
|
t.Fatalf("structured content = %#v, want nil", result.StructuredContent)
|
|
}
|
|
if len(result.Content) != 1 {
|
|
t.Fatalf("content count = %d, want 1", len(result.Content))
|
|
}
|
|
|
|
textContent, ok := result.Content[0].(*mcp.TextContent)
|
|
if !ok {
|
|
t.Fatalf("content[0] type = %T, want *mcp.TextContent", result.Content[0])
|
|
}
|
|
if textContent.Text != toolErr.Error() {
|
|
t.Fatalf("content[0].Text = %q, want %q", textContent.Text, toolErr.Error())
|
|
}
|
|
}
|