1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3 4import collections 5import contextlib 6import functools 7import inspect 8import logging 9import threading 10from typing import Any, Dict, Generic, Set, TYPE_CHECKING, TypeVar 11 12import torch 13from torch._C._distributed_rpc import ( 14 _cleanup_python_rpc_handler, 15 _delete_all_user_and_unforked_owner_rrefs, 16 _destroy_rref_context, 17 _get_current_rpc_agent, 18 _invoke_remote_builtin, 19 _invoke_remote_python_udf, 20 _invoke_remote_torchscript, 21 _invoke_rpc_builtin, 22 _invoke_rpc_python_udf, 23 _invoke_rpc_torchscript, 24 _is_current_rpc_agent_set, 25 _reset_current_rpc_agent, 26 _set_and_start_rpc_agent, 27 get_rpc_timeout, 28 PyRRef, 29 RemoteProfilerManager, 30 TensorPipeAgent, 31 WorkerInfo, 32) 33from torch.futures import Future 34 35from ._utils import _group_membership_management, _update_group_membership 36from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT 37from .internal import ( 38 _build_rpc_profiling_key, 39 _internal_rpc_pickler, 40 PythonUDF, 41 RPCExecMode, 42) 43 44 45__all__ = [ 46 "shutdown", 47 "get_worker_info", 48 "remote", 49 "rpc_sync", 50 "rpc_async", 51 "RRef", 52 "AllGatherStates", 53 "method_factory", 54 "new_method", 55] 56 57 58logger = logging.getLogger(__name__) 59 60# NB: Ignoring RRef leaks during shutdown. Without this, applications have to 61# make sure there is no references to any RRef in the application code and 62# Python GC has done its job to delete those RRefs. This is could result in bad 63# debugging experiences especially when for large applications. Therefore, by 64# default, we are going to ignore RRef leaks during shutdown. This is usually 65# fine as shutdown means applications have done training and no longer care 66# about states. 67# 68# To enable RRef leak checking, set this _ignore_rref_leak to False 69_ignore_rref_leak = True 70_default_pickler = _internal_rpc_pickler 71 72 73@contextlib.contextmanager 74def _use_rpc_pickler(rpc_pickler): 75 r""" 76 rpc_pickler: (.internal._InternalRPCPickler) Overrides the default RPC pickler 77 """ 78 global _default_pickler 79 _default_pickler = rpc_pickler 80 try: 81 yield 82 finally: 83 _default_pickler = _internal_rpc_pickler 84 85 86def _require_initialized(func): 87 @functools.wraps(func) 88 def wrapper(*args, **kwargs): 89 if not _is_current_rpc_agent_set(): 90 raise RuntimeError( 91 "RPC has not been initialized. Call " 92 "torch.distributed.rpc.init_rpc first." 93 ) 94 return func(*args, **kwargs) 95 96 return wrapper 97 98 99class AllGatherStates: 100 def __init__(self): 101 # Each `gathered_objects` is an empty dict at beginning. 102 # The leader worker is elected as the first worker in a sorted worker 103 # name list. Whenever there is a worker entering `_all_gather()`, it 104 # runs `_gather_to_leader()` on the leader to add its own name and 105 # data obj to this dict. The leader also adds itself's name to the dict 106 # on calling `_all_gather()`. 107 # Once `set(gathered_objects.keys()) == _ALL_WORKER_NAMES`, the leader 108 # will broadcast the gathered dict to all follower workers and set their 109 # `gathered_objects` field and the `proceed_signal` field. 110 self.gathered_objects = {} 111 # All workers wait on this signal until it receives all gathered 112 # objects. 113 self.proceed_signal = threading.Event() 114 115 116# States used by `def _all_gather()`. 117# `_ALL_WORKER_NAMES` is initialized on initializing RPC layer. 118_ALL_WORKER_NAMES: Set[Any] = set() 119_all_gather_dict_lock = threading.RLock() 120_all_gather_sequence_id: Dict[str, int] = {} 121_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict( 122 AllGatherStates 123) 124 125 126def _init_rpc_states(agent): 127 worker_infos = agent.get_worker_infos() 128 global _ALL_WORKER_NAMES 129 _ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos} 130 131 # NB: backend implementation might have already set the rpc_agent. 132 if not _is_current_rpc_agent_set(): 133 _set_and_start_rpc_agent(agent) 134 135 136def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None): 137 with _all_gather_dict_lock: 138 if not worker_names: 139 worker_names = _ALL_WORKER_NAMES 140 assert ( 141 worker_name in worker_names 142 ), f"{worker_name} is not expected by leader." 143 states = _all_gather_sequence_id_to_states[sequence_id] 144 assert ( 145 worker_name not in states.gathered_objects 146 ), f"{worker_name} reported intent sequence id {sequence_id} twice. " 147 states.gathered_objects[worker_name] = obj 148 if worker_names == set(states.gathered_objects.keys()): 149 states.proceed_signal.set() 150 151 152def _broadcast_to_followers(sequence_id, objects_map): 153 with _all_gather_dict_lock: 154 states = _all_gather_sequence_id_to_states[sequence_id] 155 156 assert ( 157 not states.proceed_signal.is_set() 158 ), f"Termination signal sequence id {sequence_id} got set twice." 159 states.gathered_objects = objects_map 160 states.proceed_signal.set() 161 162 163_thread_local_var = threading.local() 164 165 166@contextlib.contextmanager 167def _wait_all(): 168 r""" 169 A context manager that collects all futures returned by ``rpc_async`` and 170 waits them on the context manager's exit; relieving the user of needing 171 to explicitly call wait. 172 173 174 Example:: 175 >>> # xdoctest: +SKIP("distributed") 176 >>> # On worker 0: 177 >>> import torch 178 >>> import torch.distributed.rpc as rpc 179 >>> rpc.init_rpc("worker0", rank=0, world_size=2) 180 >>> with rpc._wait_all(): 181 >>> fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) 182 >>> fut_2 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) 183 >>> #fut_1 and fut_2 are waited on 184 """ 185 _thread_local_var.future_list = [] 186 try: 187 yield 188 finally: 189 try: 190 torch.futures.wait_all(_thread_local_var.future_list) 191 finally: 192 del _thread_local_var.future_list 193 194 195@_require_initialized 196def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT): 197 r""" 198 This is similar to torch.distributed.all_gather(), but is using RPC. It 199 picks the worker with the smallest name (alphabetic order) as the leader. 200 Then all followers send their data ``obj`` to the leader. After the leader 201 has received all, it will broadcast the results back to all followers. This 202 function blocks until all workers have received the gathered results. 203 """ 204 if not worker_names: 205 assert ( 206 _ALL_WORKER_NAMES is not None 207 ), "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`." 208 worker_names = _ALL_WORKER_NAMES 209 leader_name = min(worker_names) 210 211 self_name = _get_current_rpc_agent().get_worker_info().name 212 213 with _all_gather_dict_lock: 214 concat_names = "".join(sorted(worker_names)) 215 sequence_num = _all_gather_sequence_id.get(concat_names, 0) 216 _all_gather_sequence_id[concat_names] = sequence_num + 1 217 sequence_id = concat_names + str(sequence_num) 218 219 is_leader = leader_name == self_name 220 221 if timeout == UNSET_RPC_TIMEOUT: 222 # Timeout is specified by agent for RPC calls 223 rpc_timeout = get_rpc_timeout() 224 # No timeout for signal 225 signal_timeout = None 226 elif timeout == DEFAULT_SHUTDOWN_TIMEOUT: 227 # No timeout for RPC 228 rpc_timeout = timeout 229 # No timeout for signal 230 signal_timeout = None 231 else: 232 # Signal and RPC timeout use the same timeout 233 signal_timeout = rpc_timeout = timeout 234 235 # Phase 1: Followers send it's object to the leader 236 if is_leader: 237 _gather_to_leader(sequence_id, self_name, obj, worker_names) 238 else: 239 rpc_sync( 240 leader_name, 241 _gather_to_leader, 242 args=(sequence_id, self_name, obj, worker_names), 243 timeout=rpc_timeout, 244 ) 245 246 with _all_gather_dict_lock: 247 states = _all_gather_sequence_id_to_states[sequence_id] 248 249 # Timeout is either set by function parameter or None (which is indefinite) 250 states.proceed_signal.wait(timeout=signal_timeout) 251 252 # Phase 2: Leader broadcast gathered results to all followers 253 # Leader's signal is the first to be unblocked, after receiving all 254 # followers' data objects. 255 if is_leader: 256 worker_name_to_response_future_dict = {} 257 for follower_name in worker_names - {leader_name}: 258 fut = rpc_async( 259 follower_name, 260 _broadcast_to_followers, 261 args=(sequence_id, states.gathered_objects), 262 timeout=rpc_timeout, 263 ) 264 worker_name_to_response_future_dict[follower_name] = fut 265 266 errors = [] 267 for follower_name, fut in worker_name_to_response_future_dict.items(): 268 try: 269 fut.wait() 270 except RuntimeError as ex: 271 errors.append((follower_name, ex)) 272 273 if errors: 274 raise RuntimeError( 275 f"Followers {[e[0] for e in errors]} timed out in _all_gather " 276 f"after {rpc_timeout:.2f} seconds. The first exception is {errors[0][1]}" 277 ) 278 279 # Clean up for the states using the sequence_id 280 with _all_gather_dict_lock: 281 states = _all_gather_sequence_id_to_states.pop(sequence_id) 282 return states.gathered_objects 283 284 285@_require_initialized 286def _barrier(worker_names): 287 r""" 288 Synchronizes local and remote RPC processes. 289 290 This will block until all local and remote RPC processes specified under worker_names 291 reach this method to wait for all outstanding work to complete. 292 293 Args: 294 worker_names (List[str]): The set of workers to synchronize. 295 296 """ 297 try: 298 _all_gather(None, set(worker_names)) 299 except RuntimeError as ex: 300 logger.error("Failed to complete barrier, got error %s", ex) 301 302 303@_require_initialized 304def _wait_all_workers(timeout=DEFAULT_SHUTDOWN_TIMEOUT): 305 r""" 306 Block until all local and remote RPC processes reach this method and wait 307 for all outstanding work to complete. Every RPC process must call this 308 method before exit to perform a graceful shutdown. This should be used to 309 terminate the RPC framework, and there is no guarantee that the RPC 310 framework will work after this method returns. 311 """ 312 try: 313 _all_gather(None, timeout=timeout) 314 except RuntimeError as ex: 315 logger.error( 316 "Failed to respond to 'Shutdown Proceed' in time, got error %s", ex 317 ) 318 raise ex 319 320 321@_require_initialized 322def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT): 323 r""" 324 Perform a shutdown of the RPC agent, and then destroy the RPC agent. This 325 stops the local agent from accepting outstanding requests, and shuts 326 down the RPC framework by terminating all RPC threads. If ``graceful=True``, 327 this will block until all local and remote RPC processes reach this method 328 and wait for all outstanding work to complete. Otherwise, if 329 ``graceful=False``, this is a local shutdown, and it does not wait for other 330 RPC processes to reach this method. 331 332 .. warning:: 333 For :class:`~torch.futures.Future` objects returned by 334 :meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not 335 be called after ``shutdown()``. 336 337 Args: 338 graceful (bool): Whether to do a graceful shutdown or not. If True, 339 this will 1) wait until there is no pending system 340 messages for ``UserRRefs`` and delete them; 2) block 341 until all local and remote RPC processes have reached 342 this method and wait for all outstanding work to 343 complete. 344 345 Example:: 346 Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly 347 on both workers. Refer to :meth:`~torch.distributed.init_process_group` 348 API for more details. For example, 349 350 export MASTER_ADDR=localhost 351 export MASTER_PORT=5678 352 353 Then run the following code in two different processes: 354 355 >>> # xdoctest: +SKIP 356 >>> # On worker 0: 357 >>> import torch 358 >>> import torch.distributed.rpc as rpc 359 >>> rpc.init_rpc("worker0", rank=0, world_size=2) 360 >>> # do some work 361 >>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1)) 362 >>> # ready to shutdown 363 >>> rpc.shutdown() 364 365 >>> # On worker 1: 366 >>> import torch.distributed.rpc as rpc 367 >>> rpc.init_rpc("worker1", rank=1, world_size=2) 368 >>> # wait for worker 0 to finish work, and then shutdown. 369 >>> rpc.shutdown() 370 """ 371 if graceful: 372 try: 373 agent = _get_current_rpc_agent() 374 if not isinstance(agent, TensorPipeAgent) or agent.is_static_group: 375 _wait_all_workers(timeout) 376 _delete_all_user_and_unforked_owner_rrefs() 377 agent.join(shutdown=True, timeout=timeout) 378 else: 379 # This is a dynamic group so we need to grab the token for the operation 380 my_worker_info = agent.get_worker_info() 381 my_name = my_worker_info.name 382 with _group_membership_management(agent.store, my_name, False): 383 all_worker_infos = agent.get_worker_infos() 384 for worker in all_worker_infos: 385 if worker.name != my_name: 386 rpc_sync( 387 worker.name, 388 _update_group_membership, 389 args=(my_worker_info, [], {}, False), 390 ) 391 agent.join(shutdown=True, timeout=timeout) 392 finally: 393 # In case of errors, continue to complete the local shutdown. 394 _finalize_shutdown() 395 else: 396 _finalize_shutdown() 397 398 399def _finalize_shutdown(): 400 try: 401 # This raises a `TORCH_CHECK()` exception on RRef leak detected. 402 _destroy_rref_context(_ignore_rref_leak) 403 finally: 404 _get_current_rpc_agent().shutdown() 405 # clean up python rpc handler in shutdown(), see comments in 406 # PythonRpcHandler::cleanup(), call it in python API because the 407 # cleanup() function has python dependency, it assumes python 408 # interpreter exists. 409 # No matter if RRef leak exception is raised, this clean-up code 410 # must run to avoid destruction segfault in Python 3.5. 411 # 412 # future.wait() should not be called after shutdown(). 413 # pythonRpcHandler is cleaned up in shutdown(), after 414 # shutdown(), python objects returned from rpc python call can not be 415 # resolved. 416 _cleanup_python_rpc_handler() 417 _reset_current_rpc_agent() 418 419 420@_require_initialized 421def get_worker_info(worker_name=None): 422 r""" 423 Get :class:`~torch.distributed.rpc.WorkerInfo` of a given worker name. 424 Use this :class:`~torch.distributed.rpc.WorkerInfo` to avoid passing an 425 expensive string on every invocation. 426 427 Args: 428 worker_name (str): the string name of a worker. If ``None``, return the 429 the id of the current worker. (default ``None``) 430 431 Returns: 432 :class:`~torch.distributed.rpc.WorkerInfo` instance for the given 433 ``worker_name`` or :class:`~torch.distributed.rpc.WorkerInfo` of the 434 current worker if ``worker_name`` is ``None``. 435 """ 436 if worker_name is not None: 437 return _get_current_rpc_agent().get_worker_info(worker_name) 438 else: 439 return _get_current_rpc_agent().get_worker_info() 440 441 442def _to_worker_info(to): 443 if isinstance(to, WorkerInfo): 444 return to 445 elif isinstance(to, (str, int)): 446 return get_worker_info(to) 447 else: 448 raise ValueError(f"Cannot get WorkerInfo from name {to}") 449 450 451def _rref_typeof_on_owner(rref, blocking: bool = True): 452 rref_type = type(rref.local_value()) 453 if blocking: 454 return rref_type 455 else: 456 # Wrap result into a completed Future. This is so that if blocking=`False` 457 # is specified, we return a future regardless of if this call is on user 458 # or owner. 459 future = Future[type]() 460 future.set_result(rref_type) 461 return future 462 463 464def _rref_typeof_on_user( 465 rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True 466): 467 fut = rpc_async(rref.owner(), _rref_typeof_on_owner, args=(rref,), timeout=timeout) 468 if blocking: 469 return fut.wait() 470 else: 471 return fut 472 473 474T = TypeVar("T") 475GenericWithOneTypeVar = Generic[T] 476 477 478if TYPE_CHECKING: 479 480 class RRef(PyRRef[T], Generic[T]): 481 pass 482 483else: 484 try: 485 # Combine the implementation class and the type class. 486 class RRef(PyRRef, Generic[T]): 487 pass 488 489 except TypeError: 490 # TypeError: metaclass conflict: the metaclass of a derived class 491 # must be a (non-strict) subclass of the metaclasses of all its bases 492 # Mypy doesn't understand __class__ (mypy bug #4177) 493 class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore[name-defined, misc, valid-type] 494 pass 495 496 # Combine the implementation class and the type class. 497 # Types for classes expecting a certain generic parameter (mypy bug #7791) 498 class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore[misc, no-redef, valid-type] 499 pass 500 501 502# Install docstrings from `PyRRef` to `RRef`. 503# 504# This is for the fact that pybind11 generates the parameter 505# `self` as type `rpc.PyRRef`, so a `:inherited-members:` 506# under `.. autoclass:: RRef` does not work. 507# we have to do the following process to replace `rpc.PyRRef` with `rpc.RRef`. 508# 509def method_factory(method_name, docstring): 510 def method(self, *args, **kwargs): 511 return getattr(super(RRef, self), method_name)(*args, **kwargs) 512 513 if method.__doc__: 514 method.__doc__ = docstring 515 return method 516 517 518for method_name, method in inspect.getmembers(PyRRef): 519 # Ignore magic methods, except "__str__". 520 if method_name.startswith("_") and method_name != "__str__": 521 continue 522 523 # Get pybind11 generated docstring. 524 # It's like, 525 """ 526 to_here(self: torch.distributed.rpc.PyRRef, timeout: float=-1.0) -> object 527 528 Blocking call that copies the value of the RRef from the owner 529 to the local node and returns it. If the current node is the 530 owner, returns a reference to the local value. 531 """ 532 docstring = getattr(method, "__doc__", None) 533 assert docstring is not None, "RRef user-facing methods should all have docstrings." 534 535 # Do surgery on pybind11 generated docstrings. 536 docstring = docstring.replace( 537 "torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef" 538 ) 539 540 # Attach user-facing RRef method with modified docstring. 541 new_method = method_factory(method_name, docstring) 542 setattr(RRef, method_name, new_method) 543 544 545@_require_initialized 546def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): 547 r""" 548 Make a remote call to run ``func`` on worker ``to`` and return an 549 :class:`~torch.distributed.rpc.RRef` to the result value immediately. 550 Worker ``to`` will be the owner of the returned 551 :class:`~torch.distributed.rpc.RRef`, and the worker calling ``remote`` is 552 a user. The owner manages the global reference count of its 553 :class:`~torch.distributed.rpc.RRef`, and the owner 554 :class:`~torch.distributed.rpc.RRef` is only destructed when globally there 555 are no living references to it. 556 557 Args: 558 to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. 559 func (Callable): a callable function, such as Python callables, builtin 560 operators (e.g. :meth:`~torch.add`) and annotated 561 TorchScript functions. 562 args (tuple): the argument tuple for the ``func`` invocation. 563 kwargs (dict): is a dictionary of keyword arguments for the ``func`` 564 invocation. 565 566 timeout (float, optional): timeout in seconds for this remote call. If the 567 creation of this 568 :class:`~torch.distributed.rpc.RRef` on worker 569 ``to`` is not successfully processed on this 570 worker within this timeout, then the next time 571 there is an attempt to use the RRef (such as 572 ``to_here()``), a timeout will be raised 573 indicating this failure. A value of 0 indicates 574 an infinite timeout, i.e. a timeout error will 575 never be raised. If not provided, the default 576 value set during initialization or with 577 ``_set_rpc_timeout`` is used. 578 579 Returns: 580 A user :class:`~torch.distributed.rpc.RRef` instance to the result 581 value. Use the blocking API :meth:`torch.distributed.rpc.RRef.to_here` 582 to retrieve the result value locally. 583 584 .. warning :: 585 The ``remote`` API does not copy storages of argument tensors until 586 sending them over the wire, which could be done by a different thread 587 depending on the RPC backend type. The caller should make sure that the 588 contents of those tensors stay intact until the returned RRef is 589 confirmed by the owner, which can be checked using the 590 :meth:`torch.distributed.rpc.RRef.confirmed_by_owner` API. 591 592 .. warning :: 593 Errors such as timeouts for the ``remote`` API are handled on a 594 best-effort basis. This means that when remote calls initiated by 595 ``remote`` fail, such as with a timeout error, we take a best-effort 596 approach to error handling. This means that errors are handled and set 597 on the resulting RRef on an asynchronous basis. If the RRef has not been 598 used by the application before this handling (such as ``to_here`` or 599 fork call), then future uses of the ``RRef`` will appropriately raise 600 errors. However, it is possible that the user application will use the 601 ``RRef`` before the errors are handled. In this case, errors may not be 602 raised as they have not yet been handled. 603 604 Example:: 605 606 Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly 607 on both workers. Refer to :meth:`~torch.distributed.init_process_group` 608 API for more details. For example, 609 610 export MASTER_ADDR=localhost 611 export MASTER_PORT=5678 612 613 Then run the following code in two different processes: 614 615 >>> # xdoctest: +SKIP 616 >>> # On worker 0: 617 >>> import torch 618 >>> import torch.distributed.rpc as rpc 619 >>> rpc.init_rpc("worker0", rank=0, world_size=2) 620 >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) 621 >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) 622 >>> x = rref1.to_here() + rref2.to_here() 623 >>> rpc.shutdown() 624 625 >>> # On worker 1: 626 >>> import torch.distributed.rpc as rpc 627 >>> rpc.init_rpc("worker1", rank=1, world_size=2) 628 >>> rpc.shutdown() 629 630 Below is an example of running a TorchScript function using RPC. 631 632 >>> # On both workers: 633 >>> @torch.jit.script 634 >>> def my_script_add(tensor: torch.Tensor, scalar: int): 635 >>> return torch.add(tensor, scalar) 636 637 >>> # On worker 0: 638 >>> import torch.distributed.rpc as rpc 639 >>> rpc.init_rpc("worker0", rank=0, world_size=2) 640 >>> rref = rpc.remote("worker1", my_script_add, args=(torch.ones(2), 3)) 641 >>> rref.to_here() 642 >>> rpc.shutdown() 643 644 >>> # On worker 1: 645 >>> import torch.distributed.rpc as rpc 646 >>> rpc.init_rpc("worker1", rank=1, world_size=2) 647 >>> rpc.shutdown() 648 """ 649 torch._C._log_api_usage_once("torch.distributed.rpc_remote") 650 qualified_name = torch.jit._builtins._find_builtin(func) 651 dst_worker_info = _to_worker_info(to) 652 should_profile = _get_should_profile() 653 654 ctx_manager = _enable_rpc_profiler( 655 should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info 656 ) 657 658 with ctx_manager as rf: 659 args = args if args else () 660 kwargs = kwargs if kwargs else {} 661 662 is_async_exec = hasattr(func, "_wrapped_async_rpc_function") 663 664 if is_async_exec: 665 wrapped = func._wrapped_async_rpc_function 666 if isinstance(wrapped, torch.jit.ScriptFunction): 667 func = wrapped 668 669 if qualified_name is not None: 670 rref = _invoke_remote_builtin( 671 dst_worker_info, qualified_name, timeout, *args, **kwargs 672 ) 673 elif isinstance(func, torch.jit.ScriptFunction): 674 rref = _invoke_remote_torchscript( 675 dst_worker_info.name, 676 torch._jit_internal._qualified_name(func), 677 timeout, 678 is_async_exec, 679 *args, 680 **kwargs, 681 ) 682 else: 683 (pickled_python_udf, tensors) = _default_pickler.serialize( 684 PythonUDF(func, args, kwargs) 685 ) 686 rref = _invoke_remote_python_udf( 687 dst_worker_info, pickled_python_udf, tensors, timeout, is_async_exec 688 ) 689 # attach profiling information 690 if should_profile: 691 assert torch.autograd._profiler_enabled() 692 assert rf is not None 693 fut = rf._call_end_callbacks_on_future(rref._get_future()) 694 rref._set_profiling_future(fut) 695 696 return rref 697 698 699def _invoke_rpc( 700 to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT 701): 702 if not callable(func): 703 raise TypeError("function should be callable.") 704 705 qualified_name = torch.jit._builtins._find_builtin(func) 706 dst_worker_info = _to_worker_info(to) 707 708 should_profile = _get_should_profile() 709 710 ctx_manager = _enable_rpc_profiler( 711 should_profile, qualified_name, func, rpc_type, dst_worker_info 712 ) 713 714 with ctx_manager as rf: 715 args = args if args else () 716 kwargs = kwargs if kwargs else {} 717 718 is_async_exec = hasattr(func, "_wrapped_async_rpc_function") 719 720 if is_async_exec: 721 wrapped = func._wrapped_async_rpc_function 722 if isinstance(wrapped, torch.jit.ScriptFunction): 723 func = wrapped 724 725 if qualified_name is not None: 726 fut = _invoke_rpc_builtin( 727 dst_worker_info, qualified_name, rpc_timeout, *args, **kwargs 728 ) 729 elif isinstance(func, torch.jit.ScriptFunction): 730 fut = _invoke_rpc_torchscript( 731 dst_worker_info.name, 732 torch._jit_internal._qualified_name(func), 733 args, 734 kwargs, 735 rpc_timeout, 736 is_async_exec, 737 ) 738 else: 739 (pickled_python_udf, tensors) = _default_pickler.serialize( 740 PythonUDF(func, args, kwargs) 741 ) 742 fut = _invoke_rpc_python_udf( 743 dst_worker_info, pickled_python_udf, tensors, rpc_timeout, is_async_exec 744 ) 745 if should_profile: 746 assert torch.autograd._profiler_enabled() 747 assert rf is not None 748 # Schedule profiling callbacks to run when the future completes. 749 # This returns a future that is completed when the original future 750 # completes and the profiling callbacks have been completed as well, 751 # to guarantee that fut.wait() completes the profiling. This new 752 # future will contain the same value as the original future. 753 fut = rf._call_end_callbacks_on_future(fut) 754 return fut 755 756 757@_require_initialized 758def rpc_sync(to, func, args=None, kwargs=None, timeout: float = UNSET_RPC_TIMEOUT): 759 r""" 760 Make a blocking RPC call to run function ``func`` on worker ``to``. RPC 761 messages are sent and received in parallel to execution of Python code. This 762 method is thread-safe. 763 764 Args: 765 to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. 766 func (Callable): a callable function, such as Python callables, builtin 767 operators (e.g. :meth:`~torch.add`) and annotated 768 TorchScript functions. 769 args (tuple): the argument tuple for the ``func`` invocation. 770 kwargs (dict): is a dictionary of keyword arguments for the ``func`` 771 invocation. 772 timeout (float, optional): timeout in seconds to use for this RPC. If 773 the RPC does not complete in this amount of 774 time, an exception indicating it has 775 timed out will be raised. A value of 0 776 indicates an infinite timeout, i.e. a timeout 777 error will never be raised. If not provided, 778 the default value set during initialization 779 or with ``_set_rpc_timeout`` is used. 780 781 Returns: 782 Returns the result of running ``func`` with ``args`` and ``kwargs``. 783 784 Example:: 785 Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly 786 on both workers. Refer to :meth:`~torch.distributed.init_process_group` 787 API for more details. For example, 788 789 export MASTER_ADDR=localhost 790 export MASTER_PORT=5678 791 792 Then run the following code in two different processes: 793 794 >>> # xdoctest: +SKIP 795 >>> # On worker 0: 796 >>> import torch 797 >>> import torch.distributed.rpc as rpc 798 >>> rpc.init_rpc("worker0", rank=0, world_size=2) 799 >>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3)) 800 >>> rpc.shutdown() 801 802 >>> # On worker 1: 803 >>> import torch.distributed.rpc as rpc 804 >>> rpc.init_rpc("worker1", rank=1, world_size=2) 805 >>> rpc.shutdown() 806 807 Below is an example of running a TorchScript function using RPC. 808 809 >>> # On both workers: 810 >>> @torch.jit.script 811 >>> def my_script_add(tensor: torch.Tensor, scalar: int): 812 >>> return torch.add(tensor, scalar) 813 814 >>> # On worker 0: 815 >>> import torch.distributed.rpc as rpc 816 >>> rpc.init_rpc("worker0", rank=0, world_size=2) 817 >>> ret = rpc.rpc_sync("worker1", my_script_add, args=(torch.ones(2), 3)) 818 >>> rpc.shutdown() 819 820 >>> # On worker 1: 821 >>> import torch.distributed.rpc as rpc 822 >>> rpc.init_rpc("worker1", rank=1, world_size=2) 823 >>> rpc.shutdown() 824 825 """ 826 torch._C._log_api_usage_once("torch.distributed.rpc_sync") 827 fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs, timeout) 828 return fut.wait() 829 830 831@_require_initialized 832def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): 833 r""" 834 Make a non-blocking RPC call to run function ``func`` on worker ``to``. RPC 835 messages are sent and received in parallel to execution of Python code. This 836 method is thread-safe. This method will immediately return a 837 :class:`~torch.futures.Future` that can be awaited on. 838 839 Args: 840 to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. 841 func (Callable): a callable function, such as Python callables, builtin 842 operators (e.g. :meth:`~torch.add`) and annotated 843 TorchScript functions. 844 args (tuple): the argument tuple for the ``func`` invocation. 845 kwargs (dict): is a dictionary of keyword arguments for the ``func`` 846 invocation. 847 timeout (float, optional): timeout in seconds to use for this RPC. If 848 the RPC does not complete in this amount of 849 time, an exception indicating it has 850 timed out will be raised. A value of 0 851 indicates an infinite timeout, i.e. a timeout 852 error will never be raised. If not provided, 853 the default value set during initialization 854 or with ``_set_rpc_timeout`` is used. 855 856 857 Returns: 858 Returns a :class:`~torch.futures.Future` object that can be waited 859 on. When completed, the return value of ``func`` on ``args`` and 860 ``kwargs`` can be retrieved from the :class:`~torch.futures.Future` 861 object. 862 863 .. warning :: 864 Using GPU tensors as arguments or return values of ``func`` is not 865 supported since we don't support sending GPU tensors over the wire. You 866 need to explicitly copy GPU tensors to CPU before using them as 867 arguments or return values of ``func``. 868 869 .. warning :: 870 The ``rpc_async`` API does not copy storages of argument tensors until 871 sending them over the wire, which could be done by a different thread 872 depending on the RPC backend type. The caller should make sure that the 873 contents of those tensors stay intact until the returned 874 :class:`~torch.futures.Future` completes. 875 876 Example:: 877 Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly 878 on both workers. Refer to :meth:`~torch.distributed.init_process_group` 879 API for more details. For example, 880 881 export MASTER_ADDR=localhost 882 export MASTER_PORT=5678 883 884 Then run the following code in two different processes: 885 886 >>> # xdoctest: +SKIP 887 >>> # On worker 0: 888 >>> import torch 889 >>> import torch.distributed.rpc as rpc 890 >>> rpc.init_rpc("worker0", rank=0, world_size=2) 891 >>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3)) 892 >>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2)) 893 >>> result = fut1.wait() + fut2.wait() 894 >>> rpc.shutdown() 895 896 >>> # On worker 1: 897 >>> import torch.distributed.rpc as rpc 898 >>> rpc.init_rpc("worker1", rank=1, world_size=2) 899 >>> rpc.shutdown() 900 901 Below is an example of running a TorchScript function using RPC. 902 903 >>> # On both workers: 904 >>> @torch.jit.script 905 >>> def my_script_add(tensor: torch.Tensor, scalar: int): 906 >>> return torch.add(tensor, scalar) 907 908 >>> # On worker 0: 909 >>> import torch.distributed.rpc as rpc 910 >>> rpc.init_rpc("worker0", rank=0, world_size=2) 911 >>> fut = rpc.rpc_async("worker1", my_script_add, args=(torch.ones(2), 3)) 912 >>> ret = fut.wait() 913 >>> rpc.shutdown() 914 915 >>> # On worker 1: 916 >>> import torch.distributed.rpc as rpc 917 >>> rpc.init_rpc("worker1", rank=1, world_size=2) 918 >>> rpc.shutdown() 919 """ 920 torch._C._log_api_usage_once("torch.distributed.rpc_async") 921 fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout) 922 if hasattr(_thread_local_var, "future_list"): 923 _thread_local_var.future_list.append(fut) 924 return fut 925 926 927def _get_should_profile(): 928 # Legacy profiler should be enabled. RPC profiling is not supported with 929 # Kineto profiler. 930 ActiveProfilerType = torch._C._profiler.ActiveProfilerType 931 return ( 932 torch.autograd._profiler_enabled() 933 and torch._C._autograd._profiler_type() 934 == ActiveProfilerType.LEGACY # type: ignore[attr-defined] 935 ) 936 937 938def _enable_rpc_profiler( 939 should_profile, qualified_name, func, rpc_type, dst_worker_info 940): 941 ctx_manager = contextlib.nullcontext() 942 943 if should_profile: 944 # Create appropriate string representation based on type of func 945 # (builtin, script, python) 946 if qualified_name is None: 947 func_name = ( 948 torch._jit_internal._qualified_name(func) 949 if isinstance(func, torch.jit.ScriptFunction) 950 else func.__qualname__ 951 ) 952 else: 953 func_name = qualified_name 954 # Build RPC profiling key. 955 rpc_profiling_key = _build_rpc_profiling_key( 956 rpc_type, 957 func_name, 958 get_worker_info().name, 959 dst_worker_info.name, 960 ) 961 RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key) 962 # Mypy doesn't support re-def of a variable not in the same block (#1174) 963 ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment] 964 965 return ctx_manager 966