1#!/usr/bin/python3 2# mypy: allow-untyped-defs 3 4 5def get_remote_module_template(enable_moving_cpu_tensors_to_cuda: bool): 6 return _TEMPLATE_PREFIX + ( 7 _REMOTE_FORWARD_TEMPLATE_ENABLE_MOVING_CPU_TENSORS_TO_CUDA 8 if enable_moving_cpu_tensors_to_cuda 9 else _REMOTE_FORWARD_TEMPLATE 10 ) 11 12 13_TEMPLATE_PREFIX = """from typing import * 14 15import torch 16import torch.distributed.rpc as rpc 17from torch import Tensor 18from torch._jit_internal import Future 19from torch.distributed.rpc import RRef 20from typing import Tuple # pyre-ignore: unused import 21 22 23{assign_module_interface_cls} 24 25 26def forward_async(self, {arg_types}){arrow_and_future_return_type}: 27 args = (self.module_rref, self.device, self.is_device_map_set, {args}) 28 kwargs = {{{kwargs}}} 29 return rpc.rpc_async( 30 self.module_rref.owner(), 31 _remote_forward, 32 args, 33 kwargs, 34 ) 35 36 37def forward(self, {arg_types}){arrow_and_return_type}: 38 args = (self.module_rref, self.device, self.is_device_map_set, {args}) 39 kwargs = {{{kwargs}}} 40 ret_fut = rpc.rpc_async( 41 self.module_rref.owner(), 42 _remote_forward, 43 args, 44 kwargs, 45 ) 46 return ret_fut.wait() 47 48 49_generated_methods = [ 50 forward_async, 51 forward, 52] 53 54 55{jit_script_decorator} 56""" 57 58# This template may cause typing error (the mismatch between ``Tuple[()]`` and ``Tuple[Any]``) 59# even if the code is only used for instantiation but not execution. 60# Therefore, only include handling moving CPU tensors to a cuda device if necessary. 61# TODO: Merge these two templates together in the future once TorchScript syntax is improved. 62_REMOTE_FORWARD_TEMPLATE_ENABLE_MOVING_CPU_TENSORS_TO_CUDA = """ 63def _remote_forward( 64 module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, {arg_types}){arrow_and_return_type}: 65 module = module_rref.local_value() 66 device = torch.device(device) 67 68 if device.type != "cuda": 69 return module.forward({args}, {kwargs}) 70 71 # If the module is on a cuda device, 72 # move any CPU tensor in args or kwargs to the same cuda device. 73 # Since torch script does not support generator expression, 74 # have to use concatenation instead of 75 # ``tuple(i.to(device) if isinstance(i, Tensor) else i for i in *args)``. 76 args = ({args},) 77 out_args: Tuple[()] = () 78 for arg in args: 79 arg = (arg.to(device),) if isinstance(arg, Tensor) else (arg,) 80 out_args = out_args + arg 81 82 kwargs = {{{kwargs}}} 83 for k, v in kwargs.items(): 84 if isinstance(v, Tensor): 85 kwargs[k] = kwargs[k].to(device) 86 87 if is_device_map_set: 88 return module.forward(*out_args, {kwargs}) 89 90 # If the device map is empty, then only CPU tensors are allowed to send over wire, 91 # so have to move any GPU tensor to CPU in the output. 92 # Since torch script does not support generator expression, 93 # have to use concatenation instead of 94 # ``tuple(i.cpu() if isinstance(i, Tensor) else i for i in module.forward(*out_args, {kwargs}))``. 95 ret: Tuple[()] = () 96 for i in module.forward(*out_args, {kwargs}): 97 i = (i.cpu(),) if isinstance(i, Tensor) else (i,) 98 ret = ret + i 99 return ret 100""" 101 102_REMOTE_FORWARD_TEMPLATE = """ 103def _remote_forward( 104 module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, {arg_types}){arrow_and_return_type}: 105 module = module_rref.local_value() 106 107 return module.forward({args}, {kwargs}) 108""" 109