xref: /aosp_15_r20/external/pytorch/torch/nn/parallel/parallel_apply.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import threading
2import torch
3from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
4from ..modules import Module
5from torch.cuda._utils import _get_device_index
6from torch._utils import ExceptionWrapper
7
8__all__ = ['get_a_var', 'parallel_apply']
9
10def get_a_var(obj: Union[torch.Tensor, List[Any], Tuple[Any, ...], Dict[Any, Any]]) -> Optional[torch.Tensor]:
11    if isinstance(obj, torch.Tensor):
12        return obj
13
14    if isinstance(obj, (list, tuple)):
15        for result in map(get_a_var, obj):
16            if isinstance(result, torch.Tensor):
17                return result
18    if isinstance(obj, dict):
19        for result in map(get_a_var, obj.items()):
20            if isinstance(result, torch.Tensor):
21                return result
22    return None
23
24def parallel_apply(
25    modules: Sequence[Module],
26    inputs: Sequence[Any],
27    kwargs_tup: Optional[Sequence[Dict[str, Any]]] = None,
28    devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None,
29) -> List[Any]:
30    r"""Apply each `module` in :attr:`modules` in parallel on each of :attr:`devices`.
31
32    Args:
33        modules (Module): modules to be parallelized
34        inputs (tensor): inputs to the modules
35        devices (list of int or torch.device): CUDA devices
36
37    :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
38    :attr:`devices` (if given) should all have same length. Moreover, each
39    element of :attr:`inputs` can either be a single object as the only argument
40    to a module, or a collection of positional arguments.
41    """
42    assert len(modules) == len(inputs), f'The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}'
43    if kwargs_tup is not None:
44        assert len(modules) == len(kwargs_tup)
45    else:
46        kwargs_tup = (cast(Dict[str, Any], {}),) * len(modules)
47    if devices is not None:
48        assert len(modules) == len(devices)
49    else:
50        devices = [None] * len(modules)
51    devices = [_get_device_index(x, True) for x in devices]
52    streams = [torch.cuda.current_stream(x) for x in devices]
53    lock = threading.Lock()
54    results = {}
55    grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()
56
57    def _worker(
58        i: int,
59        module: Module,
60        input: Any,
61        kwargs: Dict[str, Any],
62        device: Optional[Union[int, torch.device]] = None,
63        stream: Optional[torch.cuda.Stream] = None,
64    ) -> None:
65        torch.set_grad_enabled(grad_enabled)
66        if device is None:
67            t = get_a_var(input)
68            if t is None:
69                with lock:
70                    results[i] = ExceptionWrapper(
71                        where=f"in replica {i}, no device was provided and no tensor input was found; "
72                        "device cannot be resolved")
73                return
74            device = t.get_device()
75        if stream is None:
76            stream = torch.cuda.current_stream(device)
77        try:
78            with torch.cuda.device(device), torch.cuda.stream(
79                stream
80            ), torch.amp.autocast("cuda", enabled=autocast_enabled):
81                # this also avoids accidental slicing of `input` if it is a Tensor
82                if not isinstance(input, (list, tuple)):
83                    input = (input,)
84                output = module(*input, **kwargs)
85            with lock:
86                results[i] = output
87        except Exception:
88            with lock:
89                results[i] = ExceptionWrapper(
90                    where=f"in replica {i} on device {device}")
91
92    if len(modules) > 1:
93        threads = [threading.Thread(target=_worker,
94                                    args=(i, module, input, kwargs, device, stream))
95                   for i, (module, input, kwargs, device, stream) in
96                   enumerate(zip(modules, inputs, kwargs_tup, devices, streams))]
97
98        for thread in threads:
99            thread.start()
100        for thread in threads:
101            thread.join()
102    else:
103        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0])
104
105    outputs = []
106    for i in range(len(inputs)):
107        output = results[i]
108        if isinstance(output, ExceptionWrapper):
109            output.reraise()
110        outputs.append(output)
111    return outputs
112