xref: /aosp_15_r20/external/pytorch/torch/futures/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4from typing import cast, Callable, Generic, List, Optional, Type, TypeVar, Union
5
6import torch
7
8__all__ = ['Future', 'collect_all', 'wait_all']
9
10T = TypeVar("T")
11S = TypeVar("S")
12
13
14class _PyFutureMeta(type(torch._C.Future), type(Generic)):  # type: ignore[misc, no-redef]
15    pass
16
17
18class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
19    r"""
20    Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
21    execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
22    also exposes a set of APIs to add callback functions and set results.
23
24    .. warning:: GPU support is a beta feature, subject to changes.
25    """
26
27    def __init__(self, *, devices: Optional[List[Union[int, str, torch.device]]] = None):
28        r"""
29        Create an empty unset ``Future``. If the future is intended to hold
30        values containing CUDA tensors, (a superset of) their CUDA devices must
31        be specified at construction. (This is only supported if
32        ``torch.cuda.is_available()`` returns ``True``). This is needed to
33        ensure proper CUDA stream synchronization. The child futures, returned
34        by the ``then`` method, will inherit these devices.
35
36        Args:
37            devices(``List[Union[int, str, torch.device]]``, optional): the set
38                of devices on which tensors contained in this future's value are
39                allowed to reside and on which callbacks are allowed to operate.
40        """
41        if devices is None:
42            devices = []
43        super().__init__([torch.device(d) for d in devices])
44
45    def done(self) -> bool:
46        r"""
47        Return ``True`` if this ``Future`` is done. A ``Future`` is done if it
48        has a result or an exception.
49
50        If the value contains tensors that reside on GPUs, ``Future.done()``
51        will return ``True`` even if the asynchronous kernels that are
52        populating those tensors haven't yet completed running on the device,
53        because at such stage the result is already usable, provided one
54        performs the appropriate synchronizations (see :meth:`wait`).
55        """
56        return super().done()
57
58    def wait(self) -> T:
59        r"""
60        Block until the value of this ``Future`` is ready.
61
62        If the value contains tensors that reside on GPUs, then an additional
63        synchronization is performed with the kernels (executing on the device)
64        which may be asynchronously populating those tensors. Such sync is
65        non-blocking, which means that ``wait()`` will insert the necessary
66        instructions in the current streams to ensure that further operations
67        enqueued on those streams will be properly scheduled after the async
68        kernels but, once that is done, ``wait()`` will return, even if those
69        kernels are still running. No further synchronization is required when
70        accessing and using the values, as long as one doesn't change streams.
71
72        Returns:
73            The value held by this ``Future``. If the function (callback or RPC)
74            creating the value has thrown an error, this ``wait`` method will
75            also throw an error.
76        """
77        return super().wait()
78
79    def value(self) -> T:
80        r"""
81        Obtain the value of an already-completed future.
82
83        This method should only be called after a call to :meth:`wait` has
84        completed, or inside a callback function passed to :meth:`then`. In
85        other cases this ``Future`` may not yet hold a value and calling
86        ``value()`` could fail.
87
88        If the value contains tensors that reside on GPUs, then this method will
89        *not* perform any additional synchronization. This should be done
90        beforehand, separately, through a call to :meth:`wait` (except within
91        callbacks, for which it's already being taken care of by :meth:`then`).
92
93        Returns:
94            The value held by this ``Future``. If the function (callback or RPC)
95            creating the value has thrown an error, this ``value()`` method will
96            also throw an error.
97        """
98        return super().value()
99
100    def then(self, callback: Callable[[Future[T]], S]) -> Future[S]:
101        r"""
102        Append the given callback function to this ``Future``, which will be run
103        when the ``Future`` is completed.  Multiple callbacks can be added to
104        the same ``Future``, but the order in which they will be executed cannot
105        be guaranteed (to enforce a certain order consider chaining:
106        ``fut.then(cb1).then(cb2)``). The callback must take one argument, which
107        is the reference to this ``Future``. The callback function can use the
108        :meth:`value` method to get the value. Note that if this ``Future`` is
109        already completed, the given callback will be run immediately inline.
110
111        If the ``Future``'s value contains tensors that reside on GPUs, the
112        callback might be invoked while the async kernels that are populating
113        those tensors haven't yet finished executing on the device. However, the
114        callback will be invoked with some dedicated streams set as current
115        (fetched from a global pool) which will be synchronized with those
116        kernels. Hence any operation performed by the callback on these tensors
117        will be scheduled on the device after the kernels complete. In other
118        words, as long as the callback doesn't switch streams, it can safely
119        manipulate the result without any additional synchronization. This is
120        similar to the non-blocking behavior of :meth:`wait`.
121
122        Similarly, if the callback returns a value that contains tensors that
123        reside on a GPU, it can do so even if the kernels that are producing
124        these tensors are still running on the device, as long as the callback
125        didn't change streams during its execution. If one wants to change
126        streams, one must be careful to re-synchronize them with the original
127        streams, that is, those that were current when the callback was invoked.
128
129        Args:
130            callback(``Callable``): a ``Callable`` that takes this ``Future`` as
131                                    the only argument.
132
133        Returns:
134            A new ``Future`` object that holds the return value of the
135            ``callback`` and will be marked as completed when the given
136            ``callback`` finishes.
137
138        .. note:: Note that if the callback function throws, either
139            through the original future being completed with an exception and
140            calling ``fut.wait()``, or through other code in the callback, the
141            future returned by ``then`` will be marked appropriately with the
142            encountered error. However, if this callback later completes
143            additional futures, those futures are not marked as completed with
144            an error and the user is responsible for handling completion/waiting
145            on those futures independently.
146
147        Example::
148            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
149            >>> def callback(fut):
150            ...     print(f"RPC return value is {fut.wait()}.")
151            >>> fut = torch.futures.Future()
152            >>> # The inserted callback will print the return value when
153            >>> # receiving the response from "worker1"
154            >>> cb_fut = fut.then(callback)
155            >>> chain_cb_fut = cb_fut.then(
156            ...     lambda x : print(f"Chained cb done. {x.wait()}")
157            ... )
158            >>> fut.set_result(5)
159            RPC return value is 5.
160            Chained cb done. None
161        """
162        return cast(Future[S], super().then(callback))
163
164    def add_done_callback(self, callback: Callable[[Future[T]], None]) -> None:
165        r"""
166        Append the given callback function to this ``Future``, which will be run
167        when the ``Future`` is completed.  Multiple callbacks can be added to
168        the same ``Future``, but the order in which they will be executed cannot
169        be guaranteed. The callback must take one argument, which is the
170        reference to this ``Future``. The callback function can use the
171        :meth:`value` method to get the value. Note that if this ``Future`` is
172        already completed, the given callback will be run inline.
173
174        We recommend that you use the :meth:`then` method as it provides a way
175        to synchronize after your callback has completed. ``add_done_callback``
176        can be cheaper if your callback does not return anything. But both
177        :meth:`then` and ``add_done_callback`` use the same callback
178        registration API under the hood.
179
180        With respect to GPU tensors, this method behaves in the same way as
181        :meth:`then`.
182
183        Args:
184            callback(``Future``): a ``Callable`` that takes in one argument,
185                which is the reference to this ``Future``.
186
187        .. note:: Note that if the callback function throws, either
188            through the original future being completed with an exception and
189            calling ``fut.wait()``, or through other code in the callback,
190            error handling must be carefully taken care of. For example, if
191            this callback later completes additional futures, those futures are
192            not marked as completed with an error and the user is responsible
193            for handling completion/waiting on those futures independently.
194
195        Example::
196            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
197            >>> def callback(fut):
198            ...     print("This will run after the future has finished.")
199            ...     print(fut.wait())
200            >>> fut = torch.futures.Future()
201            >>> fut.add_done_callback(callback)
202            >>> fut.set_result(5)
203            This will run after the future has finished.
204            5
205        """
206        super().add_done_callback(callback)
207
208    def set_result(self, result: T) -> None:
209        r"""
210        Set the result for this ``Future``, which will mark this ``Future`` as
211        completed and trigger all attached callbacks. Note that a ``Future``
212        cannot be marked completed twice.
213
214        If the result contains tensors that reside on GPUs, this method can be
215        called even if the asynchronous kernels that are populating those
216        tensors haven't yet completed running on the device, provided that the
217        streams on which those kernels were enqueued are set as the current ones
218        when this method is called. Put simply, it's safe to call this method
219        immediately after launching those kernels, without any additional
220        synchronization, as long as one doesn't change streams in between. This
221        method will record events on all the relevant current streams and will
222        use them to ensure proper scheduling for all the consumers of this
223        ``Future``.
224
225        Args:
226            result (object): the result object of this ``Future``.
227
228        Example::
229            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
230            >>> import threading
231            >>> import time
232            >>> def slow_set_future(fut, value):
233            ...     time.sleep(0.5)
234            ...     fut.set_result(value)
235            >>> fut = torch.futures.Future()
236            >>> t = threading.Thread(
237            ...     target=slow_set_future,
238            ...     args=(fut, torch.ones(2) * 3)
239            ... )
240            >>> t.start()
241            >>> print(fut.wait())
242            tensor([3., 3.])
243            >>> t.join()
244        """
245        super().set_result(result)
246
247    def set_exception(self, result: T) -> None:
248        r"""
249        Set an exception for this ``Future``, which will mark this ``Future`` as
250        completed with an error and trigger all attached callbacks. Note that
251        when calling wait()/value() on this ``Future``, the exception set here
252        will be raised inline.
253
254        Args:
255            result (BaseException): the exception for this ``Future``.
256
257        Example::
258            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
259            >>> fut = torch.futures.Future()
260            >>> fut.set_exception(ValueError("foo"))
261            >>> fut.wait()
262            Traceback (most recent call last):
263            ...
264            ValueError: foo
265        """
266        assert isinstance(result, Exception), f"{result} is of type {type(result)}, not an Exception."
267
268        def raise_error(fut_result):
269            raise fut_result
270
271        super()._set_unwrap_func(raise_error)
272        self.set_result(result)  # type: ignore[arg-type]
273
274
275def collect_all(futures: List[Future]) -> Future[List[Future]]:
276    r"""
277    Collects the provided :class:`~torch.futures.Future` objects into a single
278    combined :class:`~torch.futures.Future` that is completed when all of the
279    sub-futures are completed.
280
281    Args:
282        futures (list): a list of :class:`~torch.futures.Future` objects.
283
284    Returns:
285        Returns a :class:`~torch.futures.Future` object to a list of the passed
286        in Futures.
287
288    Example::
289        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
290        >>> fut0 = torch.futures.Future()
291        >>> fut1 = torch.futures.Future()
292        >>> fut = torch.futures.collect_all([fut0, fut1])
293        >>> fut0.set_result(0)
294        >>> fut1.set_result(1)
295        >>> fut_list = fut.wait()
296        >>> print(f"fut0 result = {fut_list[0].wait()}")
297        fut0 result = 0
298        >>> print(f"fut1 result = {fut_list[1].wait()}")
299        fut1 result = 1
300    """
301    return cast(Future[List[Future]], torch._C._collect_all(cast(List[torch._C.Future], futures)))
302
303
304def wait_all(futures: List[Future]) -> List:
305    r"""
306    Waits for all provided futures to be complete, and returns
307    the list of completed values. If any of the futures encounters an error,
308    the method will exit early and report the error not waiting for other
309    futures to complete.
310
311    Args:
312        futures (list): a list of :class:`~torch.futures.Future` object.
313
314    Returns:
315        A list of the completed :class:`~torch.futures.Future` results. This
316        method will throw an error if ``wait`` on any
317        :class:`~torch.futures.Future` throws.
318    """
319    return [fut.wait() for fut in torch._C._collect_all(cast(List[torch._C.Future], futures)).wait()]
320