xref: /aosp_15_r20/external/pytorch/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3
4import torch.distributed as dist
5import torch.distributed.rpc as rpc
6
7
8def _faulty_tensorpipe_construct_rpc_backend_options_handler(
9    rpc_timeout,
10    init_method,
11    num_worker_threads,
12    messages_to_fail,
13    messages_to_delay,
14    num_fail_sends,
15    **kwargs,
16):
17    from . import FaultyTensorPipeRpcBackendOptions
18
19    return FaultyTensorPipeRpcBackendOptions(
20        num_worker_threads=num_worker_threads,
21        rpc_timeout=rpc_timeout,
22        init_method=init_method,
23        messages_to_fail=messages_to_fail,
24        messages_to_delay=messages_to_delay,
25        num_fail_sends=num_fail_sends,
26    )
27
28
29def _faulty_tensorpipe_init_backend_handler(
30    store, name, rank, world_size, rpc_backend_options
31):
32    from torch.distributed.rpc import api
33
34    from . import FaultyTensorPipeAgent, FaultyTensorPipeRpcBackendOptions
35
36    if not isinstance(store, dist.Store):
37        raise TypeError(f"`store` must be a c10d::Store. {store}")
38
39    if not isinstance(rpc_backend_options, FaultyTensorPipeRpcBackendOptions):
40        raise TypeError(
41            f"`rpc_backend_options` must be a `FaultyTensorPipeRpcBackendOptions`. {rpc_backend_options}"
42        )
43
44    agent = FaultyTensorPipeAgent(
45        store,
46        name,
47        rank,
48        world_size,
49        rpc_backend_options,
50        {},  # reverse_device_map
51        [],  # devices
52    )
53    api._init_rpc_states(agent)
54
55    return agent
56
57
58rpc.backend_registry.register_backend(
59    "FAULTY_TENSORPIPE",
60    _faulty_tensorpipe_construct_rpc_backend_options_handler,
61    _faulty_tensorpipe_init_backend_handler,
62)
63