xref: /aosp_15_r20/external/pytorch/torch/distributed/nn/api/remote_module.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/python3
2# mypy: allow-untyped-defs
3import collections
4import io
5import sys
6import types
7from typing import (
8    Any,
9    Callable,
10    Dict,
11    Iterator,
12    List,
13    Mapping,
14    Optional,
15    Set,
16    Tuple,
17    Type,
18    TypeVar,
19    Union,
20)
21
22import torch
23import torch.distributed.rpc as rpc
24from torch import device, dtype, nn, Tensor
25from torch.distributed import _remote_device
26from torch.distributed.nn.jit import instantiator
27from torch.distributed.rpc.internal import _internal_rpc_pickler
28from torch.nn import Module
29from torch.nn.parameter import Parameter
30from torch.utils.hooks import RemovableHandle
31
32
33__all__ = ["RemoteModule"]
34
35_grad_t = Union[Tuple[Tensor, ...], Tensor]
36# See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use
37# of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be
38# the type of the subclass, not the looser type of `Module`.
39T = TypeVar("T", bound="Module")
40
41_NON_SCRIPTABLE_REMOTE_MODULE_MODULE = (
42    instantiator.instantiate_non_scriptable_remote_module_template()
43)
44
45_REMOTE_MODULE_PICKLED_ATTRIBUTES = (
46    "on",
47    "device",
48    "is_device_map_set",
49    "is_scriptable",
50    "generated_methods",
51    "module_rref",
52)
53
54_SerializedRemoteModule = collections.namedtuple("_SerializedRemoteModule", _REMOTE_MODULE_PICKLED_ATTRIBUTES)  # type: ignore[misc]
55
56# These attributes are mostly from RemoteModule's parent class and are intentionally not pickled.
57# A new attribute of RemoteModule should be either in _REMOTE_MODULE_PICKLED_ATTRIBUTES
58# or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING.
59# Otherwise, it will not be pickled.
60_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING = (
61    "training",
62    "_parameters",
63    "_buffers",
64    "_non_persistent_buffers_set",
65    "_backward_hooks",
66    "_backward_pre_hooks",
67    "_is_full_backward_hook",
68    "_forward_hooks",
69    "_forward_hooks_with_kwargs",
70    "_forward_hooks_always_called",
71    "_forward_pre_hooks",
72    "_forward_pre_hooks_with_kwargs",
73    "_state_dict_hooks",
74    "_state_dict_pre_hooks",
75    "_load_state_dict_pre_hooks",
76    "_load_state_dict_post_hooks",
77    "_state_dict_pre_hooks",
78    "_modules",
79    # The two attributes below are generated methods, not available at pickling time.
80    "forward_async",
81    "forward",
82)
83
84
85# RPC handler.
86def _instantiate_template(module_interface_cls, enable_moving_cpu_tensors_to_cuda):
87    instantiator.instantiate_scriptable_remote_module_template(
88        module_interface_cls, enable_moving_cpu_tensors_to_cuda
89    )
90
91
92def _create_module(module_cls, args, kwargs, device):
93    module = module_cls(*args, **kwargs)
94    if not isinstance(module, nn.Module):
95        raise ValueError(
96            "Expect `module_cls(*args, **kwargs)` returns an instance of <class nn.Module>, "
97            f"but it returns an instance of {type(module)}."
98        )
99    module.to(device)
100    return module
101
102
103def _create_module_with_interface(
104    module_cls, args, kwargs, device, module_interface_cls
105):
106    module = _create_module(module_cls, args, kwargs, device)
107    if module_interface_cls is not None:
108        module = torch.jit.script(module)
109    return rpc.RRef(module, module_interface_cls)
110
111
112def _param_rrefs(module_rref, recurse) -> List[rpc.RRef[Parameter]]:
113    ret: List[rpc.RRef[Parameter]] = []
114    for param in module_rref.local_value().parameters(recurse):
115        ret.append(rpc.RRef(param))
116    return ret
117
118
119def _raise_not_supported(name: str) -> None:
120    raise ValueError(f"Method ``{name}`` not supported for RemoteModule")
121
122
123class _RemoteModule(nn.Module):
124    def __new__(cls, *args, **kwargs):
125        # Use __new__ for logging purposes.
126        torch._C._log_api_usage_once("torch.distributed.nn.api.remote_module")
127        return super().__new__(cls)
128
129    def __init__(
130        self,
131        remote_device: str,
132        module_cls: Type[nn.Module],
133        args: Optional[Tuple] = None,
134        kwargs: Optional[Dict[str, Any]] = None,
135        _module_interface_cls: Any = None,
136    ):
137        """
138        RemoteModule instance can only be created after RPC initialization.
139
140        It creates a user-specified module on a specified remote node.
141        It behaves like a regular ``nn.Module`` except that the ``forward`` method is
142        executed on the remote node.
143        It takes care of autograd recording to ensure the backward pass propagates
144        gradients back to the corresponding remote module.
145        It can be shared across processors using `RPC framework <https://pytorch.org/docs/stable/rpc.html>`__,
146        without incurring any overheads of copying the actual module,
147        which is equivalent to an :class:`~torch.distributed.rpc.RRef`
148        pointing to the remote module.
149
150        The arguments of ``forward_async`` and ``forward`` are the same as
151        the ``forward`` method of the module returned by the ``module_cls``.
152
153        Apart from ``forward_async`` and ``forward``, no other methods are supported from nn.Module for now.
154
155        Particularly, to create a hybrid model, typically the local modules should be
156        created outside of remote modules, rather than as submodules of any remote module (by calling ``add_module``).
157        Hybrid Example:
158                >>> class HybridModel(nn.Module):
159                >>>     def __init__(self) -> None:
160                >>>         nn.Module.__init__(self)
161                >>>         self.remote_embedding = RemoteModule(...)
162                >>>         self.local_linear = nn.Linear(...)
163
164        For example, if ``module_cls`` returns an instance of ``nn.Linear``,
165        that has ``forward`` method signature, ``def forward(input: Tensor) -> Tensor:``,
166        the generated ``RemoteModule`` will have 2 methods in signature of
167        ``def forward(input: Tensor) -> Tensor:`` and
168        ``def forward_async(input: Tensor) -> Future[Tensor]:``.
169
170        .. note::
171            If the remote module is placed on a cuda device,
172            any input CPU tensors will be automatically moved to the same cuda device,
173            and GPU tensors are returned over the wire according to the device map of the remote worker on TensorPipe RPC backend.
174
175        Args:
176            remote_device (str): Device on the destination worker where we'd like to place this module.
177                The device can be a local device or a remote device specified by one of the following remote
178                formats:
179
180                    1. "rank:<rank>/<device>" (ex: "rank:0/cuda:0").
181                    2. "<worker_name>/<device>" (ex: "trainer0/cuda:0").
182
183                In addition, the device field can be optional and the default value is "cpu".
184            module_cls (nn.Module): For example,
185                >>> class MyModule(nn.Module):
186                >>>     def forward(input):
187                >>>         return input + 1
188                >>>
189                >>> module_cls = MyModule
190            args (Sequence, optional): args to be passed to ``module_cls``.
191            kwargs (Dict, optional): kwargs to be passed to ``module_cls``.
192            _module_interface_cls (type, optional): The TorchScript interface type for the module
193                to be created. The type object should be decorated by @torch.jit.interface.
194                If not provided, the generated RemoteModule is not torchscript-able.
195                Warning, this is an experimental API and susceptible to frequent changes.
196
197        Returns:
198            A remote module instance which wraps the :class:`~nn.Module` created by the
199            user-provided ``module_cls``, it has a blocking ``forward`` method and an
200            asynchronous ``forward_async`` method that returns a future of the ``forward`` call
201            on the user-provided module on the remote side.
202
203        Example::
204            Run the following code in two different processes:
205
206            >>> # xdoctest: +SKIP("distributed")
207            >>> # On worker 0:
208            >>> import torch
209            >>> import torch.distributed.rpc as rpc
210            >>> from torch import nn, Tensor
211            >>> from torch.distributed.nn.api.remote_module import RemoteModule
212            >>>
213            >>> rpc.init_rpc("worker0", rank=0, world_size=2)
214            >>> remote_linear_module = RemoteModule(
215            >>>     "worker1/cpu", nn.Linear, args=(20, 30),
216            >>> )
217            >>> input = torch.randn(128, 20)
218            >>> ret_fut = remote_linear_module.forward_async(input)
219            >>> ret = ret_fut.wait()
220            >>> rpc.shutdown()
221
222            >>> # On worker 1:
223            >>> import torch
224            >>> import torch.distributed.rpc as rpc
225            >>>
226            >>> rpc.init_rpc("worker1", rank=1, world_size=2)
227            >>> rpc.shutdown()
228        """
229        super().__init__()
230
231        enable_moving_cpu_tensors_to_cuda = self._prepare_init(remote_device)
232
233        # Default arguments preparation.
234        args = args if args is not None else ()
235        kwargs = kwargs if kwargs is not None else {}
236
237        if _module_interface_cls is not None:
238            # Users reply on this field to know if this generated RemoteModule is TorchScript-able.
239            self.is_scriptable = True
240
241            # Instantiate template on remote side.
242            fut = rpc.rpc_async(
243                self.on,
244                _instantiate_template,
245                (_module_interface_cls, enable_moving_cpu_tensors_to_cuda),
246            )
247
248            self._init_template(
249                _module_interface_cls, enable_moving_cpu_tensors_to_cuda
250            )
251
252            # Instantiate template on remote side.
253            fut = rpc.rpc_async(
254                self.on,
255                _instantiate_template,
256                (_module_interface_cls, enable_moving_cpu_tensors_to_cuda),
257            )
258
259            # Create the module on the remote side.
260            fut.wait()  # Ensure remote_module_cls is available on remote side.
261
262            # TODO: We need to change this to rpc.remote, and make it async (see the else branch below).
263            # For that we need to be able to apply _module_interface_cls to the RRef returned by rpc.remote
264            # See https://github.com/pytorch/pytorch/issues/58098 for more context.
265            self.module_rref = rpc.rpc_sync(
266                self.on,
267                _create_module_with_interface,
268                (module_cls, args, kwargs, self.device, _module_interface_cls),
269            )
270        else:
271            self.is_scriptable = False
272            self.generated_methods = (
273                _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods
274            )
275            # Create the module on the remote side.
276            self.module_rref = rpc.remote(
277                self.on,
278                _create_module,
279                (module_cls, args, kwargs, self.device),
280            )
281
282        self._install_generated_methods()
283        self._check_attribute_picklability()
284
285    def remote_parameters(self, recurse: bool = True) -> List[rpc.RRef[Parameter]]:
286        """
287        Return a list of :class:`~torch.distributed.rpc.RRef` pointing to the remote module's parameters.
288
289        This can typically be used in conjunction
290        with :class:`~torch.distributed.optim.DistributedOptimizer`.
291
292        Args:
293            recurse (bool): if True, then returns parameters of the remote
294                module and all submodules of the remote module. Otherwise,
295                returns only parameters that are direct members of the
296                remote module.
297
298        Returns:
299            A list of :class:`~torch.distributed.rpc.RRef` (``List[RRef[nn.Parameter]]``)
300            to remote module's parameters.
301        """
302        return rpc.rpc_sync(self.on, _param_rrefs, args=(self.module_rref, recurse))
303
304    def get_module_rref(self) -> rpc.RRef[nn.Module]:
305        """Return an :class:`~torch.distributed.rpc.RRef` (``RRef[nn.Module]``) pointing to the remote module."""
306        return self.module_rref
307
308    @torch.jit.export
309    def __getstate__(self):
310        raise RuntimeError(
311            "Cannot pickle RemoteModule in python pickler. RemoteModule can only be pickled when using RPC"
312        )
313
314    @torch.jit.export
315    def __setstate__(self, state):
316        raise RuntimeError(
317            "Cannot unpickle RemoteModule in python pickler. RemoteModule can only be unpickled when using RPC"
318        )
319
320    def register_buffer(
321        self, name: str, tensor: Optional[Tensor], persistent: bool = True
322    ) -> None:
323        _raise_not_supported(self.register_buffer.__name__)
324
325    def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
326        _raise_not_supported(self.register_parameter.__name__)
327
328    def add_module(self, name: str, module: Optional[Module]) -> None:
329        _raise_not_supported(self.add_module.__name__)
330
331    def apply(self: T, fn: Callable[[Module], None]) -> T:  # type: ignore[return]
332        _raise_not_supported(self.apply.__name__)
333
334    def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:  # type: ignore[return]
335        _raise_not_supported(self.cuda.__name__)
336
337    def ipu(self: T, device: Optional[Union[int, device]] = None) -> T:  # type: ignore[return]
338        _raise_not_supported(self.ipu.__name__)
339
340    def xpu(self: T, device: Optional[Union[int, device]] = None) -> T:  # type: ignore[return]
341        _raise_not_supported(self.xpu.__name__)
342
343    def cpu(self: T) -> T:  # type: ignore[return]
344        _raise_not_supported(self.cpu.__name__)
345
346    def type(self: T, dst_type: Union[dtype, str]) -> T:  # type: ignore[return]
347        _raise_not_supported(self.type.__name__)
348
349    def float(self: T) -> T:  # type: ignore[return]
350        _raise_not_supported(self.float.__name__)
351
352    def double(self: T) -> T:  # type: ignore[return]
353        _raise_not_supported(self.double.__name__)
354
355    def half(self: T) -> T:  # type: ignore[return]
356        _raise_not_supported(self.half.__name__)
357
358    def bfloat16(self: T) -> T:  # type: ignore[return]
359        _raise_not_supported(self.bfloat16.__name__)
360
361    def to(self, *args, **kwargs) -> T:  # type: ignore[misc, return, type-var]
362        _raise_not_supported(self.to.__name__)
363
364    def register_backward_hook(  # type: ignore[return]
365        self, hook: Callable[[Module, _grad_t, _grad_t], Union[None, _grad_t]]
366    ) -> RemovableHandle:
367        _raise_not_supported(self.register_backward_hook.__name__)
368
369    def register_forward_pre_hook(  # type: ignore[return]
370        self,
371        hook: Union[
372            Callable[[T, Tuple[Any, ...]], Optional[Any]],
373            Callable[
374                [T, Tuple[Any, ...], Dict[str, Any]],
375                Optional[Tuple[Any, Dict[str, Any]]],
376            ],
377        ],
378        prepend: bool = False,
379        with_kwargs: bool = False,
380    ) -> RemovableHandle:
381        _raise_not_supported(self.register_forward_pre_hook.__name__)
382
383    def register_forward_hook(  # type: ignore[return, override]
384        self,
385        hook: Union[
386            Callable[[T, Tuple[Any, ...], Any], Optional[Any]],
387            Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],
388        ],
389        prepend: bool = False,
390        with_kwargs: bool = False,
391    ) -> RemovableHandle:
392        _raise_not_supported(self.register_forward_hook.__name__)
393
394    def state_dict(self, *args, **kwargs):
395        _raise_not_supported(self.state_dict.__name__)
396
397    def load_state_dict(
398        self,
399        state_dict: Mapping[str, Any],
400        strict: bool = True,
401        assign: bool = False,
402    ):
403        _raise_not_supported(self.load_state_dict.__name__)
404
405    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
406        raise ValueError(
407            "Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead."
408        )
409
410    def named_parameters(  # type: ignore[return]
411        self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
412    ) -> Iterator[Tuple[str, Parameter]]:
413        _raise_not_supported(self.named_parameters.__name__)
414
415    def buffers(self, recurse: bool = True) -> Iterator[Tensor]:  # type: ignore[return]
416        _raise_not_supported(self.buffers.__name__)
417
418    def named_buffers(  # type: ignore[return]
419        self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
420    ) -> Iterator[Tuple[str, Tensor]]:
421        _raise_not_supported(self.named_buffers.__name__)
422
423    def children(self) -> Iterator[Module]:  # type: ignore[return]
424        _raise_not_supported(self.children.__name__)
425
426    def named_children(self) -> Iterator[Tuple[str, Module]]:  # type: ignore[return]
427        _raise_not_supported(self.named_children.__name__)
428
429    def modules(self) -> Iterator[Module]:  # type: ignore[return]
430        _raise_not_supported(self.modules.__name__)
431
432    def named_modules(
433        self,
434        memo: Optional[Set[Module]] = None,
435        prefix: str = "",
436        remove_duplicate: bool = True,
437    ):
438        _raise_not_supported(self.named_modules.__name__)
439
440    def train(self: T, mode: bool = True) -> T:
441        return self.module_rref.rpc_sync().train()  # type: ignore[operator, union-attr]
442
443    def eval(self: T) -> T:
444        return self.module_rref.rpc_sync().eval()  # type: ignore[operator, union-attr]
445
446    def requires_grad_(self: T, requires_grad: bool = True) -> T:  # type: ignore[return]
447        _raise_not_supported(self.requires_grad_.__name__)
448
449    def zero_grad(self, set_to_none: bool = True) -> None:
450        _raise_not_supported(self.zero_grad.__name__)
451
452    def share_memory(self: T) -> T:  # type: ignore[return]
453        _raise_not_supported(self.share_memory.__name__)
454
455    def extra_repr(self) -> str:  # type: ignore[return]
456        _raise_not_supported(self.extra_repr.__name__)
457
458    def _prepare_init(self, remote_device_str: str) -> bool:
459        """Prepare the initialization and returns whether to enable automatically moving CPU tensors to CUDA devices."""
460        # Sanity check.
461        assert rpc._is_current_rpc_agent_set(), "RemoteModule only works in RPC."
462
463        remote_device = _remote_device(remote_device_str)
464        self.on = (
465            remote_device.worker_name()
466            if remote_device.worker_name() is not None
467            else remote_device.rank()
468        )
469        self.device = str(remote_device.device())
470        agent = rpc._get_current_rpc_agent()
471        # If the device map of the remote worker is set,
472        # then enable moving any input CPU tensors to the same cuda device.
473        self.is_device_map_set = bool(
474            agent._get_device_map(agent.get_worker_info(self.on))  # type: ignore[arg-type]
475        )
476        # ``enable_moving_cpu_tensors_to_cuda`` is less strict than ``is_device_map_set``:
477        # If ``enable_moving_cpu_tensors_to_cuda`` is true, but the device map is not set,
478        # then any CPU tensors can still be moved to a cuda device to run forward,
479        # but the output must be moved back to CPU before being sent over the wire.
480        enable_moving_cpu_tensors_to_cuda = torch.device(self.device).type == "cuda"
481        return enable_moving_cpu_tensors_to_cuda
482
483    def _init_template(self, module_interface_cls, enable_moving_cpu_tensors_to_cuda):
484        """Instantiate template on local side."""
485        generated_module = instantiator.instantiate_scriptable_remote_module_template(
486            module_interface_cls, enable_moving_cpu_tensors_to_cuda
487        )
488        self.generated_methods = generated_module._generated_methods
489
490    def _check_attribute_picklability(self):
491        """Check if all the attribute has explicitly defined whether to be pickled (i.e., picklability)."""
492        for k in self.__dict__.keys():
493            if (
494                k not in _REMOTE_MODULE_PICKLED_ATTRIBUTES
495                and k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING
496            ):
497                raise AttributeError(
498                    f"Attribute {k} must be either in ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` or "
499                    "``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``."
500                )
501
502    def _install_generated_methods(self):
503        for method in self.generated_methods:
504            method_name = method.__name__
505            method = torch.jit.export(method)
506            setattr(self, method_name, types.MethodType(method, self))
507
508    @staticmethod
509    def init_from_module_rref(
510        remote_device: str,
511        module_rref: rpc.RRef[nn.Module],
512        _module_interface_cls: Any = None,
513    ):
514        """
515        Besides the constructor, a RemoteModule instance can also be initialized given a module RRef.
516
517        This alternate initialization method can be particularly useful if we want to create multiple
518        RemoteModule instances that share the same underlying module and reduce memory consumption.
519
520        Moreover, this also provides a workaround for passing script RemoteModule over RPC,
521        which is not supported. The recommended way is as follows:
522
523            1. the sender creates a RemoteModule;
524            2. the sender sends its ``module_rref`` over RPC;
525            3. the receiver calls this method to initialize another RemoteModule using the same ``module_rref``.
526
527        Example::
528            Run the following code in two different processes:
529
530            >>> # xdoctest: +SKIP("distributed")
531            >>> # On worker 0:
532            >>> import torch
533            >>> import torch.distributed.rpc as rpc
534            >>> from torch import nn, Tensor
535            >>> from torch.distributed.nn.api.remote_module import RemoteModule
536            >>>
537            >>> rpc.init_rpc("worker0", rank=0, world_size=2)
538            >>> remote_module = RemoteModule(
539            >>>     "worker1/cpu", nn.Linear, args=(20, 30),
540            >>> )
541            >>>
542            >>> remote_module1 = rpc.rpc_sync(
543            >>>     "worker1/cpu",
544            >>>     RemoteModule.init_from_module_rref,
545            >>>     ("worker1/cpu", remote_module1.get_module_rref()),
546            >>> )
547            >>> rpc.shutdown()
548
549            >>> # On worker 1:
550            >>> import torch
551            >>> import torch.distributed.rpc as rpc
552            >>>
553            >>> rpc.init_rpc("worker1", rank=1, world_size=2)
554            >>> rpc.shutdown()
555
556        Args:
557            remote_device (str): Device on the destination worker where we'd like to place this module.
558                The device can be a local device or a remote device specified by one of the following remote
559                formats:
560
561                    1. "rank:<rank>/<device>" (ex: "rank:0/cuda:0").
562                    2. "<worker_name>/<device>" (ex: "trainer0/cuda:0").
563
564                In addition, the device field can be optional and the default value is "cpu".
565            module_rref (RRef[nn.Module]): The module reference shared by both the caller and
566                the created remote module.
567            _module_interface_cls (type, optional): The TorchScript interface type for the module
568                to be created. The type object should be decorated by @torch.jit.interface.
569                If not provided, the generated RemoteModule is not torchscript-able.
570                Warning, this is an experimental API and susceptible to frequent changes.
571
572        Returns:
573            A remote module instance which wraps the :class:`~nn.Module` created by the
574            user-provided ``module_rref``, it has a blocking ``forward`` method and an
575            asynchronous ``forward_async`` method that returns a future of the ``forward`` call
576            on the user-provided module on the remote side.
577        """
578        # NOTE: if a new attribute is added to this class, also need to add it
579        # to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` for pickling/unpickling.
580
581        remote_module = object.__new__(RemoteModule)
582
583        enable_moving_cpu_tensors_to_cuda = remote_module._prepare_init(remote_device)
584
585        if _module_interface_cls is not None:
586            # Users reply on this field to know if this generated RemoteModule is TorchScript-able.
587            remote_module.is_scriptable = True
588
589            remote_module._init_template(
590                _module_interface_cls, enable_moving_cpu_tensors_to_cuda
591            )
592        else:
593            remote_module.is_scriptable = False
594            remote_module.generated_methods = (
595                _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods
596            )
597        remote_module.module_rref = module_rref
598
599        remote_module._install_generated_methods()
600        remote_module._check_attribute_picklability()
601
602        return remote_module
603
604
605class RemoteModule(_RemoteModule):
606    """
607        A RemoteModule instance can only be created after RPC initialization.
608
609        It creates a user-specified module on a specified remote node.
610        It behaves like a regular ``nn.Module`` except that the ``forward`` method is
611        executed on the remote node.
612        It takes care of autograd recording to ensure the backward pass propagates
613        gradients back to the corresponding remote module.
614
615        It generates two methods ``forward_async`` and ``forward`` based on the
616        signature of the ``forward`` method of ``module_cls``. ``forward_async``
617        runs asynchronously and returns a Future. The arguments of ``forward_async``
618        and ``forward`` are the same as the ``forward`` method of the module
619        returned by the ``module_cls``.
620
621        For example, if ``module_cls`` returns an instance of ``nn.Linear``,
622        that has ``forward`` method signature: ``def forward(input: Tensor) -> Tensor:``,
623        the generated ``RemoteModule`` will have 2 methods with the signatures:
624
625        | ``def forward(input: Tensor) -> Tensor:``
626        | ``def forward_async(input: Tensor) -> Future[Tensor]:``
627
628    Args:
629        remote_device (str): Device on the destination worker where we'd like to place this module.
630            The format should be "<workername>/<device>", where the device field can be parsed as torch.device type.
631            E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
632            In addition, the device field can be optional and the default value is "cpu".
633        module_cls (nn.Module): Class for the module to be created remotely. For example,
634
635            >>> class MyModule(nn.Module):
636            >>>     def forward(input):
637            >>>         return input + 1
638            >>>
639            >>> module_cls = MyModule
640
641        args (Sequence, optional): args to be passed to ``module_cls``.
642        kwargs (Dict, optional): kwargs to be passed to ``module_cls``.
643
644    Returns:
645        A remote module instance which wraps the :class:`~nn.Module` created by the
646        user-provided ``module_cls``, it has a blocking ``forward`` method and an
647        asynchronous ``forward_async`` method that returns a future of the ``forward`` call
648        on the user-provided module on the remote side.
649
650    Example::
651        Run the following code in two different processes:
652
653        >>> # xdoctest: +SKIP("distributed")
654        >>> # On worker 0:
655        >>> import torch
656        >>> import torch.distributed.rpc as rpc
657        >>> from torch import nn, Tensor
658        >>> from torch.distributed.nn.api.remote_module import RemoteModule
659        >>>
660        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
661        >>> remote_linear_module = RemoteModule(
662        >>>     "worker1/cpu", nn.Linear, args=(20, 30),
663        >>> )
664        >>> input = torch.randn(128, 20)
665        >>> ret_fut = remote_linear_module.forward_async(input)
666        >>> ret = ret_fut.wait()
667        >>> rpc.shutdown()
668
669        >>> # On worker 1:
670        >>> import torch
671        >>> import torch.distributed.rpc as rpc
672        >>>
673        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
674        >>> rpc.shutdown()
675
676        Furthermore, a more practical example that is combined with
677        `DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__ (DDP)
678        can be found in this `tutorial <https://pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html>`__.
679    """
680
681    def __init__(
682        self,
683        remote_device: str,
684        module_cls: Type[nn.Module],
685        args: Optional[Tuple] = None,
686        kwargs: Optional[Dict[str, Any]] = None,
687    ):
688        super().__init__(remote_device, module_cls, args, kwargs)
689
690
691def _remote_module_receiver(
692    *remote_module_pickled_attrs,
693):
694    """Deserializes a RemoteModule."""
695    serialized_remote_module = _SerializedRemoteModule._make(
696        remote_module_pickled_attrs
697    )
698    m = object.__new__(RemoteModule)
699    m.__dict__.update(serialized_remote_module._asdict())
700
701    # Unpickling the attribute `module_rref` must invoke RRef's `_deserialize()` method.
702    m.module_rref = rpc.PyRRef._deserialize(m.module_rref)
703
704    # Install generated methods when unpickled.
705    for method in m.generated_methods:
706        method_name = method.__name__
707        method = torch.jit.export(method)
708        setattr(m, method_name, types.MethodType(method, m))
709
710    return m
711
712
713def _remote_module_reducer(remote_module):
714    """Serialize a RemoteModule."""
715    pickled_attrs = {}
716    for k, v in remote_module.__dict__.items():
717        # Pickling the attribute `module_rref` must invoke RRef's `_serialize()` method.
718        if k == "module_rref":
719            pickled_attrs[k] = v._serialize()
720        elif k in _REMOTE_MODULE_PICKLED_ATTRIBUTES:
721            pickled_attrs[k] = v
722        # Check if unpickled attributes are all in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING.
723        elif k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING:
724            print(
725                f"The new attribute ``{k}`` of RemoteModule is ignored during RPC pickling. "
726                "To pickle this attribute, please add it to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES``. "
727                "Otherwise, please explicitly add it to ``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``.",
728                file=sys.stderr,
729            )
730
731    return (
732        _remote_module_receiver,
733        tuple(pickled_attrs.values()),
734    )
735
736
737def _recursive_script_module_receiver(
738    recursive_script_module_serialized,
739):
740    """Deserializes a RecursiveScriptModule that does not contain a script RemoteModule."""
741    f = io.BytesIO(recursive_script_module_serialized)
742    m = torch.jit.load(f)
743    return m
744
745
746def _recursive_script_module_reducer(recursive_script_module):
747    """Serialize a RecursiveScriptModule that does not contain a script RemoteModule, and raises an error otherwise."""
748    if hasattr(recursive_script_module._c, "module_rref"):
749        raise RuntimeError(
750            "Passing a script RemoteModule over RPC is not supported. Please create a RemoteModule in the sender, "
751            "send the `module_rref` to the receiver, and create a new instance on the receiver end by passing this `module_rref`."
752        )
753
754    f = io.BytesIO()
755    torch.jit.save(recursive_script_module, f)
756    return (_recursive_script_module_receiver, (f.getvalue(),))
757
758
759_internal_rpc_pickler._register_reducer(RemoteModule, _remote_module_reducer)
760_internal_rpc_pickler._register_reducer(
761    torch.jit.RecursiveScriptModule, _recursive_script_module_reducer
762)
763