1# mypy: allow-untyped-defs 2 3import enum 4from typing import Tuple 5 6import torch 7import torch.distributed.rpc as rpc 8import torch.testing._internal.dist_utils as dist_utils 9from torch import Tensor, nn 10from torch._jit_internal import Future 11from torch.distributed.nn import RemoteModule 12from torch.distributed.nn.api.remote_module import _REMOTE_MODULE_PICKLED_ATTRIBUTES 13from torch.distributed.nn.api.remote_module import _RemoteModule 14from torch.testing._internal.common_distributed import skip_if_lt_x_gpu 15from torch.testing._internal.common_utils import TemporaryFileName 16from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( 17 RpcAgentTestFixture, 18) 19 20 21_PARAM_VAL = torch.nn.Parameter(torch.ones(1)) 22 23 24# RPC handler for querying the device on the destination worker. 25def remote_device(module_rref): 26 for param in module_rref.local_value().parameters(): 27 return param.device 28 29 30# RPC handler for querying __dict__ on the destination worker. 31def remote_module_attributes(remote_module): 32 return remote_module.__dict__ 33 34 35# RPC handler for running forward on the destination worker. 36def remote_forward(remote_module, args): 37 return remote_module.forward(*args) 38 39# RPC handler for running forward_async on the destination worker. 40def remote_forward_async(remote_module, args): 41 # Since future cannot be pickled and sent over the RPC layer, 42 # have to wait and behave just like ``forward_sync``. 43 return remote_module.forward_async(*args).wait() 44 45# RPC handler for getting training mode on the destination worker. 46def get_remote_training_arg(module_rref): 47 return module_rref.local_value().training 48 49class ModuleCreationMode(enum.Enum): 50 MODULE_CTOR_WITH_INTERFACE = "module_ctor_with_interface" 51 MODULE_CTOR = "module_ctor" 52 53 54@torch.jit.interface 55class MyModuleInterface: 56 def forward( 57 self, tensor: Tensor, number: int, word: str = "default" 58 ) -> Tuple[str, int, Tensor]: 59 # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well 60 pass 61 62 63@torch.jit.interface 64class RemoteMyModuleInterface: 65 def forward( 66 self, tensor: Tensor, number: int, word: str = "default" 67 ) -> Tuple[str, int, Tensor]: 68 # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well 69 pass 70 71 def forward_async( 72 self, tensor: Tensor, number: int, word: str = "default" 73 ) -> Future[Tuple[str, int, Tensor]]: 74 pass 75 76 77class MyModule(nn.Module): 78 def __init__(self, first_arg, first_kwarg=-1): 79 super().__init__() 80 self.param1 = _PARAM_VAL 81 82 def forward( 83 self, tensor: Tensor, number: int, word: str = "default" 84 ) -> Tuple[str, int, Tensor]: 85 return word, number, tensor 86 87 88class BadModule: 89 def __init__(self, first_arg, first_kwarg=-1): 90 pass 91 92 93def create_scripted_module(first_arg, first_kwarg=-1): 94 module = MyModule(first_arg, first_kwarg=first_kwarg) 95 scripted_module = torch.jit.script(module) 96 return scripted_module 97 98 99# Common utils for both CPU and CUDA test suites 100class CommonRemoteModuleTest(RpcAgentTestFixture): 101 @property 102 def world_size(self): # Override setting in RpcAgentTestFixture 103 return 2 104 105 @staticmethod 106 def _create_remote_module_iter(remote_device, modes=None): 107 if modes is None: 108 modes = ModuleCreationMode.__members__.values() 109 110 args = (1,) 111 kwargs = dict(first_kwarg=2) 112 113 if ModuleCreationMode.MODULE_CTOR in modes: 114 remote_module = RemoteModule(remote_device, MyModule, args, kwargs) 115 yield remote_module 116 117 if ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE in modes: 118 remote_module = _RemoteModule( 119 remote_device, 120 create_scripted_module, 121 args, 122 kwargs, 123 _module_interface_cls=MyModuleInterface, 124 ) 125 scripted_remote_module = torch.jit.script(remote_module) 126 yield scripted_remote_module 127 128 129class RemoteModuleTest(CommonRemoteModuleTest): 130 @dist_utils.dist_init 131 def test_bad_module(self): 132 if self.rank != 0: 133 return 134 dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) 135 remote_device = f"{dst_worker_name}/cpu" 136 args = (1,) 137 kwargs = dict(first_kwarg=2) 138 139 with self.assertRaisesRegex( 140 ValueError, 141 r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of <class nn.Module>,", 142 ): 143 RemoteModule(remote_device, BadModule, args, kwargs).forward() 144 145 with self.assertRaisesRegex( 146 ValueError, 147 r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of <class nn.Module>,", 148 ): 149 RemoteModule(remote_device, BadModule, args, kwargs).forward() 150 151 152 @dist_utils.dist_init 153 def test_forward_async(self): 154 if self.rank != 0: 155 return 156 dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) 157 args = (torch.ones(1), 2, "3") 158 for remote_module in self._create_remote_module_iter(dst_worker_name): 159 ret_fut = remote_module.forward_async(*args) 160 ret = ret_fut.wait() 161 self.assertEqual(ret, tuple(reversed(args))) 162 163 @dist_utils.dist_init 164 def test_forward_async_script(self): 165 if self.rank != 0: 166 return 167 dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) 168 169 scripted_remote_module = next( 170 self._create_remote_module_iter( 171 dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE] 172 ) 173 ) 174 175 @torch.jit.script 176 def run_forward_async(scripted_remote_module: RemoteMyModuleInterface): 177 ret_fut = scripted_remote_module.forward_async(torch.ones(1), 2, "3") 178 ret = ret_fut.wait() 179 return ret 180 181 ret = run_forward_async(scripted_remote_module) 182 183 self.assertEqual(ret, ("3", 2, torch.ones(1))) 184 185 @dist_utils.dist_init 186 def test_forward_sync(self): 187 if self.rank != 0: 188 return 189 dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) 190 args = (torch.ones(1), 2, "3") 191 for remote_module in self._create_remote_module_iter(dst_worker_name): 192 ret = remote_module.forward(*args) 193 self.assertEqual(ret, tuple(reversed(args))) 194 195 @dist_utils.dist_init 196 def test_forward_sync_script(self): 197 if self.rank != 0: 198 return 199 dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) 200 201 scripted_remote_module = next( 202 self._create_remote_module_iter( 203 dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE] 204 ) 205 ) 206 207 @torch.jit.script 208 def run_forward(scripted_remote_module: MyModuleInterface): 209 ret = scripted_remote_module.forward(torch.ones(1), 2, "3") 210 return ret 211 212 ret = run_forward(scripted_remote_module) 213 214 self.assertEqual(ret, ("3", 2, torch.ones(1))) 215 216 @dist_utils.dist_init 217 def test_forward_with_kwargs(self): 218 if self.rank != 0: 219 return 220 dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) 221 args = (torch.ones(1), 2) 222 kwargs = dict(word="3") 223 # Only test Python nn.Module, because script module methods don't support taking kwargs. 224 for remote_module in self._create_remote_module_iter( 225 dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] 226 ): 227 ret_fut = remote_module.forward_async(*args, **kwargs) 228 ret = ret_fut.wait() 229 self.assertEqual(ret, tuple(reversed(args + ("3",)))) 230 231 ret = remote_module.forward(*args, **kwargs) 232 self.assertEqual(ret, tuple(reversed(args + ("3",)))) 233 234 @dist_utils.dist_init 235 def test_remote_parameters(self): 236 if self.rank != 0: 237 return 238 dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) 239 240 # Only test Python nn.Module, because script module methods don't support ``remote_parameters``. 241 for remote_module in self._create_remote_module_iter( 242 dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] 243 ): 244 param_rrefs = remote_module.remote_parameters() 245 self.assertEqual(len(param_rrefs), 1) 246 self.assertTrue(torch.equal(param_rrefs[0].to_here(), _PARAM_VAL)) 247 248 @dist_utils.dist_init 249 def test_get_module_rref(self): 250 if self.rank != 0: 251 return 252 dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) 253 254 # Only test Python nn.Module, because script module methods don't support ``get_module_rref``. 255 for remote_module in self._create_remote_module_iter( 256 dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] 257 ): 258 rref = remote_module.get_module_rref() 259 self.assertEqual(rref, remote_module.module_rref) 260 for param in rref.to_here().parameters(): 261 self.assertTrue(torch.equal(param, _PARAM_VAL)) 262 263 @dist_utils.dist_init 264 def test_train_eval(self): 265 if self.rank != 0: 266 return 267 dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) 268 269 for remote_module in self._create_remote_module_iter( 270 dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] 271 ): 272 remote_module.train() 273 ret1 = rpc.rpc_sync(dst_worker_name, get_remote_training_arg, args=(remote_module.get_module_rref(),)) 274 self.assertEqual(ret1, True) 275 276 remote_module.eval() 277 ret2 = rpc.rpc_sync(dst_worker_name, get_remote_training_arg, args=(remote_module.get_module_rref(),)) 278 self.assertEqual(ret2, False) 279 280 @dist_utils.dist_init 281 def test_unsupported_methods(self): 282 if self.rank != 0: 283 return 284 dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) 285 286 for remote_module in self._create_remote_module_iter( 287 dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] 288 ): 289 with self.assertRaisesRegex( 290 ValueError, r"Method ``register_buffer`` not supported for RemoteModule" 291 ): 292 remote_module.register_buffer("buffer", torch.ones(5)) 293 with self.assertRaisesRegex( 294 ValueError, 295 r"Method ``register_parameter`` not supported for RemoteModule", 296 ): 297 remote_module.register_parameter( 298 "param", torch.nn.Parameter(torch.ones(1)) 299 ) 300 with self.assertRaisesRegex( 301 ValueError, r"Method ``add_module`` not supported for RemoteModule" 302 ): 303 remote_module.add_module("empty", None) 304 305 with self.assertRaisesRegex( 306 ValueError, r"Method ``apply`` not supported for RemoteModule" 307 ): 308 fn = torch.rand((3, 3), requires_grad=False) 309 remote_module.apply(fn) 310 311 with self.assertRaisesRegex( 312 ValueError, r"Method ``cuda`` not supported for RemoteModule" 313 ): 314 remote_module.cuda() 315 with self.assertRaisesRegex( 316 ValueError, r"Method ``cpu`` not supported for RemoteModule" 317 ): 318 remote_module.cpu() 319 with self.assertRaisesRegex( 320 ValueError, r"Method ``type`` not supported for RemoteModule" 321 ): 322 remote_module.type(torch.FloatTensor) 323 with self.assertRaisesRegex( 324 ValueError, r"Method ``float`` not supported for RemoteModule" 325 ): 326 remote_module.float() 327 with self.assertRaisesRegex( 328 ValueError, r"Method ``double`` not supported for RemoteModule" 329 ): 330 remote_module.double() 331 with self.assertRaisesRegex( 332 ValueError, r"Method ``bfloat16`` not supported for RemoteModule" 333 ): 334 remote_module.bfloat16() 335 with self.assertRaisesRegex( 336 ValueError, r"Method ``to`` not supported for RemoteModule" 337 ): 338 remote_module.to("cpu", dtype=torch.int32) 339 340 def hook(module, grad_input, grad_output): 341 pass 342 343 with self.assertRaisesRegex( 344 ValueError, 345 r"Method ``register_backward_hook`` not supported for RemoteModule", 346 ): 347 remote_module.register_backward_hook(hook) 348 with self.assertRaisesRegex( 349 ValueError, 350 r"Method ``register_forward_pre_hook`` not supported for RemoteModule", 351 ): 352 remote_module.register_forward_pre_hook(hook) 353 with self.assertRaisesRegex( 354 ValueError, 355 r"Method ``register_forward_hook`` not supported for RemoteModule", 356 ): 357 remote_module.register_forward_hook(hook) 358 359 with self.assertRaisesRegex( 360 ValueError, r"Method ``state_dict`` not supported for RemoteModule" 361 ): 362 remote_module.state_dict() 363 with self.assertRaisesRegex( 364 ValueError, r"Method ``load_state_dict`` not supported for RemoteModule" 365 ): 366 remote_module.load_state_dict({}) 367 368 with self.assertRaisesRegex( 369 ValueError, 370 r"Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead.", 371 ): 372 remote_module.parameters() 373 with self.assertRaisesRegex( 374 ValueError, 375 r"Method ``named_parameters`` not supported for RemoteModule", 376 ): 377 remote_module.named_parameters() 378 with self.assertRaisesRegex( 379 ValueError, r"Method ``buffers`` not supported for RemoteModule" 380 ): 381 remote_module.buffers() 382 with self.assertRaisesRegex( 383 ValueError, r"Method ``named_buffers`` not supported for RemoteModule" 384 ): 385 remote_module.named_buffers() 386 with self.assertRaisesRegex( 387 ValueError, r"Method ``children`` not supported for RemoteModule" 388 ): 389 remote_module.children() 390 with self.assertRaisesRegex( 391 ValueError, r"Method ``named_children`` not supported for RemoteModule" 392 ): 393 remote_module.named_children() 394 with self.assertRaisesRegex( 395 ValueError, r"Method ``modules`` not supported for RemoteModule" 396 ): 397 remote_module.modules() 398 with self.assertRaisesRegex( 399 ValueError, r"Method ``named_modules`` not supported for RemoteModule" 400 ): 401 remote_module.named_modules() 402 403 with self.assertRaisesRegex( 404 ValueError, r"Method ``requires_grad_`` not supported for RemoteModule" 405 ): 406 remote_module.requires_grad_() 407 with self.assertRaisesRegex( 408 ValueError, r"Method ``zero_grad`` not supported for RemoteModule" 409 ): 410 remote_module.zero_grad() 411 with self.assertRaisesRegex( 412 ValueError, r"Method ``share_memory`` not supported for RemoteModule" 413 ): 414 remote_module.share_memory() 415 with self.assertRaisesRegex( 416 ValueError, r"Method ``extra_repr`` not supported for RemoteModule" 417 ): 418 remote_module.extra_repr() 419 420 @dist_utils.dist_init 421 def test_send_remote_module_with_a_new_attribute_not_pickled_over_the_wire(self): 422 if self.rank != 0: 423 return 424 dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) 425 426 # If a new attribute is added to this RemoteModule after the initialization, 427 # and it will be sent over the wire by RPC, 428 # this new field will not be pickled, because it's not specified in _REMOTE_MODULE_PICKLED_ATTRIBUTES. 429 # Note that adding a new attribute out of constructor should rarely happen. 430 # If a new attribute is added to RemoteModule constructor, 431 # there is a sanity check to enforce developers to add this attribute to either 432 # _REMOTE_MODULE_PICKLED_ATTRIBUTES or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING. 433 for remote_module in self._create_remote_module_iter( 434 dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] 435 ): 436 new_attr_name = "new_attr" 437 setattr(remote_module, new_attr_name, 1) 438 439 attrs = rpc.rpc_sync( 440 dst_worker_name, remote_module_attributes, (remote_module,) 441 ) 442 self.assertNotIn(new_attr_name, attrs) 443 444 @dist_utils.dist_init 445 def test_remote_module_py_pickle_not_supported(self): 446 if self.rank != 0: 447 return 448 dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) 449 450 for remote_module in self._create_remote_module_iter( 451 dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR] 452 ): 453 with TemporaryFileName() as fname: 454 with self.assertRaisesRegex( 455 RuntimeError, 456 "Cannot pickle RemoteModule in python pickler. RemoteModule can only be pickled when using RPC", 457 ): 458 torch.save(remote_module, fname) 459 460 @dist_utils.dist_init 461 def test_remote_module_py_pickle_not_supported_script(self): 462 if self.rank != 0: 463 return 464 dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) 465 466 for remote_module in self._create_remote_module_iter( 467 dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE] 468 ): 469 with TemporaryFileName() as fname: 470 with self.assertRaisesRegex(torch.jit.Error, "can only be pickled when using RPC"): 471 torch.save(remote_module, fname) 472 473 474class ThreeWorkersRemoteModuleTest(CommonRemoteModuleTest): 475 @property 476 def world_size(self): # Override setting in CommonRemoteModuleTest 477 return 3 478 479 @dist_utils.dist_init 480 def test_send_remote_module_over_the_wire(self): 481 if self.rank != 0: 482 return 483 dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size) 484 dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size) 485 486 # Unpickled attributes include both the inherent attributes of RemoteModule 487 # (not inherited from the superclass) and two installed methods. 488 expected_unpickled_attrs = list(_REMOTE_MODULE_PICKLED_ATTRIBUTES) 489 expected_unpickled_attrs.append("forward_async") 490 expected_unpickled_attrs.append("forward") 491 492 # Create a remote module on worker1 and then pass it to worker2 over the RPC layer. 493 for remote_module in self._create_remote_module_iter( 494 dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR] 495 ): 496 # Test querying some simple attributes from worker2. 497 attrs = rpc.rpc_sync( 498 dst_worker2_name, remote_module_attributes, (remote_module,) 499 ) 500 self.assertListEqual(list(attrs.keys()), expected_unpickled_attrs) 501 self.assertEqual(attrs["on"], "worker1") 502 self.assertEqual(attrs["device"], "cpu") 503 self.assertFalse(attrs["is_device_map_set"]) 504 self.assertFalse(attrs["is_scriptable"]) 505 506 # Test the installed methods on worker1's can be initiated by worker2 over RPC layer. 507 # NOTE: In practice a remote module should be directly stored on the worker that runs ``forward``` or ``forward_async``, 508 # not have another worker to initiate forward over the RPC layer. 509 args = (torch.ones(1), 2, "3") 510 ret1 = rpc.rpc_sync(dst_worker2_name, remote_forward, (remote_module, args)) 511 self.assertEqual(ret1, tuple(reversed(args))) 512 ret2 = rpc.rpc_sync( 513 dst_worker2_name, remote_forward_async, (remote_module, args) 514 ) 515 self.assertEqual(ret2, tuple(reversed(args))) 516 517 @dist_utils.dist_init 518 def test_send_remote_module_over_the_wire_script_not_supported(self): 519 if self.rank != 0: 520 return 521 dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size) 522 dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size) 523 524 # Unpickled attributes include both the inherent attributes of RemoteModule 525 # (not inherited from the superclass) and two installed methods. 526 expected_unpickled_attrs = list(_REMOTE_MODULE_PICKLED_ATTRIBUTES) 527 expected_unpickled_attrs.append("forward_async") 528 expected_unpickled_attrs.append("forward") 529 530 with self.assertRaisesRegex( 531 RuntimeError, "Passing a script RemoteModule over RPC is not supported." 532 ): 533 # Create a remote module on worker1 and then pass it to worker2 over the RPC layer. 534 for remote_module in self._create_remote_module_iter( 535 dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE] 536 ): 537 # Test querying some simple attributes from worker2. 538 attrs = rpc.rpc_sync( 539 dst_worker2_name, remote_module_attributes, (remote_module,) 540 ) 541 542 @dist_utils.dist_init 543 def test_create_remote_module_from_module_rref(self): 544 if self.rank != 0: 545 return 546 dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size) 547 dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size) 548 549 # Create a remote module on worker1 and then pass its `module_rref` to worker2 over the RPC layer. 550 for remote_module in self._create_remote_module_iter( 551 dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR] 552 ): 553 remote_module2 = rpc.rpc_sync( 554 dst_worker2_name, 555 RemoteModule.init_from_module_rref, 556 (dst_worker2_name, remote_module.get_module_rref()), 557 ) 558 559 args = (torch.ones(1), 2, "3") 560 ret1 = rpc.rpc_sync( 561 dst_worker1_name, remote_forward, (remote_module, args) 562 ) 563 ret2 = rpc.rpc_sync( 564 dst_worker2_name, remote_forward, (remote_module2, args) 565 ) 566 self.assertEqual(ret2, ret2) 567 568 569class CudaRemoteModuleTest(CommonRemoteModuleTest): 570 @skip_if_lt_x_gpu(1) 571 @dist_utils.dist_init 572 def test_valid_device(self): 573 if self.rank != 0: 574 return 575 dst_rank = (self.rank + 1) % self.world_size 576 dst_worker_name = dist_utils.worker_name(dst_rank) 577 578 for remote_module in self._create_remote_module_iter( 579 f"{dst_worker_name}/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR] 580 ): 581 device = rpc.rpc_sync( 582 dst_worker_name, remote_device, (remote_module.module_rref,) 583 ) 584 self.assertEqual(device.type, "cuda") 585 self.assertEqual(device.index, 0) 586 587 # Test rank works as well. 588 for remote_module in self._create_remote_module_iter( 589 f"rank:{dst_rank}/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR] 590 ): 591 device = rpc.rpc_sync( 592 dst_worker_name, remote_device, (remote_module.module_rref,) 593 ) 594 self.assertEqual(device.type, "cuda") 595 self.assertEqual(device.index, 0) 596 597 @skip_if_lt_x_gpu(1) 598 @dist_utils.dist_init 599 def test_invalid_devices(self): 600 if self.rank != 0: 601 return 602 dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) 603 604 with self.assertRaisesRegex( 605 RuntimeError, 606 r"Expected one of .+ device type at start of device string", 607 ): 608 [ 609 m.forward() 610 for m in self._create_remote_module_iter( 611 f"{dst_worker_name}/foo", 612 modes=[ModuleCreationMode.MODULE_CTOR], 613 ) 614 ] 615 616 with self.assertRaisesRegex( 617 RuntimeError, r"CUDA error: invalid device ordinal" 618 ): 619 [ 620 m.forward() 621 for m in self._create_remote_module_iter( 622 f"{dst_worker_name}/cuda:100", 623 modes=[ModuleCreationMode.MODULE_CTOR], 624 ) 625 ] 626 627 with self.assertRaisesRegex(RuntimeError, r"Invalid device string: 'cpu2'"): 628 [ 629 m.forward() 630 for m in self._create_remote_module_iter( 631 f"{dst_worker_name}/cpu2", 632 modes=[ModuleCreationMode.MODULE_CTOR], 633 ) 634 ] 635 636 with self.assertRaisesRegex(RuntimeError, r"Device string must not be empty"): 637 [ 638 m.forward() 639 for m in self._create_remote_module_iter( 640 f"{dst_worker_name}/", 641 modes=[ModuleCreationMode.MODULE_CTOR], 642 ) 643 ] 644 645 with self.assertRaisesRegex( 646 ValueError, 647 r"Could not parse remote_device: worker1/cuda:0/cuda:1. The valid format is '<workername>/<device>'", 648 ): 649 [ 650 m.forward() 651 for m in self._create_remote_module_iter( 652 f"{dst_worker_name}/cuda:0/cuda:1", 653 modes=[ModuleCreationMode.MODULE_CTOR], 654 ) 655 ] 656 657 with self.assertRaisesRegex( 658 ValueError, 659 r"Could not parse remote_device: /. The valid format is '<workername>/<device>'", 660 ): 661 [ 662 m.forward() 663 for m in self._create_remote_module_iter( 664 "/", 665 modes=[ModuleCreationMode.MODULE_CTOR], 666 ) 667 ] 668 669 with self.assertRaisesRegex( 670 ValueError, 671 r"Could not parse remote_device: /cuda:0. The valid format is '<workername>/<device>'", 672 ): 673 [ 674 m.forward() 675 for m in self._create_remote_module_iter( 676 "/cuda:0", 677 modes=[ModuleCreationMode.MODULE_CTOR], 678 ) 679 ] 680 681 @skip_if_lt_x_gpu(1) 682 @dist_utils.dist_init 683 def test_input_moved_to_cuda_device(self): 684 if self.rank != 0: 685 return 686 dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) 687 688 # These two CPU tensors (in args and kwargs) should be implicitly moved to an appropriate cuda device. 689 t1 = torch.ones(1) 690 args = (t1, 2) 691 t2 = t1 * 2 692 kwargs = dict(word=t2) 693 694 # Only test Python nn.Module, because script module methods don't support taking kwargs. 695 for remote_module in self._create_remote_module_iter( 696 f"{dst_worker_name}/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR] 697 ): 698 ret_fut = remote_module.forward_async(*args, **kwargs) 699 ret = ret_fut.wait() 700 self.assertEqual(ret, tuple(reversed(args + (t2,)))) 701 # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0". 702 self.assertEqual(ret[0].device.type, "cpu") 703 self.assertEqual(ret[2].device.type, "cpu") 704 705 ret = remote_module.forward(*args, **kwargs) 706 self.assertEqual(ret, tuple(reversed(args + (t2,)))) 707 # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0". 708 self.assertEqual(ret[0].device.type, "cpu") 709 self.assertEqual(ret[2].device.type, "cpu") 710 711 @skip_if_lt_x_gpu(1) 712 @dist_utils.dist_init 713 def test_input_moved_to_cuda_device_script(self): 714 if self.rank != 0: 715 return 716 dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size) 717 718 scripted_remote_module = next( 719 self._create_remote_module_iter( 720 f"{dst_worker_name}/cuda:0", 721 modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE], 722 ) 723 ) 724 725 @torch.jit.script 726 def run_forward(scripted_remote_module: MyModuleInterface): 727 ret = scripted_remote_module.forward(torch.ones(1), 2, "3") 728 return ret 729 730 ret = run_forward(scripted_remote_module) 731 732 self.assertEqual(ret, ("3", 2, torch.ones(1))) 733 # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0". 734 self.assertEqual(ret[2].device.type, "cpu") 735