package mcpserver import ( "context" "encoding/json" "errors" "testing" "github.com/modelcontextprotocol/go-sdk/jsonrpc" "github.com/modelcontextprotocol/go-sdk/mcp" "git.warky.dev/wdevs/amcs/internal/config" "git.warky.dev/wdevs/amcs/internal/mcperrors" "git.warky.dev/wdevs/amcs/internal/session" "git.warky.dev/wdevs/amcs/internal/tools" ) func TestToolValidationErrorsRoundTripAsStructuredJSONRPC(t *testing.T) { server := mcp.NewServer(&mcp.Implementation{Name: "test", Version: "0.0.1"}, nil) projects := tools.NewProjectsTool(nil, session.NewActiveProjects()) contextTool := tools.NewContextTool(nil, nil, toolsSearchConfig(), session.NewActiveProjects()) if err := addTool(server, nil, &mcp.Tool{ Name: "create_project", Description: "Create a named project container for thoughts.", }, projects.Create); err != nil { t.Fatalf("add create_project tool: %v", err) } if err := addTool(server, nil, &mcp.Tool{ Name: "get_project_context", Description: "Get recent and semantic context for a project.", }, contextTool.Handle); err != nil { t.Fatalf("add get_project_context tool: %v", err) } ct, st := mcp.NewInMemoryTransports() _, err := server.Connect(context.Background(), st, nil) if err != nil { t.Fatalf("connect server: %v", err) } client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "0.0.1"}, nil) cs, err := client.Connect(context.Background(), ct, nil) if err != nil { t.Fatalf("connect client: %v", err) } defer func() { _ = cs.Close() }() t.Run("required_field", func(t *testing.T) { _, err := cs.CallTool(context.Background(), &mcp.CallToolParams{ Name: "create_project", Arguments: map[string]any{"name": ""}, }) if err == nil { t.Fatal("CallTool(create_project) error = nil, want error") } rpcErr, data := requireWireError(t, err) if rpcErr.Code != jsonrpc.CodeInvalidParams { t.Fatalf("create_project code = %d, want %d", rpcErr.Code, jsonrpc.CodeInvalidParams) } if data.Type != mcperrors.TypeInvalidInput { t.Fatalf("create_project data.type = %q, want %q", data.Type, mcperrors.TypeInvalidInput) } if data.Field != "name" { t.Fatalf("create_project data.field = %q, want %q", data.Field, "name") } }) t.Run("schema_required_field", func(t *testing.T) { _, err := cs.CallTool(context.Background(), &mcp.CallToolParams{ Name: "create_project", Arguments: map[string]any{}, }) if err == nil { t.Fatal("CallTool(create_project missing field) error = nil, want error") } rpcErr, data := requireWireError(t, err) if rpcErr.Code != jsonrpc.CodeInvalidParams { t.Fatalf("create_project schema code = %d, want %d", rpcErr.Code, jsonrpc.CodeInvalidParams) } if data.Type != mcperrors.TypeInvalidArguments { t.Fatalf("create_project schema data.type = %q, want %q", data.Type, mcperrors.TypeInvalidArguments) } if data.Field != "name" { t.Fatalf("create_project schema data.field = %q, want %q", data.Field, "name") } if data.Detail == "" { t.Fatal("create_project schema data.detail = empty, want validation detail") } }) t.Run("project_required", func(t *testing.T) { _, err := cs.CallTool(context.Background(), &mcp.CallToolParams{ Name: "get_project_context", Arguments: map[string]any{}, }) if err == nil { t.Fatal("CallTool(get_project_context) error = nil, want error") } rpcErr, data := requireWireError(t, err) if rpcErr.Code != mcperrors.CodeProjectRequired { t.Fatalf("get_project_context code = %d, want %d", rpcErr.Code, mcperrors.CodeProjectRequired) } if data.Type != mcperrors.TypeProjectRequired { t.Fatalf("get_project_context data.type = %q, want %q", data.Type, mcperrors.TypeProjectRequired) } if data.Field != "project" { t.Fatalf("get_project_context data.field = %q, want %q", data.Field, "project") } if data.Hint == "" { t.Fatal("get_project_context data.hint = empty, want guidance") } }) } type wireErrorData struct { Type string `json:"type"` Field string `json:"field,omitempty"` Fields []string `json:"fields,omitempty"` Detail string `json:"detail,omitempty"` Hint string `json:"hint,omitempty"` } func requireWireError(t *testing.T, err error) (*jsonrpc.Error, wireErrorData) { t.Helper() var rpcErr *jsonrpc.Error if !errors.As(err, &rpcErr) { t.Fatalf("error type = %T, want *jsonrpc.Error", err) } var data wireErrorData if len(rpcErr.Data) > 0 { if unmarshalErr := json.Unmarshal(rpcErr.Data, &data); unmarshalErr != nil { t.Fatalf("unmarshal wire error data: %v", unmarshalErr) } } return rpcErr, data } func toolsSearchConfig() config.SearchConfig { return config.SearchConfig{ DefaultLimit: 10, MaxLimit: 50, DefaultThreshold: 0.7, } }