xref: /aosp_15_r20/external/pytorch/torch/distributed/elastic/rendezvous/registry.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Facebook, Inc. and its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from .api import (
8    rendezvous_handler_registry as handler_registry,
9    RendezvousHandler,
10    RendezvousParameters,
11)
12from .dynamic_rendezvous import create_handler
13
14
15__all__ = ["get_rendezvous_handler"]
16
17
18def _create_static_handler(params: RendezvousParameters) -> RendezvousHandler:
19    from . import static_tcp_rendezvous
20
21    return static_tcp_rendezvous.create_rdzv_handler(params)
22
23
24def _create_etcd_handler(params: RendezvousParameters) -> RendezvousHandler:
25    from . import etcd_rendezvous
26
27    return etcd_rendezvous.create_rdzv_handler(params)
28
29
30def _create_etcd_v2_handler(params: RendezvousParameters) -> RendezvousHandler:
31    from .etcd_rendezvous_backend import create_backend
32
33    backend, store = create_backend(params)
34
35    return create_handler(store, backend, params)
36
37
38def _create_c10d_handler(params: RendezvousParameters) -> RendezvousHandler:
39    from .c10d_rendezvous_backend import create_backend
40
41    backend, store = create_backend(params)
42
43    return create_handler(store, backend, params)
44
45
46def _register_default_handlers() -> None:
47    handler_registry.register("etcd", _create_etcd_handler)
48    handler_registry.register("etcd-v2", _create_etcd_v2_handler)
49    handler_registry.register("c10d", _create_c10d_handler)
50    handler_registry.register("static", _create_static_handler)
51
52
53def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler:
54    """
55    Obtain a reference to a :py:class`RendezvousHandler`.
56
57    Custom rendezvous handlers can be registered by
58
59    ::
60
61      from torch.distributed.elastic.rendezvous import rendezvous_handler_registry
62      from torch.distributed.elastic.rendezvous.registry import get_rendezvous_handler
63
64      def create_my_rdzv(params: RendezvousParameters):
65        return MyCustomRdzv(params)
66
67      rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv)
68
69      my_rdzv_handler = get_rendezvous_handler("my_rdzv_backend_name", RendezvousParameters)
70    """
71    return handler_registry.create_handler(params)
72