1#!/usr/bin/env python3 2# mypy: allow-untyped-defs 3 4import torch.distributed as dist 5import torch.distributed.rpc as rpc 6 7 8def _faulty_tensorpipe_construct_rpc_backend_options_handler( 9 rpc_timeout, 10 init_method, 11 num_worker_threads, 12 messages_to_fail, 13 messages_to_delay, 14 num_fail_sends, 15 **kwargs, 16): 17 from . import FaultyTensorPipeRpcBackendOptions 18 19 return FaultyTensorPipeRpcBackendOptions( 20 num_worker_threads=num_worker_threads, 21 rpc_timeout=rpc_timeout, 22 init_method=init_method, 23 messages_to_fail=messages_to_fail, 24 messages_to_delay=messages_to_delay, 25 num_fail_sends=num_fail_sends, 26 ) 27 28 29def _faulty_tensorpipe_init_backend_handler( 30 store, name, rank, world_size, rpc_backend_options 31): 32 from torch.distributed.rpc import api 33 34 from . import FaultyTensorPipeAgent, FaultyTensorPipeRpcBackendOptions 35 36 if not isinstance(store, dist.Store): 37 raise TypeError(f"`store` must be a c10d::Store. {store}") 38 39 if not isinstance(rpc_backend_options, FaultyTensorPipeRpcBackendOptions): 40 raise TypeError( 41 f"`rpc_backend_options` must be a `FaultyTensorPipeRpcBackendOptions`. {rpc_backend_options}" 42 ) 43 44 agent = FaultyTensorPipeAgent( 45 store, 46 name, 47 rank, 48 world_size, 49 rpc_backend_options, 50 {}, # reverse_device_map 51 [], # devices 52 ) 53 api._init_rpc_states(agent) 54 55 return agent 56 57 58rpc.backend_registry.register_backend( 59 "FAULTY_TENSORPIPE", 60 _faulty_tensorpipe_construct_rpc_backend_options_handler, 61 _faulty_tensorpipe_init_backend_handler, 62) 63