xref: /aosp_15_r20/external/pytorch/torch/distributed/nn/jit/templates/remote_module_template.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/python3
2# mypy: allow-untyped-defs
3
4
5def get_remote_module_template(enable_moving_cpu_tensors_to_cuda: bool):
6    return _TEMPLATE_PREFIX + (
7        _REMOTE_FORWARD_TEMPLATE_ENABLE_MOVING_CPU_TENSORS_TO_CUDA
8        if enable_moving_cpu_tensors_to_cuda
9        else _REMOTE_FORWARD_TEMPLATE
10    )
11
12
13_TEMPLATE_PREFIX = """from typing import *
14
15import torch
16import torch.distributed.rpc as rpc
17from torch import Tensor
18from torch._jit_internal import Future
19from torch.distributed.rpc import RRef
20from typing import Tuple  # pyre-ignore: unused import
21
22
23{assign_module_interface_cls}
24
25
26def forward_async(self, {arg_types}){arrow_and_future_return_type}:
27    args = (self.module_rref, self.device, self.is_device_map_set, {args})
28    kwargs = {{{kwargs}}}
29    return rpc.rpc_async(
30        self.module_rref.owner(),
31        _remote_forward,
32        args,
33        kwargs,
34    )
35
36
37def forward(self, {arg_types}){arrow_and_return_type}:
38    args = (self.module_rref, self.device, self.is_device_map_set, {args})
39    kwargs = {{{kwargs}}}
40    ret_fut = rpc.rpc_async(
41        self.module_rref.owner(),
42        _remote_forward,
43        args,
44        kwargs,
45    )
46    return ret_fut.wait()
47
48
49_generated_methods = [
50    forward_async,
51    forward,
52]
53
54
55{jit_script_decorator}
56"""
57
58# This template may cause typing error (the mismatch between ``Tuple[()]`` and ``Tuple[Any]``)
59# even if the code is only used for instantiation but not execution.
60# Therefore, only include handling moving CPU tensors to a cuda device if necessary.
61# TODO: Merge these two templates together in the future once TorchScript syntax is improved.
62_REMOTE_FORWARD_TEMPLATE_ENABLE_MOVING_CPU_TENSORS_TO_CUDA = """
63def _remote_forward(
64    module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, {arg_types}){arrow_and_return_type}:
65    module = module_rref.local_value()
66    device = torch.device(device)
67
68    if device.type != "cuda":
69        return module.forward({args}, {kwargs})
70
71    # If the module is on a cuda device,
72    # move any CPU tensor in args or kwargs to the same cuda device.
73    # Since torch script does not support generator expression,
74    # have to use concatenation instead of
75    # ``tuple(i.to(device) if isinstance(i, Tensor) else i for i in *args)``.
76    args = ({args},)
77    out_args: Tuple[()] = ()
78    for arg in args:
79        arg = (arg.to(device),) if isinstance(arg, Tensor) else (arg,)
80        out_args = out_args + arg
81
82    kwargs = {{{kwargs}}}
83    for k, v in kwargs.items():
84        if isinstance(v, Tensor):
85            kwargs[k] = kwargs[k].to(device)
86
87    if is_device_map_set:
88        return module.forward(*out_args, {kwargs})
89
90    # If the device map is empty, then only CPU tensors are allowed to send over wire,
91    # so have to move any GPU tensor to CPU in the output.
92    # Since torch script does not support generator expression,
93    # have to use concatenation instead of
94    # ``tuple(i.cpu() if isinstance(i, Tensor) else i for i in module.forward(*out_args, {kwargs}))``.
95    ret: Tuple[()] = ()
96    for i in module.forward(*out_args, {kwargs}):
97        i = (i.cpu(),) if isinstance(i, Tensor) else (i,)
98        ret = ret + i
99    return ret
100"""
101
102_REMOTE_FORWARD_TEMPLATE = """
103def _remote_forward(
104    module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, {arg_types}){arrow_and_return_type}:
105    module = module_rref.local_value()
106
107    return module.forward({args}, {kwargs})
108"""
109