xref: /aosp_15_r20/external/pytorch/torch/distributed/rpc/api.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3
4import collections
5import contextlib
6import functools
7import inspect
8import logging
9import threading
10from typing import Any, Dict, Generic, Set, TYPE_CHECKING, TypeVar
11
12import torch
13from torch._C._distributed_rpc import (
14    _cleanup_python_rpc_handler,
15    _delete_all_user_and_unforked_owner_rrefs,
16    _destroy_rref_context,
17    _get_current_rpc_agent,
18    _invoke_remote_builtin,
19    _invoke_remote_python_udf,
20    _invoke_remote_torchscript,
21    _invoke_rpc_builtin,
22    _invoke_rpc_python_udf,
23    _invoke_rpc_torchscript,
24    _is_current_rpc_agent_set,
25    _reset_current_rpc_agent,
26    _set_and_start_rpc_agent,
27    get_rpc_timeout,
28    PyRRef,
29    RemoteProfilerManager,
30    TensorPipeAgent,
31    WorkerInfo,
32)
33from torch.futures import Future
34
35from ._utils import _group_membership_management, _update_group_membership
36from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT
37from .internal import (
38    _build_rpc_profiling_key,
39    _internal_rpc_pickler,
40    PythonUDF,
41    RPCExecMode,
42)
43
44
45__all__ = [
46    "shutdown",
47    "get_worker_info",
48    "remote",
49    "rpc_sync",
50    "rpc_async",
51    "RRef",
52    "AllGatherStates",
53    "method_factory",
54    "new_method",
55]
56
57
58logger = logging.getLogger(__name__)
59
60# NB: Ignoring RRef leaks during shutdown. Without this, applications have to
61# make sure there is no references to any RRef in the application code and
62# Python GC has done its job to delete those RRefs. This is could result in bad
63# debugging experiences especially when for large applications. Therefore, by
64# default, we are going to ignore RRef leaks during shutdown. This is usually
65# fine as shutdown means applications have done training and no longer care
66# about states.
67#
68# To enable RRef leak checking, set this _ignore_rref_leak to False
69_ignore_rref_leak = True
70_default_pickler = _internal_rpc_pickler
71
72
73@contextlib.contextmanager
74def _use_rpc_pickler(rpc_pickler):
75    r"""
76    rpc_pickler: (.internal._InternalRPCPickler) Overrides the default RPC pickler
77    """
78    global _default_pickler
79    _default_pickler = rpc_pickler
80    try:
81        yield
82    finally:
83        _default_pickler = _internal_rpc_pickler
84
85
86def _require_initialized(func):
87    @functools.wraps(func)
88    def wrapper(*args, **kwargs):
89        if not _is_current_rpc_agent_set():
90            raise RuntimeError(
91                "RPC has not been initialized. Call "
92                "torch.distributed.rpc.init_rpc first."
93            )
94        return func(*args, **kwargs)
95
96    return wrapper
97
98
99class AllGatherStates:
100    def __init__(self):
101        # Each `gathered_objects` is an empty dict at beginning.
102        # The leader worker is elected as the first worker in a sorted worker
103        # name list. Whenever there is a worker entering `_all_gather()`, it
104        # runs `_gather_to_leader()` on the leader to add its own name and
105        # data obj to this dict. The leader also adds itself's name to the dict
106        # on calling `_all_gather()`.
107        # Once `set(gathered_objects.keys()) == _ALL_WORKER_NAMES`, the leader
108        # will broadcast the gathered dict to all follower workers and set their
109        # `gathered_objects` field and the `proceed_signal` field.
110        self.gathered_objects = {}
111        # All workers wait on this signal until it receives all gathered
112        # objects.
113        self.proceed_signal = threading.Event()
114
115
116# States used by `def _all_gather()`.
117# `_ALL_WORKER_NAMES` is initialized on initializing RPC layer.
118_ALL_WORKER_NAMES: Set[Any] = set()
119_all_gather_dict_lock = threading.RLock()
120_all_gather_sequence_id: Dict[str, int] = {}
121_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(
122    AllGatherStates
123)
124
125
126def _init_rpc_states(agent):
127    worker_infos = agent.get_worker_infos()
128    global _ALL_WORKER_NAMES
129    _ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos}
130
131    # NB: backend implementation might have already set the rpc_agent.
132    if not _is_current_rpc_agent_set():
133        _set_and_start_rpc_agent(agent)
134
135
136def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None):
137    with _all_gather_dict_lock:
138        if not worker_names:
139            worker_names = _ALL_WORKER_NAMES
140            assert (
141                worker_name in worker_names
142            ), f"{worker_name} is not expected by leader."
143        states = _all_gather_sequence_id_to_states[sequence_id]
144        assert (
145            worker_name not in states.gathered_objects
146        ), f"{worker_name} reported intent sequence id {sequence_id} twice. "
147        states.gathered_objects[worker_name] = obj
148        if worker_names == set(states.gathered_objects.keys()):
149            states.proceed_signal.set()
150
151
152def _broadcast_to_followers(sequence_id, objects_map):
153    with _all_gather_dict_lock:
154        states = _all_gather_sequence_id_to_states[sequence_id]
155
156    assert (
157        not states.proceed_signal.is_set()
158    ), f"Termination signal sequence id {sequence_id} got set twice."
159    states.gathered_objects = objects_map
160    states.proceed_signal.set()
161
162
163_thread_local_var = threading.local()
164
165
166@contextlib.contextmanager
167def _wait_all():
168    r"""
169    A context manager that collects all futures returned by ``rpc_async`` and
170    waits them on the context manager's exit; relieving the user of needing
171    to explicitly call wait.
172
173
174    Example::
175        >>> # xdoctest: +SKIP("distributed")
176        >>> # On worker 0:
177        >>> import torch
178        >>> import torch.distributed.rpc as rpc
179        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
180        >>> with rpc._wait_all():
181        >>>    fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
182        >>>    fut_2 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
183        >>> #fut_1 and fut_2 are waited on
184    """
185    _thread_local_var.future_list = []
186    try:
187        yield
188    finally:
189        try:
190            torch.futures.wait_all(_thread_local_var.future_list)
191        finally:
192            del _thread_local_var.future_list
193
194
195@_require_initialized
196def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT):
197    r"""
198    This is similar to torch.distributed.all_gather(), but is using RPC. It
199    picks the worker with the smallest name (alphabetic order) as the leader.
200    Then all followers send their data ``obj`` to the leader. After the leader
201    has received all, it will broadcast the results back to all followers. This
202    function blocks until all workers have received the gathered results.
203    """
204    if not worker_names:
205        assert (
206            _ALL_WORKER_NAMES is not None
207        ), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
208        worker_names = _ALL_WORKER_NAMES
209    leader_name = min(worker_names)
210
211    self_name = _get_current_rpc_agent().get_worker_info().name
212
213    with _all_gather_dict_lock:
214        concat_names = "".join(sorted(worker_names))
215        sequence_num = _all_gather_sequence_id.get(concat_names, 0)
216        _all_gather_sequence_id[concat_names] = sequence_num + 1
217        sequence_id = concat_names + str(sequence_num)
218
219    is_leader = leader_name == self_name
220
221    if timeout == UNSET_RPC_TIMEOUT:
222        # Timeout is specified by agent for RPC calls
223        rpc_timeout = get_rpc_timeout()
224        # No timeout for signal
225        signal_timeout = None
226    elif timeout == DEFAULT_SHUTDOWN_TIMEOUT:
227        # No timeout for RPC
228        rpc_timeout = timeout
229        # No timeout for signal
230        signal_timeout = None
231    else:
232        # Signal and RPC timeout use the same timeout
233        signal_timeout = rpc_timeout = timeout
234
235    # Phase 1: Followers send it's object to the leader
236    if is_leader:
237        _gather_to_leader(sequence_id, self_name, obj, worker_names)
238    else:
239        rpc_sync(
240            leader_name,
241            _gather_to_leader,
242            args=(sequence_id, self_name, obj, worker_names),
243            timeout=rpc_timeout,
244        )
245
246    with _all_gather_dict_lock:
247        states = _all_gather_sequence_id_to_states[sequence_id]
248
249    # Timeout is either set by function parameter or None (which is indefinite)
250    states.proceed_signal.wait(timeout=signal_timeout)
251
252    # Phase 2: Leader broadcast gathered results to all followers
253    # Leader's signal is the first to be unblocked, after receiving all
254    # followers' data objects.
255    if is_leader:
256        worker_name_to_response_future_dict = {}
257        for follower_name in worker_names - {leader_name}:
258            fut = rpc_async(
259                follower_name,
260                _broadcast_to_followers,
261                args=(sequence_id, states.gathered_objects),
262                timeout=rpc_timeout,
263            )
264            worker_name_to_response_future_dict[follower_name] = fut
265
266        errors = []
267        for follower_name, fut in worker_name_to_response_future_dict.items():
268            try:
269                fut.wait()
270            except RuntimeError as ex:
271                errors.append((follower_name, ex))
272
273        if errors:
274            raise RuntimeError(
275                f"Followers {[e[0] for e in errors]} timed out in _all_gather "
276                f"after {rpc_timeout:.2f} seconds. The first exception is {errors[0][1]}"
277            )
278
279    # Clean up for the states using the sequence_id
280    with _all_gather_dict_lock:
281        states = _all_gather_sequence_id_to_states.pop(sequence_id)
282    return states.gathered_objects
283
284
285@_require_initialized
286def _barrier(worker_names):
287    r"""
288    Synchronizes local and remote RPC processes.
289
290    This will block until all local and remote RPC processes specified under worker_names
291    reach this method to wait for all outstanding work to complete.
292
293    Args:
294        worker_names (List[str]): The set of workers to synchronize.
295
296    """
297    try:
298        _all_gather(None, set(worker_names))
299    except RuntimeError as ex:
300        logger.error("Failed to complete barrier, got error %s", ex)
301
302
303@_require_initialized
304def _wait_all_workers(timeout=DEFAULT_SHUTDOWN_TIMEOUT):
305    r"""
306    Block until all local and remote RPC processes reach this method and wait
307    for all outstanding work to complete. Every RPC process must call this
308    method before exit to perform a graceful shutdown. This should be used to
309    terminate the RPC framework, and there is no guarantee that the RPC
310    framework will work after this method returns.
311    """
312    try:
313        _all_gather(None, timeout=timeout)
314    except RuntimeError as ex:
315        logger.error(
316            "Failed to respond to 'Shutdown Proceed' in time, got error %s", ex
317        )
318        raise ex
319
320
321@_require_initialized
322def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT):
323    r"""
324    Perform a shutdown of the RPC agent, and then destroy the RPC agent. This
325    stops the local agent from accepting outstanding requests, and shuts
326    down the RPC framework by terminating all RPC threads. If ``graceful=True``,
327    this will block until all local and remote RPC processes reach this method
328    and wait for all outstanding work to complete. Otherwise, if
329    ``graceful=False``, this is a local shutdown, and it does not wait for other
330    RPC processes to reach this method.
331
332    .. warning::
333        For :class:`~torch.futures.Future` objects returned by
334        :meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not
335        be called after ``shutdown()``.
336
337    Args:
338        graceful (bool): Whether to do a graceful shutdown or not. If True,
339                         this will 1) wait until there is no pending system
340                         messages for ``UserRRefs`` and delete them; 2) block
341                         until all local and remote RPC processes have reached
342                         this method and wait for all outstanding work to
343                         complete.
344
345    Example::
346        Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
347        on both workers. Refer to :meth:`~torch.distributed.init_process_group`
348        API for more details. For example,
349
350        export MASTER_ADDR=localhost
351        export MASTER_PORT=5678
352
353        Then run the following code in two different processes:
354
355        >>> # xdoctest: +SKIP
356        >>> # On worker 0:
357        >>> import torch
358        >>> import torch.distributed.rpc as rpc
359        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
360        >>> # do some work
361        >>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1))
362        >>> # ready to shutdown
363        >>> rpc.shutdown()
364
365        >>> # On worker 1:
366        >>> import torch.distributed.rpc as rpc
367        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
368        >>> # wait for worker 0 to finish work, and then shutdown.
369        >>> rpc.shutdown()
370    """
371    if graceful:
372        try:
373            agent = _get_current_rpc_agent()
374            if not isinstance(agent, TensorPipeAgent) or agent.is_static_group:
375                _wait_all_workers(timeout)
376                _delete_all_user_and_unforked_owner_rrefs()
377                agent.join(shutdown=True, timeout=timeout)
378            else:
379                # This is a dynamic group so we need to grab the token for the operation
380                my_worker_info = agent.get_worker_info()
381                my_name = my_worker_info.name
382                with _group_membership_management(agent.store, my_name, False):
383                    all_worker_infos = agent.get_worker_infos()
384                    for worker in all_worker_infos:
385                        if worker.name != my_name:
386                            rpc_sync(
387                                worker.name,
388                                _update_group_membership,
389                                args=(my_worker_info, [], {}, False),
390                            )
391                    agent.join(shutdown=True, timeout=timeout)
392        finally:
393            # In case of errors, continue to complete the local shutdown.
394            _finalize_shutdown()
395    else:
396        _finalize_shutdown()
397
398
399def _finalize_shutdown():
400    try:
401        # This raises a `TORCH_CHECK()` exception on RRef leak detected.
402        _destroy_rref_context(_ignore_rref_leak)
403    finally:
404        _get_current_rpc_agent().shutdown()
405        # clean up python rpc handler in shutdown(), see comments in
406        # PythonRpcHandler::cleanup(), call it in python API because the
407        # cleanup() function has python dependency, it assumes python
408        # interpreter exists.
409        # No matter if RRef leak exception is raised, this clean-up code
410        # must run to avoid destruction segfault in Python 3.5.
411        #
412        # future.wait() should not be called after shutdown().
413        # pythonRpcHandler is cleaned up in shutdown(), after
414        # shutdown(), python objects returned from rpc python call can not be
415        # resolved.
416        _cleanup_python_rpc_handler()
417        _reset_current_rpc_agent()
418
419
420@_require_initialized
421def get_worker_info(worker_name=None):
422    r"""
423    Get :class:`~torch.distributed.rpc.WorkerInfo` of a given worker name.
424    Use this :class:`~torch.distributed.rpc.WorkerInfo` to avoid passing an
425    expensive string on every invocation.
426
427    Args:
428        worker_name (str): the string name of a worker. If ``None``, return the
429                           the id of the current worker. (default ``None``)
430
431    Returns:
432        :class:`~torch.distributed.rpc.WorkerInfo` instance for the given
433        ``worker_name`` or :class:`~torch.distributed.rpc.WorkerInfo` of the
434        current worker if ``worker_name`` is ``None``.
435    """
436    if worker_name is not None:
437        return _get_current_rpc_agent().get_worker_info(worker_name)
438    else:
439        return _get_current_rpc_agent().get_worker_info()
440
441
442def _to_worker_info(to):
443    if isinstance(to, WorkerInfo):
444        return to
445    elif isinstance(to, (str, int)):
446        return get_worker_info(to)
447    else:
448        raise ValueError(f"Cannot get WorkerInfo from name {to}")
449
450
451def _rref_typeof_on_owner(rref, blocking: bool = True):
452    rref_type = type(rref.local_value())
453    if blocking:
454        return rref_type
455    else:
456        # Wrap result into a completed Future. This is so that if blocking=`False`
457        # is specified, we return a future regardless of if this call is on user
458        # or owner.
459        future = Future[type]()
460        future.set_result(rref_type)
461        return future
462
463
464def _rref_typeof_on_user(
465    rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True
466):
467    fut = rpc_async(rref.owner(), _rref_typeof_on_owner, args=(rref,), timeout=timeout)
468    if blocking:
469        return fut.wait()
470    else:
471        return fut
472
473
474T = TypeVar("T")
475GenericWithOneTypeVar = Generic[T]
476
477
478if TYPE_CHECKING:
479
480    class RRef(PyRRef[T], Generic[T]):
481        pass
482
483else:
484    try:
485        # Combine the implementation class and the type class.
486        class RRef(PyRRef, Generic[T]):
487            pass
488
489    except TypeError:
490        # TypeError: metaclass conflict: the metaclass of a derived class
491        # must be a (non-strict) subclass of the metaclasses of all its bases
492        # Mypy doesn't understand __class__ (mypy bug #4177)
493        class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__):  # type: ignore[name-defined, misc, valid-type]
494            pass
495
496        # Combine the implementation class and the type class.
497        # Types for classes expecting a certain generic parameter (mypy bug #7791)
498        class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta):  # type: ignore[misc, no-redef, valid-type]
499            pass
500
501
502# Install docstrings from `PyRRef` to `RRef`.
503#
504# This is for the fact that pybind11 generates the parameter
505# `self` as type `rpc.PyRRef`, so a `:inherited-members:`
506# under `.. autoclass:: RRef` does not work.
507# we have to do the following process to replace `rpc.PyRRef` with `rpc.RRef`.
508#
509def method_factory(method_name, docstring):
510    def method(self, *args, **kwargs):
511        return getattr(super(RRef, self), method_name)(*args, **kwargs)
512
513    if method.__doc__:
514        method.__doc__ = docstring
515    return method
516
517
518for method_name, method in inspect.getmembers(PyRRef):
519    # Ignore magic methods, except "__str__".
520    if method_name.startswith("_") and method_name != "__str__":
521        continue
522
523    # Get pybind11 generated docstring.
524    # It's like,
525    """
526    to_here(self: torch.distributed.rpc.PyRRef, timeout: float=-1.0) -> object
527
528        Blocking call that copies the value of the RRef from the owner
529        to the local node and returns it. If the current node is the
530        owner, returns a reference to the local value.
531    """
532    docstring = getattr(method, "__doc__", None)
533    assert docstring is not None, "RRef user-facing methods should all have docstrings."
534
535    # Do surgery on pybind11 generated docstrings.
536    docstring = docstring.replace(
537        "torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef"
538    )
539
540    # Attach user-facing RRef method with modified docstring.
541    new_method = method_factory(method_name, docstring)
542    setattr(RRef, method_name, new_method)
543
544
545@_require_initialized
546def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
547    r"""
548    Make a remote call to run ``func`` on worker ``to`` and return an
549    :class:`~torch.distributed.rpc.RRef` to the result value immediately.
550    Worker ``to`` will be the owner of the returned
551    :class:`~torch.distributed.rpc.RRef`, and the worker calling ``remote`` is
552    a user. The owner manages the global reference count of its
553    :class:`~torch.distributed.rpc.RRef`, and the owner
554    :class:`~torch.distributed.rpc.RRef` is only destructed when globally there
555    are no living references to it.
556
557    Args:
558        to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
559        func (Callable): a callable function, such as Python callables, builtin
560                         operators (e.g. :meth:`~torch.add`) and annotated
561                         TorchScript functions.
562        args (tuple): the argument tuple for the ``func`` invocation.
563        kwargs (dict): is a dictionary of keyword arguments for the ``func``
564                       invocation.
565
566        timeout (float, optional): timeout in seconds for this remote call. If the
567                                   creation of this
568                                   :class:`~torch.distributed.rpc.RRef` on worker
569                                   ``to`` is not successfully processed on this
570                                   worker within this timeout, then the next time
571                                   there is an attempt to use the RRef (such as
572                                   ``to_here()``), a timeout will be raised
573                                   indicating this failure. A value of 0 indicates
574                                   an infinite timeout, i.e. a timeout error will
575                                   never be raised. If not provided, the default
576                                   value set during initialization or with
577                                   ``_set_rpc_timeout`` is used.
578
579    Returns:
580        A user :class:`~torch.distributed.rpc.RRef` instance to the result
581        value. Use the blocking API :meth:`torch.distributed.rpc.RRef.to_here`
582        to retrieve the result value locally.
583
584    .. warning ::
585        The ``remote`` API does not copy storages of argument tensors until
586        sending them over the wire, which could be done by a different thread
587        depending on the RPC backend type. The caller should make sure that the
588        contents of those tensors stay intact until the returned RRef is
589        confirmed by the owner, which can be checked using the
590        :meth:`torch.distributed.rpc.RRef.confirmed_by_owner` API.
591
592    .. warning ::
593        Errors such as timeouts for the ``remote`` API are handled on a
594        best-effort basis. This means that when remote calls initiated by
595        ``remote`` fail, such as with a timeout error, we take a best-effort
596        approach to error handling. This means that errors are handled and set
597        on the resulting RRef on an asynchronous basis. If the RRef has not been
598        used by the application before this handling (such as ``to_here`` or
599        fork call), then future uses of the ``RRef`` will appropriately raise
600        errors. However, it is possible that the user application will use the
601        ``RRef`` before the errors are handled. In this case, errors may not be
602        raised as they have not yet been handled.
603
604    Example::
605
606        Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
607        on both workers. Refer to :meth:`~torch.distributed.init_process_group`
608        API for more details. For example,
609
610        export MASTER_ADDR=localhost
611        export MASTER_PORT=5678
612
613        Then run the following code in two different processes:
614
615        >>> # xdoctest: +SKIP
616        >>> # On worker 0:
617        >>> import torch
618        >>> import torch.distributed.rpc as rpc
619        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
620        >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
621        >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
622        >>> x = rref1.to_here() + rref2.to_here()
623        >>> rpc.shutdown()
624
625        >>> # On worker 1:
626        >>> import torch.distributed.rpc as rpc
627        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
628        >>> rpc.shutdown()
629
630        Below is an example of running a TorchScript function using RPC.
631
632        >>> # On both workers:
633        >>> @torch.jit.script
634        >>> def my_script_add(tensor: torch.Tensor, scalar: int):
635        >>>    return torch.add(tensor, scalar)
636
637        >>> # On worker 0:
638        >>> import torch.distributed.rpc as rpc
639        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
640        >>> rref = rpc.remote("worker1", my_script_add, args=(torch.ones(2), 3))
641        >>> rref.to_here()
642        >>> rpc.shutdown()
643
644        >>> # On worker 1:
645        >>> import torch.distributed.rpc as rpc
646        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
647        >>> rpc.shutdown()
648    """
649    torch._C._log_api_usage_once("torch.distributed.rpc_remote")
650    qualified_name = torch.jit._builtins._find_builtin(func)
651    dst_worker_info = _to_worker_info(to)
652    should_profile = _get_should_profile()
653
654    ctx_manager = _enable_rpc_profiler(
655        should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info
656    )
657
658    with ctx_manager as rf:
659        args = args if args else ()
660        kwargs = kwargs if kwargs else {}
661
662        is_async_exec = hasattr(func, "_wrapped_async_rpc_function")
663
664        if is_async_exec:
665            wrapped = func._wrapped_async_rpc_function
666            if isinstance(wrapped, torch.jit.ScriptFunction):
667                func = wrapped
668
669        if qualified_name is not None:
670            rref = _invoke_remote_builtin(
671                dst_worker_info, qualified_name, timeout, *args, **kwargs
672            )
673        elif isinstance(func, torch.jit.ScriptFunction):
674            rref = _invoke_remote_torchscript(
675                dst_worker_info.name,
676                torch._jit_internal._qualified_name(func),
677                timeout,
678                is_async_exec,
679                *args,
680                **kwargs,
681            )
682        else:
683            (pickled_python_udf, tensors) = _default_pickler.serialize(
684                PythonUDF(func, args, kwargs)
685            )
686            rref = _invoke_remote_python_udf(
687                dst_worker_info, pickled_python_udf, tensors, timeout, is_async_exec
688            )
689        # attach profiling information
690        if should_profile:
691            assert torch.autograd._profiler_enabled()
692            assert rf is not None
693            fut = rf._call_end_callbacks_on_future(rref._get_future())
694            rref._set_profiling_future(fut)
695
696    return rref
697
698
699def _invoke_rpc(
700    to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT
701):
702    if not callable(func):
703        raise TypeError("function should be callable.")
704
705    qualified_name = torch.jit._builtins._find_builtin(func)
706    dst_worker_info = _to_worker_info(to)
707
708    should_profile = _get_should_profile()
709
710    ctx_manager = _enable_rpc_profiler(
711        should_profile, qualified_name, func, rpc_type, dst_worker_info
712    )
713
714    with ctx_manager as rf:
715        args = args if args else ()
716        kwargs = kwargs if kwargs else {}
717
718        is_async_exec = hasattr(func, "_wrapped_async_rpc_function")
719
720        if is_async_exec:
721            wrapped = func._wrapped_async_rpc_function
722            if isinstance(wrapped, torch.jit.ScriptFunction):
723                func = wrapped
724
725        if qualified_name is not None:
726            fut = _invoke_rpc_builtin(
727                dst_worker_info, qualified_name, rpc_timeout, *args, **kwargs
728            )
729        elif isinstance(func, torch.jit.ScriptFunction):
730            fut = _invoke_rpc_torchscript(
731                dst_worker_info.name,
732                torch._jit_internal._qualified_name(func),
733                args,
734                kwargs,
735                rpc_timeout,
736                is_async_exec,
737            )
738        else:
739            (pickled_python_udf, tensors) = _default_pickler.serialize(
740                PythonUDF(func, args, kwargs)
741            )
742            fut = _invoke_rpc_python_udf(
743                dst_worker_info, pickled_python_udf, tensors, rpc_timeout, is_async_exec
744            )
745        if should_profile:
746            assert torch.autograd._profiler_enabled()
747            assert rf is not None
748            # Schedule profiling callbacks to run when the future completes.
749            # This returns a future that is completed when the original future
750            # completes and the profiling callbacks have been completed as well,
751            # to guarantee that fut.wait() completes the profiling. This new
752            # future will contain the same value as the original future.
753            fut = rf._call_end_callbacks_on_future(fut)
754    return fut
755
756
757@_require_initialized
758def rpc_sync(to, func, args=None, kwargs=None, timeout: float = UNSET_RPC_TIMEOUT):
759    r"""
760    Make a blocking RPC call to run function ``func`` on worker ``to``. RPC
761    messages are sent and received in parallel to execution of Python code. This
762    method is thread-safe.
763
764    Args:
765        to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
766        func (Callable): a callable function, such as Python callables, builtin
767                         operators (e.g. :meth:`~torch.add`) and annotated
768                         TorchScript functions.
769        args (tuple): the argument tuple for the ``func`` invocation.
770        kwargs (dict): is a dictionary of keyword arguments for the ``func``
771                       invocation.
772        timeout (float, optional): timeout in seconds to use for this RPC. If
773                                   the RPC does not complete in this amount of
774                                   time, an exception indicating it has
775                                   timed out will be raised. A value of 0
776                                   indicates an infinite timeout, i.e. a timeout
777                                   error will never be raised. If not provided,
778                                   the default value set during initialization
779                                   or with ``_set_rpc_timeout`` is used.
780
781    Returns:
782        Returns the result of running ``func`` with ``args`` and ``kwargs``.
783
784    Example::
785        Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
786        on both workers. Refer to :meth:`~torch.distributed.init_process_group`
787        API for more details. For example,
788
789        export MASTER_ADDR=localhost
790        export MASTER_PORT=5678
791
792        Then run the following code in two different processes:
793
794        >>> # xdoctest: +SKIP
795        >>> # On worker 0:
796        >>> import torch
797        >>> import torch.distributed.rpc as rpc
798        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
799        >>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
800        >>> rpc.shutdown()
801
802        >>> # On worker 1:
803        >>> import torch.distributed.rpc as rpc
804        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
805        >>> rpc.shutdown()
806
807        Below is an example of running a TorchScript function using RPC.
808
809        >>> # On both workers:
810        >>> @torch.jit.script
811        >>> def my_script_add(tensor: torch.Tensor, scalar: int):
812        >>>    return torch.add(tensor, scalar)
813
814        >>> # On worker 0:
815        >>> import torch.distributed.rpc as rpc
816        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
817        >>> ret = rpc.rpc_sync("worker1", my_script_add, args=(torch.ones(2), 3))
818        >>> rpc.shutdown()
819
820        >>> # On worker 1:
821        >>> import torch.distributed.rpc as rpc
822        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
823        >>> rpc.shutdown()
824
825    """
826    torch._C._log_api_usage_once("torch.distributed.rpc_sync")
827    fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs, timeout)
828    return fut.wait()
829
830
831@_require_initialized
832def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
833    r"""
834    Make a non-blocking RPC call to run function ``func`` on worker ``to``. RPC
835    messages are sent and received in parallel to execution of Python code. This
836    method is thread-safe. This method will immediately return a
837    :class:`~torch.futures.Future` that can be awaited on.
838
839    Args:
840        to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
841        func (Callable): a callable function, such as Python callables, builtin
842                         operators (e.g. :meth:`~torch.add`) and annotated
843                         TorchScript functions.
844        args (tuple): the argument tuple for the ``func`` invocation.
845        kwargs (dict): is a dictionary of keyword arguments for the ``func``
846                       invocation.
847        timeout (float, optional): timeout in seconds to use for this RPC. If
848                                   the RPC does not complete in this amount of
849                                   time, an exception indicating it has
850                                   timed out will be raised. A value of 0
851                                   indicates an infinite timeout, i.e. a timeout
852                                   error will never be raised. If not provided,
853                                   the default value set during initialization
854                                   or with ``_set_rpc_timeout`` is used.
855
856
857    Returns:
858        Returns a :class:`~torch.futures.Future` object that can be waited
859        on. When completed, the return value of ``func`` on ``args`` and
860        ``kwargs`` can be retrieved from the :class:`~torch.futures.Future`
861        object.
862
863    .. warning ::
864        Using GPU tensors as arguments or return values of ``func`` is not
865        supported since we don't support sending GPU tensors over the wire. You
866        need to explicitly copy GPU tensors to CPU before using them as
867        arguments or return values of ``func``.
868
869    .. warning ::
870        The ``rpc_async`` API does not copy storages of argument tensors until
871        sending them over the wire, which could be done by a different thread
872        depending on the RPC backend type. The caller should make sure that the
873        contents of those tensors stay intact until the returned
874        :class:`~torch.futures.Future` completes.
875
876    Example::
877        Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
878        on both workers. Refer to :meth:`~torch.distributed.init_process_group`
879        API for more details. For example,
880
881        export MASTER_ADDR=localhost
882        export MASTER_PORT=5678
883
884        Then run the following code in two different processes:
885
886        >>> # xdoctest: +SKIP
887        >>> # On worker 0:
888        >>> import torch
889        >>> import torch.distributed.rpc as rpc
890        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
891        >>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
892        >>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
893        >>> result = fut1.wait() + fut2.wait()
894        >>> rpc.shutdown()
895
896        >>> # On worker 1:
897        >>> import torch.distributed.rpc as rpc
898        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
899        >>> rpc.shutdown()
900
901        Below is an example of running a TorchScript function using RPC.
902
903        >>> # On both workers:
904        >>> @torch.jit.script
905        >>> def my_script_add(tensor: torch.Tensor, scalar: int):
906        >>>    return torch.add(tensor, scalar)
907
908        >>> # On worker 0:
909        >>> import torch.distributed.rpc as rpc
910        >>> rpc.init_rpc("worker0", rank=0, world_size=2)
911        >>> fut = rpc.rpc_async("worker1", my_script_add, args=(torch.ones(2), 3))
912        >>> ret = fut.wait()
913        >>> rpc.shutdown()
914
915        >>> # On worker 1:
916        >>> import torch.distributed.rpc as rpc
917        >>> rpc.init_rpc("worker1", rank=1, world_size=2)
918        >>> rpc.shutdown()
919    """
920    torch._C._log_api_usage_once("torch.distributed.rpc_async")
921    fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout)
922    if hasattr(_thread_local_var, "future_list"):
923        _thread_local_var.future_list.append(fut)
924    return fut
925
926
927def _get_should_profile():
928    # Legacy profiler should be enabled. RPC profiling is not supported with
929    # Kineto profiler.
930    ActiveProfilerType = torch._C._profiler.ActiveProfilerType
931    return (
932        torch.autograd._profiler_enabled()
933        and torch._C._autograd._profiler_type()
934        == ActiveProfilerType.LEGACY  # type: ignore[attr-defined]
935    )
936
937
938def _enable_rpc_profiler(
939    should_profile, qualified_name, func, rpc_type, dst_worker_info
940):
941    ctx_manager = contextlib.nullcontext()
942
943    if should_profile:
944        # Create appropriate string representation based on type of func
945        # (builtin, script, python)
946        if qualified_name is None:
947            func_name = (
948                torch._jit_internal._qualified_name(func)
949                if isinstance(func, torch.jit.ScriptFunction)
950                else func.__qualname__
951            )
952        else:
953            func_name = qualified_name
954        # Build RPC profiling key.
955        rpc_profiling_key = _build_rpc_profiling_key(
956            rpc_type,
957            func_name,
958            get_worker_info().name,
959            dst_worker_info.name,
960        )
961        RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key)
962        # Mypy doesn't support re-def of a variable not in the same block (#1174)
963        ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key)  # type: ignore[assignment]
964
965    return ctx_manager
966