package mcpserver import ( "context" "net/http/httptest" "testing" "time" "github.com/modelcontextprotocol/go-sdk/jsonrpc" "github.com/modelcontextprotocol/go-sdk/mcp" "git.warky.dev/wdevs/amcs/internal/buildinfo" "git.warky.dev/wdevs/amcs/internal/config" "git.warky.dev/wdevs/amcs/internal/mcperrors" "git.warky.dev/wdevs/amcs/internal/tools" ) func TestStreamableHTTPReturnsStructuredToolErrors(t *testing.T) { handler, err := New(config.MCPConfig{ ServerName: "test", Version: "0.0.1", SessionTimeout: time.Minute, }, nil, streamableTestToolSet(), nil) if err != nil { t.Fatalf("mcpserver.New() error = %v", err) } httpServer := httptest.NewServer(handler) defer httpServer.Close() client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "0.0.1"}, nil) cs, err := client.Connect(context.Background(), &mcp.StreamableClientTransport{Endpoint: httpServer.URL}, nil) if err != nil { t.Fatalf("connect client: %v", err) } defer func() { _ = cs.Close() }() t.Run("schema_validation", func(t *testing.T) { _, err := cs.CallTool(context.Background(), &mcp.CallToolParams{ Name: "create_project", Arguments: map[string]any{}, }) if err == nil { t.Fatal("CallTool(create_project) error = nil, want error") } rpcErr, data := requireWireError(t, err) if rpcErr.Code != jsonrpc.CodeInvalidParams { t.Fatalf("create_project code = %d, want %d", rpcErr.Code, jsonrpc.CodeInvalidParams) } if data.Type != mcperrors.TypeInvalidArguments { t.Fatalf("create_project data.type = %q, want %q", data.Type, mcperrors.TypeInvalidArguments) } if data.Field != "name" { t.Fatalf("create_project data.field = %q, want %q", data.Field, "name") } }) t.Run("project_required", func(t *testing.T) { _, err := cs.CallTool(context.Background(), &mcp.CallToolParams{ Name: "get_project_context", Arguments: map[string]any{}, }) if err == nil { t.Fatal("CallTool(get_project_context) error = nil, want error") } rpcErr, data := requireWireError(t, err) if rpcErr.Code != mcperrors.CodeProjectRequired { t.Fatalf("get_project_context code = %d, want %d", rpcErr.Code, mcperrors.CodeProjectRequired) } if data.Type != mcperrors.TypeProjectRequired { t.Fatalf("get_project_context data.type = %q, want %q", data.Type, mcperrors.TypeProjectRequired) } if data.Field != "project" { t.Fatalf("get_project_context data.field = %q, want %q", data.Field, "project") } }) t.Run("version_info", func(t *testing.T) { result, err := cs.CallTool(context.Background(), &mcp.CallToolParams{ Name: "get_version_info", Arguments: map[string]any{}, }) if err != nil { t.Fatalf("CallTool(get_version_info) error = %v", err) } got, ok := result.StructuredContent.(map[string]any) if !ok { t.Fatalf("structured content type = %T, want map[string]any", result.StructuredContent) } if got["server_name"] != "test" { t.Fatalf("server_name = %#v, want %q", got["server_name"], "test") } if got["version"] != "0.0.1" { t.Fatalf("version = %#v, want %q", got["version"], "0.0.1") } if got["tag_name"] != "v0.0.1" { t.Fatalf("tag_name = %#v, want %q", got["tag_name"], "v0.0.1") } if got["build_date"] != "2026-03-31T00:00:00Z" { t.Fatalf("build_date = %#v, want %q", got["build_date"], "2026-03-31T00:00:00Z") } }) } func streamableTestToolSet() ToolSet { return ToolSet{ Version: tools.NewVersionTool("test", buildinfo.Info{Version: "0.0.1", TagName: "v0.0.1", Commit: "test", BuildDate: "2026-03-31T00:00:00Z"}), Capture: new(tools.CaptureTool), Search: new(tools.SearchTool), List: new(tools.ListTool), Stats: new(tools.StatsTool), Get: new(tools.GetTool), Update: new(tools.UpdateTool), Delete: new(tools.DeleteTool), Archive: new(tools.ArchiveTool), Projects: new(tools.ProjectsTool), Context: new(tools.ContextTool), Recall: new(tools.RecallTool), Summarize: new(tools.SummarizeTool), Links: new(tools.LinksTool), Files: new(tools.FilesTool), Backfill: new(tools.BackfillTool), Reparse: new(tools.ReparseMetadataTool), RetryMetadata: new(tools.RetryMetadataTool), Maintenance: new(tools.MaintenanceTool), Skills: new(tools.SkillsTool), } }