1#!/usr/bin/python3 2# mypy: allow-untyped-defs 3import collections 4import io 5import sys 6import types 7from typing import ( 8 Any, 9 Callable, 10 Dict, 11 Iterator, 12 List, 13 Mapping, 14 Optional, 15 Set, 16 Tuple, 17 Type, 18 TypeVar, 19 Union, 20) 21 22import torch 23import torch.distributed.rpc as rpc 24from torch import device, dtype, nn, Tensor 25from torch.distributed import _remote_device 26from torch.distributed.nn.jit import instantiator 27from torch.distributed.rpc.internal import _internal_rpc_pickler 28from torch.nn import Module 29from torch.nn.parameter import Parameter 30from torch.utils.hooks import RemovableHandle 31 32 33__all__ = ["RemoteModule"] 34 35_grad_t = Union[Tuple[Tensor, ...], Tensor] 36# See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use 37# of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be 38# the type of the subclass, not the looser type of `Module`. 39T = TypeVar("T", bound="Module") 40 41_NON_SCRIPTABLE_REMOTE_MODULE_MODULE = ( 42 instantiator.instantiate_non_scriptable_remote_module_template() 43) 44 45_REMOTE_MODULE_PICKLED_ATTRIBUTES = ( 46 "on", 47 "device", 48 "is_device_map_set", 49 "is_scriptable", 50 "generated_methods", 51 "module_rref", 52) 53 54_SerializedRemoteModule = collections.namedtuple("_SerializedRemoteModule", _REMOTE_MODULE_PICKLED_ATTRIBUTES) # type: ignore[misc] 55 56# These attributes are mostly from RemoteModule's parent class and are intentionally not pickled. 57# A new attribute of RemoteModule should be either in _REMOTE_MODULE_PICKLED_ATTRIBUTES 58# or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING. 59# Otherwise, it will not be pickled. 60_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING = ( 61 "training", 62 "_parameters", 63 "_buffers", 64 "_non_persistent_buffers_set", 65 "_backward_hooks", 66 "_backward_pre_hooks", 67 "_is_full_backward_hook", 68 "_forward_hooks", 69 "_forward_hooks_with_kwargs", 70 "_forward_hooks_always_called", 71 "_forward_pre_hooks", 72 "_forward_pre_hooks_with_kwargs", 73 "_state_dict_hooks", 74 "_state_dict_pre_hooks", 75 "_load_state_dict_pre_hooks", 76 "_load_state_dict_post_hooks", 77 "_state_dict_pre_hooks", 78 "_modules", 79 # The two attributes below are generated methods, not available at pickling time. 80 "forward_async", 81 "forward", 82) 83 84 85# RPC handler. 86def _instantiate_template(module_interface_cls, enable_moving_cpu_tensors_to_cuda): 87 instantiator.instantiate_scriptable_remote_module_template( 88 module_interface_cls, enable_moving_cpu_tensors_to_cuda 89 ) 90 91 92def _create_module(module_cls, args, kwargs, device): 93 module = module_cls(*args, **kwargs) 94 if not isinstance(module, nn.Module): 95 raise ValueError( 96 "Expect `module_cls(*args, **kwargs)` returns an instance of <class nn.Module>, " 97 f"but it returns an instance of {type(module)}." 98 ) 99 module.to(device) 100 return module 101 102 103def _create_module_with_interface( 104 module_cls, args, kwargs, device, module_interface_cls 105): 106 module = _create_module(module_cls, args, kwargs, device) 107 if module_interface_cls is not None: 108 module = torch.jit.script(module) 109 return rpc.RRef(module, module_interface_cls) 110 111 112def _param_rrefs(module_rref, recurse) -> List[rpc.RRef[Parameter]]: 113 ret: List[rpc.RRef[Parameter]] = [] 114 for param in module_rref.local_value().parameters(recurse): 115 ret.append(rpc.RRef(param)) 116 return ret 117 118 119def _raise_not_supported(name: str) -> None: 120 raise ValueError(f"Method ``{name}`` not supported for RemoteModule") 121 122 123class _RemoteModule(nn.Module): 124 def __new__(cls, *args, **kwargs): 125 # Use __new__ for logging purposes. 126 torch._C._log_api_usage_once("torch.distributed.nn.api.remote_module") 127 return super().__new__(cls) 128 129 def __init__( 130 self, 131 remote_device: str, 132 module_cls: Type[nn.Module], 133 args: Optional[Tuple] = None, 134 kwargs: Optional[Dict[str, Any]] = None, 135 _module_interface_cls: Any = None, 136 ): 137 """ 138 RemoteModule instance can only be created after RPC initialization. 139 140 It creates a user-specified module on a specified remote node. 141 It behaves like a regular ``nn.Module`` except that the ``forward`` method is 142 executed on the remote node. 143 It takes care of autograd recording to ensure the backward pass propagates 144 gradients back to the corresponding remote module. 145 It can be shared across processors using `RPC framework <https://pytorch.org/docs/stable/rpc.html>`__, 146 without incurring any overheads of copying the actual module, 147 which is equivalent to an :class:`~torch.distributed.rpc.RRef` 148 pointing to the remote module. 149 150 The arguments of ``forward_async`` and ``forward`` are the same as 151 the ``forward`` method of the module returned by the ``module_cls``. 152 153 Apart from ``forward_async`` and ``forward``, no other methods are supported from nn.Module for now. 154 155 Particularly, to create a hybrid model, typically the local modules should be 156 created outside of remote modules, rather than as submodules of any remote module (by calling ``add_module``). 157 Hybrid Example: 158 >>> class HybridModel(nn.Module): 159 >>> def __init__(self) -> None: 160 >>> nn.Module.__init__(self) 161 >>> self.remote_embedding = RemoteModule(...) 162 >>> self.local_linear = nn.Linear(...) 163 164 For example, if ``module_cls`` returns an instance of ``nn.Linear``, 165 that has ``forward`` method signature, ``def forward(input: Tensor) -> Tensor:``, 166 the generated ``RemoteModule`` will have 2 methods in signature of 167 ``def forward(input: Tensor) -> Tensor:`` and 168 ``def forward_async(input: Tensor) -> Future[Tensor]:``. 169 170 .. note:: 171 If the remote module is placed on a cuda device, 172 any input CPU tensors will be automatically moved to the same cuda device, 173 and GPU tensors are returned over the wire according to the device map of the remote worker on TensorPipe RPC backend. 174 175 Args: 176 remote_device (str): Device on the destination worker where we'd like to place this module. 177 The device can be a local device or a remote device specified by one of the following remote 178 formats: 179 180 1. "rank:<rank>/<device>" (ex: "rank:0/cuda:0"). 181 2. "<worker_name>/<device>" (ex: "trainer0/cuda:0"). 182 183 In addition, the device field can be optional and the default value is "cpu". 184 module_cls (nn.Module): For example, 185 >>> class MyModule(nn.Module): 186 >>> def forward(input): 187 >>> return input + 1 188 >>> 189 >>> module_cls = MyModule 190 args (Sequence, optional): args to be passed to ``module_cls``. 191 kwargs (Dict, optional): kwargs to be passed to ``module_cls``. 192 _module_interface_cls (type, optional): The TorchScript interface type for the module 193 to be created. The type object should be decorated by @torch.jit.interface. 194 If not provided, the generated RemoteModule is not torchscript-able. 195 Warning, this is an experimental API and susceptible to frequent changes. 196 197 Returns: 198 A remote module instance which wraps the :class:`~nn.Module` created by the 199 user-provided ``module_cls``, it has a blocking ``forward`` method and an 200 asynchronous ``forward_async`` method that returns a future of the ``forward`` call 201 on the user-provided module on the remote side. 202 203 Example:: 204 Run the following code in two different processes: 205 206 >>> # xdoctest: +SKIP("distributed") 207 >>> # On worker 0: 208 >>> import torch 209 >>> import torch.distributed.rpc as rpc 210 >>> from torch import nn, Tensor 211 >>> from torch.distributed.nn.api.remote_module import RemoteModule 212 >>> 213 >>> rpc.init_rpc("worker0", rank=0, world_size=2) 214 >>> remote_linear_module = RemoteModule( 215 >>> "worker1/cpu", nn.Linear, args=(20, 30), 216 >>> ) 217 >>> input = torch.randn(128, 20) 218 >>> ret_fut = remote_linear_module.forward_async(input) 219 >>> ret = ret_fut.wait() 220 >>> rpc.shutdown() 221 222 >>> # On worker 1: 223 >>> import torch 224 >>> import torch.distributed.rpc as rpc 225 >>> 226 >>> rpc.init_rpc("worker1", rank=1, world_size=2) 227 >>> rpc.shutdown() 228 """ 229 super().__init__() 230 231 enable_moving_cpu_tensors_to_cuda = self._prepare_init(remote_device) 232 233 # Default arguments preparation. 234 args = args if args is not None else () 235 kwargs = kwargs if kwargs is not None else {} 236 237 if _module_interface_cls is not None: 238 # Users reply on this field to know if this generated RemoteModule is TorchScript-able. 239 self.is_scriptable = True 240 241 # Instantiate template on remote side. 242 fut = rpc.rpc_async( 243 self.on, 244 _instantiate_template, 245 (_module_interface_cls, enable_moving_cpu_tensors_to_cuda), 246 ) 247 248 self._init_template( 249 _module_interface_cls, enable_moving_cpu_tensors_to_cuda 250 ) 251 252 # Instantiate template on remote side. 253 fut = rpc.rpc_async( 254 self.on, 255 _instantiate_template, 256 (_module_interface_cls, enable_moving_cpu_tensors_to_cuda), 257 ) 258 259 # Create the module on the remote side. 260 fut.wait() # Ensure remote_module_cls is available on remote side. 261 262 # TODO: We need to change this to rpc.remote, and make it async (see the else branch below). 263 # For that we need to be able to apply _module_interface_cls to the RRef returned by rpc.remote 264 # See https://github.com/pytorch/pytorch/issues/58098 for more context. 265 self.module_rref = rpc.rpc_sync( 266 self.on, 267 _create_module_with_interface, 268 (module_cls, args, kwargs, self.device, _module_interface_cls), 269 ) 270 else: 271 self.is_scriptable = False 272 self.generated_methods = ( 273 _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods 274 ) 275 # Create the module on the remote side. 276 self.module_rref = rpc.remote( 277 self.on, 278 _create_module, 279 (module_cls, args, kwargs, self.device), 280 ) 281 282 self._install_generated_methods() 283 self._check_attribute_picklability() 284 285 def remote_parameters(self, recurse: bool = True) -> List[rpc.RRef[Parameter]]: 286 """ 287 Return a list of :class:`~torch.distributed.rpc.RRef` pointing to the remote module's parameters. 288 289 This can typically be used in conjunction 290 with :class:`~torch.distributed.optim.DistributedOptimizer`. 291 292 Args: 293 recurse (bool): if True, then returns parameters of the remote 294 module and all submodules of the remote module. Otherwise, 295 returns only parameters that are direct members of the 296 remote module. 297 298 Returns: 299 A list of :class:`~torch.distributed.rpc.RRef` (``List[RRef[nn.Parameter]]``) 300 to remote module's parameters. 301 """ 302 return rpc.rpc_sync(self.on, _param_rrefs, args=(self.module_rref, recurse)) 303 304 def get_module_rref(self) -> rpc.RRef[nn.Module]: 305 """Return an :class:`~torch.distributed.rpc.RRef` (``RRef[nn.Module]``) pointing to the remote module.""" 306 return self.module_rref 307 308 @torch.jit.export 309 def __getstate__(self): 310 raise RuntimeError( 311 "Cannot pickle RemoteModule in python pickler. RemoteModule can only be pickled when using RPC" 312 ) 313 314 @torch.jit.export 315 def __setstate__(self, state): 316 raise RuntimeError( 317 "Cannot unpickle RemoteModule in python pickler. RemoteModule can only be unpickled when using RPC" 318 ) 319 320 def register_buffer( 321 self, name: str, tensor: Optional[Tensor], persistent: bool = True 322 ) -> None: 323 _raise_not_supported(self.register_buffer.__name__) 324 325 def register_parameter(self, name: str, param: Optional[Parameter]) -> None: 326 _raise_not_supported(self.register_parameter.__name__) 327 328 def add_module(self, name: str, module: Optional[Module]) -> None: 329 _raise_not_supported(self.add_module.__name__) 330 331 def apply(self: T, fn: Callable[[Module], None]) -> T: # type: ignore[return] 332 _raise_not_supported(self.apply.__name__) 333 334 def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return] 335 _raise_not_supported(self.cuda.__name__) 336 337 def ipu(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return] 338 _raise_not_supported(self.ipu.__name__) 339 340 def xpu(self: T, device: Optional[Union[int, device]] = None) -> T: # type: ignore[return] 341 _raise_not_supported(self.xpu.__name__) 342 343 def cpu(self: T) -> T: # type: ignore[return] 344 _raise_not_supported(self.cpu.__name__) 345 346 def type(self: T, dst_type: Union[dtype, str]) -> T: # type: ignore[return] 347 _raise_not_supported(self.type.__name__) 348 349 def float(self: T) -> T: # type: ignore[return] 350 _raise_not_supported(self.float.__name__) 351 352 def double(self: T) -> T: # type: ignore[return] 353 _raise_not_supported(self.double.__name__) 354 355 def half(self: T) -> T: # type: ignore[return] 356 _raise_not_supported(self.half.__name__) 357 358 def bfloat16(self: T) -> T: # type: ignore[return] 359 _raise_not_supported(self.bfloat16.__name__) 360 361 def to(self, *args, **kwargs) -> T: # type: ignore[misc, return, type-var] 362 _raise_not_supported(self.to.__name__) 363 364 def register_backward_hook( # type: ignore[return] 365 self, hook: Callable[[Module, _grad_t, _grad_t], Union[None, _grad_t]] 366 ) -> RemovableHandle: 367 _raise_not_supported(self.register_backward_hook.__name__) 368 369 def register_forward_pre_hook( # type: ignore[return] 370 self, 371 hook: Union[ 372 Callable[[T, Tuple[Any, ...]], Optional[Any]], 373 Callable[ 374 [T, Tuple[Any, ...], Dict[str, Any]], 375 Optional[Tuple[Any, Dict[str, Any]]], 376 ], 377 ], 378 prepend: bool = False, 379 with_kwargs: bool = False, 380 ) -> RemovableHandle: 381 _raise_not_supported(self.register_forward_pre_hook.__name__) 382 383 def register_forward_hook( # type: ignore[return, override] 384 self, 385 hook: Union[ 386 Callable[[T, Tuple[Any, ...], Any], Optional[Any]], 387 Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]], 388 ], 389 prepend: bool = False, 390 with_kwargs: bool = False, 391 ) -> RemovableHandle: 392 _raise_not_supported(self.register_forward_hook.__name__) 393 394 def state_dict(self, *args, **kwargs): 395 _raise_not_supported(self.state_dict.__name__) 396 397 def load_state_dict( 398 self, 399 state_dict: Mapping[str, Any], 400 strict: bool = True, 401 assign: bool = False, 402 ): 403 _raise_not_supported(self.load_state_dict.__name__) 404 405 def parameters(self, recurse: bool = True) -> Iterator[Parameter]: 406 raise ValueError( 407 "Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead." 408 ) 409 410 def named_parameters( # type: ignore[return] 411 self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True 412 ) -> Iterator[Tuple[str, Parameter]]: 413 _raise_not_supported(self.named_parameters.__name__) 414 415 def buffers(self, recurse: bool = True) -> Iterator[Tensor]: # type: ignore[return] 416 _raise_not_supported(self.buffers.__name__) 417 418 def named_buffers( # type: ignore[return] 419 self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True 420 ) -> Iterator[Tuple[str, Tensor]]: 421 _raise_not_supported(self.named_buffers.__name__) 422 423 def children(self) -> Iterator[Module]: # type: ignore[return] 424 _raise_not_supported(self.children.__name__) 425 426 def named_children(self) -> Iterator[Tuple[str, Module]]: # type: ignore[return] 427 _raise_not_supported(self.named_children.__name__) 428 429 def modules(self) -> Iterator[Module]: # type: ignore[return] 430 _raise_not_supported(self.modules.__name__) 431 432 def named_modules( 433 self, 434 memo: Optional[Set[Module]] = None, 435 prefix: str = "", 436 remove_duplicate: bool = True, 437 ): 438 _raise_not_supported(self.named_modules.__name__) 439 440 def train(self: T, mode: bool = True) -> T: 441 return self.module_rref.rpc_sync().train() # type: ignore[operator, union-attr] 442 443 def eval(self: T) -> T: 444 return self.module_rref.rpc_sync().eval() # type: ignore[operator, union-attr] 445 446 def requires_grad_(self: T, requires_grad: bool = True) -> T: # type: ignore[return] 447 _raise_not_supported(self.requires_grad_.__name__) 448 449 def zero_grad(self, set_to_none: bool = True) -> None: 450 _raise_not_supported(self.zero_grad.__name__) 451 452 def share_memory(self: T) -> T: # type: ignore[return] 453 _raise_not_supported(self.share_memory.__name__) 454 455 def extra_repr(self) -> str: # type: ignore[return] 456 _raise_not_supported(self.extra_repr.__name__) 457 458 def _prepare_init(self, remote_device_str: str) -> bool: 459 """Prepare the initialization and returns whether to enable automatically moving CPU tensors to CUDA devices.""" 460 # Sanity check. 461 assert rpc._is_current_rpc_agent_set(), "RemoteModule only works in RPC." 462 463 remote_device = _remote_device(remote_device_str) 464 self.on = ( 465 remote_device.worker_name() 466 if remote_device.worker_name() is not None 467 else remote_device.rank() 468 ) 469 self.device = str(remote_device.device()) 470 agent = rpc._get_current_rpc_agent() 471 # If the device map of the remote worker is set, 472 # then enable moving any input CPU tensors to the same cuda device. 473 self.is_device_map_set = bool( 474 agent._get_device_map(agent.get_worker_info(self.on)) # type: ignore[arg-type] 475 ) 476 # ``enable_moving_cpu_tensors_to_cuda`` is less strict than ``is_device_map_set``: 477 # If ``enable_moving_cpu_tensors_to_cuda`` is true, but the device map is not set, 478 # then any CPU tensors can still be moved to a cuda device to run forward, 479 # but the output must be moved back to CPU before being sent over the wire. 480 enable_moving_cpu_tensors_to_cuda = torch.device(self.device).type == "cuda" 481 return enable_moving_cpu_tensors_to_cuda 482 483 def _init_template(self, module_interface_cls, enable_moving_cpu_tensors_to_cuda): 484 """Instantiate template on local side.""" 485 generated_module = instantiator.instantiate_scriptable_remote_module_template( 486 module_interface_cls, enable_moving_cpu_tensors_to_cuda 487 ) 488 self.generated_methods = generated_module._generated_methods 489 490 def _check_attribute_picklability(self): 491 """Check if all the attribute has explicitly defined whether to be pickled (i.e., picklability).""" 492 for k in self.__dict__.keys(): 493 if ( 494 k not in _REMOTE_MODULE_PICKLED_ATTRIBUTES 495 and k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING 496 ): 497 raise AttributeError( 498 f"Attribute {k} must be either in ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` or " 499 "``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``." 500 ) 501 502 def _install_generated_methods(self): 503 for method in self.generated_methods: 504 method_name = method.__name__ 505 method = torch.jit.export(method) 506 setattr(self, method_name, types.MethodType(method, self)) 507 508 @staticmethod 509 def init_from_module_rref( 510 remote_device: str, 511 module_rref: rpc.RRef[nn.Module], 512 _module_interface_cls: Any = None, 513 ): 514 """ 515 Besides the constructor, a RemoteModule instance can also be initialized given a module RRef. 516 517 This alternate initialization method can be particularly useful if we want to create multiple 518 RemoteModule instances that share the same underlying module and reduce memory consumption. 519 520 Moreover, this also provides a workaround for passing script RemoteModule over RPC, 521 which is not supported. The recommended way is as follows: 522 523 1. the sender creates a RemoteModule; 524 2. the sender sends its ``module_rref`` over RPC; 525 3. the receiver calls this method to initialize another RemoteModule using the same ``module_rref``. 526 527 Example:: 528 Run the following code in two different processes: 529 530 >>> # xdoctest: +SKIP("distributed") 531 >>> # On worker 0: 532 >>> import torch 533 >>> import torch.distributed.rpc as rpc 534 >>> from torch import nn, Tensor 535 >>> from torch.distributed.nn.api.remote_module import RemoteModule 536 >>> 537 >>> rpc.init_rpc("worker0", rank=0, world_size=2) 538 >>> remote_module = RemoteModule( 539 >>> "worker1/cpu", nn.Linear, args=(20, 30), 540 >>> ) 541 >>> 542 >>> remote_module1 = rpc.rpc_sync( 543 >>> "worker1/cpu", 544 >>> RemoteModule.init_from_module_rref, 545 >>> ("worker1/cpu", remote_module1.get_module_rref()), 546 >>> ) 547 >>> rpc.shutdown() 548 549 >>> # On worker 1: 550 >>> import torch 551 >>> import torch.distributed.rpc as rpc 552 >>> 553 >>> rpc.init_rpc("worker1", rank=1, world_size=2) 554 >>> rpc.shutdown() 555 556 Args: 557 remote_device (str): Device on the destination worker where we'd like to place this module. 558 The device can be a local device or a remote device specified by one of the following remote 559 formats: 560 561 1. "rank:<rank>/<device>" (ex: "rank:0/cuda:0"). 562 2. "<worker_name>/<device>" (ex: "trainer0/cuda:0"). 563 564 In addition, the device field can be optional and the default value is "cpu". 565 module_rref (RRef[nn.Module]): The module reference shared by both the caller and 566 the created remote module. 567 _module_interface_cls (type, optional): The TorchScript interface type for the module 568 to be created. The type object should be decorated by @torch.jit.interface. 569 If not provided, the generated RemoteModule is not torchscript-able. 570 Warning, this is an experimental API and susceptible to frequent changes. 571 572 Returns: 573 A remote module instance which wraps the :class:`~nn.Module` created by the 574 user-provided ``module_rref``, it has a blocking ``forward`` method and an 575 asynchronous ``forward_async`` method that returns a future of the ``forward`` call 576 on the user-provided module on the remote side. 577 """ 578 # NOTE: if a new attribute is added to this class, also need to add it 579 # to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` for pickling/unpickling. 580 581 remote_module = object.__new__(RemoteModule) 582 583 enable_moving_cpu_tensors_to_cuda = remote_module._prepare_init(remote_device) 584 585 if _module_interface_cls is not None: 586 # Users reply on this field to know if this generated RemoteModule is TorchScript-able. 587 remote_module.is_scriptable = True 588 589 remote_module._init_template( 590 _module_interface_cls, enable_moving_cpu_tensors_to_cuda 591 ) 592 else: 593 remote_module.is_scriptable = False 594 remote_module.generated_methods = ( 595 _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods 596 ) 597 remote_module.module_rref = module_rref 598 599 remote_module._install_generated_methods() 600 remote_module._check_attribute_picklability() 601 602 return remote_module 603 604 605class RemoteModule(_RemoteModule): 606 """ 607 A RemoteModule instance can only be created after RPC initialization. 608 609 It creates a user-specified module on a specified remote node. 610 It behaves like a regular ``nn.Module`` except that the ``forward`` method is 611 executed on the remote node. 612 It takes care of autograd recording to ensure the backward pass propagates 613 gradients back to the corresponding remote module. 614 615 It generates two methods ``forward_async`` and ``forward`` based on the 616 signature of the ``forward`` method of ``module_cls``. ``forward_async`` 617 runs asynchronously and returns a Future. The arguments of ``forward_async`` 618 and ``forward`` are the same as the ``forward`` method of the module 619 returned by the ``module_cls``. 620 621 For example, if ``module_cls`` returns an instance of ``nn.Linear``, 622 that has ``forward`` method signature: ``def forward(input: Tensor) -> Tensor:``, 623 the generated ``RemoteModule`` will have 2 methods with the signatures: 624 625 | ``def forward(input: Tensor) -> Tensor:`` 626 | ``def forward_async(input: Tensor) -> Future[Tensor]:`` 627 628 Args: 629 remote_device (str): Device on the destination worker where we'd like to place this module. 630 The format should be "<workername>/<device>", where the device field can be parsed as torch.device type. 631 E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". 632 In addition, the device field can be optional and the default value is "cpu". 633 module_cls (nn.Module): Class for the module to be created remotely. For example, 634 635 >>> class MyModule(nn.Module): 636 >>> def forward(input): 637 >>> return input + 1 638 >>> 639 >>> module_cls = MyModule 640 641 args (Sequence, optional): args to be passed to ``module_cls``. 642 kwargs (Dict, optional): kwargs to be passed to ``module_cls``. 643 644 Returns: 645 A remote module instance which wraps the :class:`~nn.Module` created by the 646 user-provided ``module_cls``, it has a blocking ``forward`` method and an 647 asynchronous ``forward_async`` method that returns a future of the ``forward`` call 648 on the user-provided module on the remote side. 649 650 Example:: 651 Run the following code in two different processes: 652 653 >>> # xdoctest: +SKIP("distributed") 654 >>> # On worker 0: 655 >>> import torch 656 >>> import torch.distributed.rpc as rpc 657 >>> from torch import nn, Tensor 658 >>> from torch.distributed.nn.api.remote_module import RemoteModule 659 >>> 660 >>> rpc.init_rpc("worker0", rank=0, world_size=2) 661 >>> remote_linear_module = RemoteModule( 662 >>> "worker1/cpu", nn.Linear, args=(20, 30), 663 >>> ) 664 >>> input = torch.randn(128, 20) 665 >>> ret_fut = remote_linear_module.forward_async(input) 666 >>> ret = ret_fut.wait() 667 >>> rpc.shutdown() 668 669 >>> # On worker 1: 670 >>> import torch 671 >>> import torch.distributed.rpc as rpc 672 >>> 673 >>> rpc.init_rpc("worker1", rank=1, world_size=2) 674 >>> rpc.shutdown() 675 676 Furthermore, a more practical example that is combined with 677 `DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__ (DDP) 678 can be found in this `tutorial <https://pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html>`__. 679 """ 680 681 def __init__( 682 self, 683 remote_device: str, 684 module_cls: Type[nn.Module], 685 args: Optional[Tuple] = None, 686 kwargs: Optional[Dict[str, Any]] = None, 687 ): 688 super().__init__(remote_device, module_cls, args, kwargs) 689 690 691def _remote_module_receiver( 692 *remote_module_pickled_attrs, 693): 694 """Deserializes a RemoteModule.""" 695 serialized_remote_module = _SerializedRemoteModule._make( 696 remote_module_pickled_attrs 697 ) 698 m = object.__new__(RemoteModule) 699 m.__dict__.update(serialized_remote_module._asdict()) 700 701 # Unpickling the attribute `module_rref` must invoke RRef's `_deserialize()` method. 702 m.module_rref = rpc.PyRRef._deserialize(m.module_rref) 703 704 # Install generated methods when unpickled. 705 for method in m.generated_methods: 706 method_name = method.__name__ 707 method = torch.jit.export(method) 708 setattr(m, method_name, types.MethodType(method, m)) 709 710 return m 711 712 713def _remote_module_reducer(remote_module): 714 """Serialize a RemoteModule.""" 715 pickled_attrs = {} 716 for k, v in remote_module.__dict__.items(): 717 # Pickling the attribute `module_rref` must invoke RRef's `_serialize()` method. 718 if k == "module_rref": 719 pickled_attrs[k] = v._serialize() 720 elif k in _REMOTE_MODULE_PICKLED_ATTRIBUTES: 721 pickled_attrs[k] = v 722 # Check if unpickled attributes are all in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING. 723 elif k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING: 724 print( 725 f"The new attribute ``{k}`` of RemoteModule is ignored during RPC pickling. " 726 "To pickle this attribute, please add it to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES``. " 727 "Otherwise, please explicitly add it to ``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``.", 728 file=sys.stderr, 729 ) 730 731 return ( 732 _remote_module_receiver, 733 tuple(pickled_attrs.values()), 734 ) 735 736 737def _recursive_script_module_receiver( 738 recursive_script_module_serialized, 739): 740 """Deserializes a RecursiveScriptModule that does not contain a script RemoteModule.""" 741 f = io.BytesIO(recursive_script_module_serialized) 742 m = torch.jit.load(f) 743 return m 744 745 746def _recursive_script_module_reducer(recursive_script_module): 747 """Serialize a RecursiveScriptModule that does not contain a script RemoteModule, and raises an error otherwise.""" 748 if hasattr(recursive_script_module._c, "module_rref"): 749 raise RuntimeError( 750 "Passing a script RemoteModule over RPC is not supported. Please create a RemoteModule in the sender, " 751 "send the `module_rref` to the receiver, and create a new instance on the receiver end by passing this `module_rref`." 752 ) 753 754 f = io.BytesIO() 755 torch.jit.save(recursive_script_module, f) 756 return (_recursive_script_module_receiver, (f.getvalue(),)) 757 758 759_internal_rpc_pickler._register_reducer(RemoteModule, _remote_module_reducer) 760_internal_rpc_pickler._register_reducer( 761 torch.jit.RecursiveScriptModule, _recursive_script_module_reducer 762) 763