1 #include <torch/csrc/python_headers.h>
2
3 #include <torch/csrc/distributed/rpc/request_callback_impl.h>
4 #include <torch/csrc/distributed/rpc/rpc_agent.h>
5 #include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
6 #include <torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.h>
7 #include <torch/csrc/utils/pybind.h>
8
9 #include <pybind11/chrono.h>
10
11 namespace torch {
12 namespace distributed {
13 namespace rpc {
14 namespace testing {
15
16 namespace {
17
18 template <typename T>
19 using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
20
faulty_agent_init(PyObject * _unused,PyObject * noargs)21 PyObject* faulty_agent_init(PyObject* _unused, PyObject* noargs) {
22 // Add the FaultyTensorPipeAgent and its backend options object
23 // to the python module torch._C._distributed_rpc_testing
24 auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
25 if (!torch_C_module) {
26 throw python_error();
27 }
28
29 auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
30 auto m = torch_C_m.def_submodule(
31 "_distributed_rpc_testing", "distributed rpc testing bindings");
32 auto module = py::handle(m).cast<py::module>();
33
34 // Import the rpc_module so we can subclass TensorPipeAgent
35 py::module rpc_module = py::module::import("torch.distributed.rpc");
36
37 #ifdef USE_TENSORPIPE
38 shared_ptr_class_<FaultyTensorPipeRpcBackendOptions>(
39 module,
40 "FaultyTensorPipeRpcBackendOptions",
41 rpc_module.attr("_TensorPipeRpcBackendOptionsBase"))
42 .def(
43 py::init<
44 int,
45 float,
46 std::string,
47 std::vector<std::string>,
48 std::unordered_map<std::string, float>,
49 int>(),
50 py::arg("num_worker_threads"),
51 py::arg("rpc_timeout"),
52 py::arg("init_method"),
53 py::arg("messages_to_fail"),
54 py::arg("messages_to_delay"),
55 py::arg("num_fail_sends"))
56 .def_readwrite(
57 "num_worker_threads", &TensorPipeRpcBackendOptions::numWorkerThreads)
58 .def_readwrite(
59 "messages_to_fail",
60 &FaultyTensorPipeRpcBackendOptions::messagesToFail)
61 .def_readwrite(
62 "messages_to_delay",
63 &FaultyTensorPipeRpcBackendOptions::messagesToDelay)
64 .def_readwrite(
65 "num_fail_sends", &FaultyTensorPipeRpcBackendOptions::numFailSends);
66
67 shared_ptr_class_<FaultyTensorPipeAgent>(
68 module, "FaultyTensorPipeAgent", rpc_module.attr("TensorPipeAgent"))
69 .def(
70 py::init(
71 [](const c10::intrusive_ptr<::c10d::Store> store,
72 std::string name,
73 worker_id_t rank,
74 int world_size,
75 FaultyTensorPipeRpcBackendOptions opts,
76 std::unordered_map<std::string, DeviceMap> reverse_device_maps,
77 std::vector<c10::Device> devices) {
78 return std::shared_ptr<FaultyTensorPipeAgent>(
79 new FaultyTensorPipeAgent(
80 store,
81 std::move(name),
82 rank,
83 world_size,
84 opts,
85 reverse_device_maps,
86 devices,
87 std::make_unique<RequestCallbackImpl>()),
88 impl::destroy_without_gil<FaultyTensorPipeAgent>);
89 }),
90 py::arg("store"),
91 py::arg("name"),
92 py::arg("rank"),
93 py::arg("world_size"),
94 py::arg("opts"),
95 py::arg("reverse_device_maps"),
96 py::arg("devices"))
97 .def(
98 "join",
99 &TensorPipeAgent::join,
100 py::call_guard<py::gil_scoped_release>(),
101 py::arg("shutdown") = false,
102 py::arg("timeout") = 0)
103 .def(
104 "shutdown",
105 &TensorPipeAgent::shutdown,
106 py::call_guard<py::gil_scoped_release>())
107 .def(
108 "get_worker_info",
109 (const WorkerInfo& (TensorPipeAgent::*)(void) const) &
110 RpcAgent::getWorkerInfo,
111 py::call_guard<py::gil_scoped_release>())
112 .def(
113 "get_worker_info",
114 (const WorkerInfo& (TensorPipeAgent::*)(const std::string&) const) &
115 TensorPipeAgent::getWorkerInfo,
116 py::call_guard<py::gil_scoped_release>())
117 .def(
118 "get_worker_info",
119 (const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id) const) &
120 TensorPipeAgent::getWorkerInfo,
121 py::call_guard<py::gil_scoped_release>())
122 .def(
123 "get_worker_infos",
124 (std::vector<WorkerInfo>(TensorPipeAgent::*)() const) &
125 TensorPipeAgent::getWorkerInfos,
126 py::call_guard<py::gil_scoped_release>());
127 #endif // USE_TENSORPIPE
128
129 Py_RETURN_TRUE;
130 }
131
132 } // namespace
133
134 static PyMethodDef methods[] = { // NOLINT
135 {"_faulty_agent_init", faulty_agent_init, METH_NOARGS, nullptr},
136 {nullptr, nullptr, 0, nullptr}};
137
python_functions()138 PyMethodDef* python_functions() {
139 return methods;
140 }
141
142 } // namespace testing
143 } // namespace rpc
144 } // namespace distributed
145 } // namespace torch
146