1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4from typing import cast, Callable, Generic, List, Optional, Type, TypeVar, Union 5 6import torch 7 8__all__ = ['Future', 'collect_all', 'wait_all'] 9 10T = TypeVar("T") 11S = TypeVar("S") 12 13 14class _PyFutureMeta(type(torch._C.Future), type(Generic)): # type: ignore[misc, no-redef] 15 pass 16 17 18class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta): 19 r""" 20 Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous 21 execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It 22 also exposes a set of APIs to add callback functions and set results. 23 24 .. warning:: GPU support is a beta feature, subject to changes. 25 """ 26 27 def __init__(self, *, devices: Optional[List[Union[int, str, torch.device]]] = None): 28 r""" 29 Create an empty unset ``Future``. If the future is intended to hold 30 values containing CUDA tensors, (a superset of) their CUDA devices must 31 be specified at construction. (This is only supported if 32 ``torch.cuda.is_available()`` returns ``True``). This is needed to 33 ensure proper CUDA stream synchronization. The child futures, returned 34 by the ``then`` method, will inherit these devices. 35 36 Args: 37 devices(``List[Union[int, str, torch.device]]``, optional): the set 38 of devices on which tensors contained in this future's value are 39 allowed to reside and on which callbacks are allowed to operate. 40 """ 41 if devices is None: 42 devices = [] 43 super().__init__([torch.device(d) for d in devices]) 44 45 def done(self) -> bool: 46 r""" 47 Return ``True`` if this ``Future`` is done. A ``Future`` is done if it 48 has a result or an exception. 49 50 If the value contains tensors that reside on GPUs, ``Future.done()`` 51 will return ``True`` even if the asynchronous kernels that are 52 populating those tensors haven't yet completed running on the device, 53 because at such stage the result is already usable, provided one 54 performs the appropriate synchronizations (see :meth:`wait`). 55 """ 56 return super().done() 57 58 def wait(self) -> T: 59 r""" 60 Block until the value of this ``Future`` is ready. 61 62 If the value contains tensors that reside on GPUs, then an additional 63 synchronization is performed with the kernels (executing on the device) 64 which may be asynchronously populating those tensors. Such sync is 65 non-blocking, which means that ``wait()`` will insert the necessary 66 instructions in the current streams to ensure that further operations 67 enqueued on those streams will be properly scheduled after the async 68 kernels but, once that is done, ``wait()`` will return, even if those 69 kernels are still running. No further synchronization is required when 70 accessing and using the values, as long as one doesn't change streams. 71 72 Returns: 73 The value held by this ``Future``. If the function (callback or RPC) 74 creating the value has thrown an error, this ``wait`` method will 75 also throw an error. 76 """ 77 return super().wait() 78 79 def value(self) -> T: 80 r""" 81 Obtain the value of an already-completed future. 82 83 This method should only be called after a call to :meth:`wait` has 84 completed, or inside a callback function passed to :meth:`then`. In 85 other cases this ``Future`` may not yet hold a value and calling 86 ``value()`` could fail. 87 88 If the value contains tensors that reside on GPUs, then this method will 89 *not* perform any additional synchronization. This should be done 90 beforehand, separately, through a call to :meth:`wait` (except within 91 callbacks, for which it's already being taken care of by :meth:`then`). 92 93 Returns: 94 The value held by this ``Future``. If the function (callback or RPC) 95 creating the value has thrown an error, this ``value()`` method will 96 also throw an error. 97 """ 98 return super().value() 99 100 def then(self, callback: Callable[[Future[T]], S]) -> Future[S]: 101 r""" 102 Append the given callback function to this ``Future``, which will be run 103 when the ``Future`` is completed. Multiple callbacks can be added to 104 the same ``Future``, but the order in which they will be executed cannot 105 be guaranteed (to enforce a certain order consider chaining: 106 ``fut.then(cb1).then(cb2)``). The callback must take one argument, which 107 is the reference to this ``Future``. The callback function can use the 108 :meth:`value` method to get the value. Note that if this ``Future`` is 109 already completed, the given callback will be run immediately inline. 110 111 If the ``Future``'s value contains tensors that reside on GPUs, the 112 callback might be invoked while the async kernels that are populating 113 those tensors haven't yet finished executing on the device. However, the 114 callback will be invoked with some dedicated streams set as current 115 (fetched from a global pool) which will be synchronized with those 116 kernels. Hence any operation performed by the callback on these tensors 117 will be scheduled on the device after the kernels complete. In other 118 words, as long as the callback doesn't switch streams, it can safely 119 manipulate the result without any additional synchronization. This is 120 similar to the non-blocking behavior of :meth:`wait`. 121 122 Similarly, if the callback returns a value that contains tensors that 123 reside on a GPU, it can do so even if the kernels that are producing 124 these tensors are still running on the device, as long as the callback 125 didn't change streams during its execution. If one wants to change 126 streams, one must be careful to re-synchronize them with the original 127 streams, that is, those that were current when the callback was invoked. 128 129 Args: 130 callback(``Callable``): a ``Callable`` that takes this ``Future`` as 131 the only argument. 132 133 Returns: 134 A new ``Future`` object that holds the return value of the 135 ``callback`` and will be marked as completed when the given 136 ``callback`` finishes. 137 138 .. note:: Note that if the callback function throws, either 139 through the original future being completed with an exception and 140 calling ``fut.wait()``, or through other code in the callback, the 141 future returned by ``then`` will be marked appropriately with the 142 encountered error. However, if this callback later completes 143 additional futures, those futures are not marked as completed with 144 an error and the user is responsible for handling completion/waiting 145 on those futures independently. 146 147 Example:: 148 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) 149 >>> def callback(fut): 150 ... print(f"RPC return value is {fut.wait()}.") 151 >>> fut = torch.futures.Future() 152 >>> # The inserted callback will print the return value when 153 >>> # receiving the response from "worker1" 154 >>> cb_fut = fut.then(callback) 155 >>> chain_cb_fut = cb_fut.then( 156 ... lambda x : print(f"Chained cb done. {x.wait()}") 157 ... ) 158 >>> fut.set_result(5) 159 RPC return value is 5. 160 Chained cb done. None 161 """ 162 return cast(Future[S], super().then(callback)) 163 164 def add_done_callback(self, callback: Callable[[Future[T]], None]) -> None: 165 r""" 166 Append the given callback function to this ``Future``, which will be run 167 when the ``Future`` is completed. Multiple callbacks can be added to 168 the same ``Future``, but the order in which they will be executed cannot 169 be guaranteed. The callback must take one argument, which is the 170 reference to this ``Future``. The callback function can use the 171 :meth:`value` method to get the value. Note that if this ``Future`` is 172 already completed, the given callback will be run inline. 173 174 We recommend that you use the :meth:`then` method as it provides a way 175 to synchronize after your callback has completed. ``add_done_callback`` 176 can be cheaper if your callback does not return anything. But both 177 :meth:`then` and ``add_done_callback`` use the same callback 178 registration API under the hood. 179 180 With respect to GPU tensors, this method behaves in the same way as 181 :meth:`then`. 182 183 Args: 184 callback(``Future``): a ``Callable`` that takes in one argument, 185 which is the reference to this ``Future``. 186 187 .. note:: Note that if the callback function throws, either 188 through the original future being completed with an exception and 189 calling ``fut.wait()``, or through other code in the callback, 190 error handling must be carefully taken care of. For example, if 191 this callback later completes additional futures, those futures are 192 not marked as completed with an error and the user is responsible 193 for handling completion/waiting on those futures independently. 194 195 Example:: 196 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) 197 >>> def callback(fut): 198 ... print("This will run after the future has finished.") 199 ... print(fut.wait()) 200 >>> fut = torch.futures.Future() 201 >>> fut.add_done_callback(callback) 202 >>> fut.set_result(5) 203 This will run after the future has finished. 204 5 205 """ 206 super().add_done_callback(callback) 207 208 def set_result(self, result: T) -> None: 209 r""" 210 Set the result for this ``Future``, which will mark this ``Future`` as 211 completed and trigger all attached callbacks. Note that a ``Future`` 212 cannot be marked completed twice. 213 214 If the result contains tensors that reside on GPUs, this method can be 215 called even if the asynchronous kernels that are populating those 216 tensors haven't yet completed running on the device, provided that the 217 streams on which those kernels were enqueued are set as the current ones 218 when this method is called. Put simply, it's safe to call this method 219 immediately after launching those kernels, without any additional 220 synchronization, as long as one doesn't change streams in between. This 221 method will record events on all the relevant current streams and will 222 use them to ensure proper scheduling for all the consumers of this 223 ``Future``. 224 225 Args: 226 result (object): the result object of this ``Future``. 227 228 Example:: 229 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) 230 >>> import threading 231 >>> import time 232 >>> def slow_set_future(fut, value): 233 ... time.sleep(0.5) 234 ... fut.set_result(value) 235 >>> fut = torch.futures.Future() 236 >>> t = threading.Thread( 237 ... target=slow_set_future, 238 ... args=(fut, torch.ones(2) * 3) 239 ... ) 240 >>> t.start() 241 >>> print(fut.wait()) 242 tensor([3., 3.]) 243 >>> t.join() 244 """ 245 super().set_result(result) 246 247 def set_exception(self, result: T) -> None: 248 r""" 249 Set an exception for this ``Future``, which will mark this ``Future`` as 250 completed with an error and trigger all attached callbacks. Note that 251 when calling wait()/value() on this ``Future``, the exception set here 252 will be raised inline. 253 254 Args: 255 result (BaseException): the exception for this ``Future``. 256 257 Example:: 258 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) 259 >>> fut = torch.futures.Future() 260 >>> fut.set_exception(ValueError("foo")) 261 >>> fut.wait() 262 Traceback (most recent call last): 263 ... 264 ValueError: foo 265 """ 266 assert isinstance(result, Exception), f"{result} is of type {type(result)}, not an Exception." 267 268 def raise_error(fut_result): 269 raise fut_result 270 271 super()._set_unwrap_func(raise_error) 272 self.set_result(result) # type: ignore[arg-type] 273 274 275def collect_all(futures: List[Future]) -> Future[List[Future]]: 276 r""" 277 Collects the provided :class:`~torch.futures.Future` objects into a single 278 combined :class:`~torch.futures.Future` that is completed when all of the 279 sub-futures are completed. 280 281 Args: 282 futures (list): a list of :class:`~torch.futures.Future` objects. 283 284 Returns: 285 Returns a :class:`~torch.futures.Future` object to a list of the passed 286 in Futures. 287 288 Example:: 289 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) 290 >>> fut0 = torch.futures.Future() 291 >>> fut1 = torch.futures.Future() 292 >>> fut = torch.futures.collect_all([fut0, fut1]) 293 >>> fut0.set_result(0) 294 >>> fut1.set_result(1) 295 >>> fut_list = fut.wait() 296 >>> print(f"fut0 result = {fut_list[0].wait()}") 297 fut0 result = 0 298 >>> print(f"fut1 result = {fut_list[1].wait()}") 299 fut1 result = 1 300 """ 301 return cast(Future[List[Future]], torch._C._collect_all(cast(List[torch._C.Future], futures))) 302 303 304def wait_all(futures: List[Future]) -> List: 305 r""" 306 Waits for all provided futures to be complete, and returns 307 the list of completed values. If any of the futures encounters an error, 308 the method will exit early and report the error not waiting for other 309 futures to complete. 310 311 Args: 312 futures (list): a list of :class:`~torch.futures.Future` object. 313 314 Returns: 315 A list of the completed :class:`~torch.futures.Future` results. This 316 method will throw an error if ``wait`` on any 317 :class:`~torch.futures.Future` throws. 318 """ 319 return [fut.wait() for fut in torch._C._collect_all(cast(List[torch._C.Future], futures)).wait()] 320