1import threading 2import torch 3from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast 4from ..modules import Module 5from torch.cuda._utils import _get_device_index 6from torch._utils import ExceptionWrapper 7 8__all__ = ['get_a_var', 'parallel_apply'] 9 10def get_a_var(obj: Union[torch.Tensor, List[Any], Tuple[Any, ...], Dict[Any, Any]]) -> Optional[torch.Tensor]: 11 if isinstance(obj, torch.Tensor): 12 return obj 13 14 if isinstance(obj, (list, tuple)): 15 for result in map(get_a_var, obj): 16 if isinstance(result, torch.Tensor): 17 return result 18 if isinstance(obj, dict): 19 for result in map(get_a_var, obj.items()): 20 if isinstance(result, torch.Tensor): 21 return result 22 return None 23 24def parallel_apply( 25 modules: Sequence[Module], 26 inputs: Sequence[Any], 27 kwargs_tup: Optional[Sequence[Dict[str, Any]]] = None, 28 devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None, 29) -> List[Any]: 30 r"""Apply each `module` in :attr:`modules` in parallel on each of :attr:`devices`. 31 32 Args: 33 modules (Module): modules to be parallelized 34 inputs (tensor): inputs to the modules 35 devices (list of int or torch.device): CUDA devices 36 37 :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and 38 :attr:`devices` (if given) should all have same length. Moreover, each 39 element of :attr:`inputs` can either be a single object as the only argument 40 to a module, or a collection of positional arguments. 41 """ 42 assert len(modules) == len(inputs), f'The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}' 43 if kwargs_tup is not None: 44 assert len(modules) == len(kwargs_tup) 45 else: 46 kwargs_tup = (cast(Dict[str, Any], {}),) * len(modules) 47 if devices is not None: 48 assert len(modules) == len(devices) 49 else: 50 devices = [None] * len(modules) 51 devices = [_get_device_index(x, True) for x in devices] 52 streams = [torch.cuda.current_stream(x) for x in devices] 53 lock = threading.Lock() 54 results = {} 55 grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled() 56 57 def _worker( 58 i: int, 59 module: Module, 60 input: Any, 61 kwargs: Dict[str, Any], 62 device: Optional[Union[int, torch.device]] = None, 63 stream: Optional[torch.cuda.Stream] = None, 64 ) -> None: 65 torch.set_grad_enabled(grad_enabled) 66 if device is None: 67 t = get_a_var(input) 68 if t is None: 69 with lock: 70 results[i] = ExceptionWrapper( 71 where=f"in replica {i}, no device was provided and no tensor input was found; " 72 "device cannot be resolved") 73 return 74 device = t.get_device() 75 if stream is None: 76 stream = torch.cuda.current_stream(device) 77 try: 78 with torch.cuda.device(device), torch.cuda.stream( 79 stream 80 ), torch.amp.autocast("cuda", enabled=autocast_enabled): 81 # this also avoids accidental slicing of `input` if it is a Tensor 82 if not isinstance(input, (list, tuple)): 83 input = (input,) 84 output = module(*input, **kwargs) 85 with lock: 86 results[i] = output 87 except Exception: 88 with lock: 89 results[i] = ExceptionWrapper( 90 where=f"in replica {i} on device {device}") 91 92 if len(modules) > 1: 93 threads = [threading.Thread(target=_worker, 94 args=(i, module, input, kwargs, device, stream)) 95 for i, (module, input, kwargs, device, stream) in 96 enumerate(zip(modules, inputs, kwargs_tup, devices, streams))] 97 98 for thread in threads: 99 thread.start() 100 for thread in threads: 101 thread.join() 102 else: 103 _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0]) 104 105 outputs = [] 106 for i in range(len(inputs)): 107 output = results[i] 108 if isinstance(output, ExceptionWrapper): 109 output.reraise() 110 outputs.append(output) 111 return outputs 112