xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/testing/init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/python_headers.h>
2 
3 #include <torch/csrc/distributed/rpc/request_callback_impl.h>
4 #include <torch/csrc/distributed/rpc/rpc_agent.h>
5 #include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
6 #include <torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.h>
7 #include <torch/csrc/utils/pybind.h>
8 
9 #include <pybind11/chrono.h>
10 
11 namespace torch {
12 namespace distributed {
13 namespace rpc {
14 namespace testing {
15 
16 namespace {
17 
18 template <typename T>
19 using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
20 
faulty_agent_init(PyObject * _unused,PyObject * noargs)21 PyObject* faulty_agent_init(PyObject* _unused, PyObject* noargs) {
22   // Add the FaultyTensorPipeAgent and its backend options object
23   // to the python module torch._C._distributed_rpc_testing
24   auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
25   if (!torch_C_module) {
26     throw python_error();
27   }
28 
29   auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
30   auto m = torch_C_m.def_submodule(
31       "_distributed_rpc_testing", "distributed rpc testing bindings");
32   auto module = py::handle(m).cast<py::module>();
33 
34   // Import the rpc_module so we can subclass TensorPipeAgent
35   py::module rpc_module = py::module::import("torch.distributed.rpc");
36 
37 #ifdef USE_TENSORPIPE
38   shared_ptr_class_<FaultyTensorPipeRpcBackendOptions>(
39       module,
40       "FaultyTensorPipeRpcBackendOptions",
41       rpc_module.attr("_TensorPipeRpcBackendOptionsBase"))
42       .def(
43           py::init<
44               int,
45               float,
46               std::string,
47               std::vector<std::string>,
48               std::unordered_map<std::string, float>,
49               int>(),
50           py::arg("num_worker_threads"),
51           py::arg("rpc_timeout"),
52           py::arg("init_method"),
53           py::arg("messages_to_fail"),
54           py::arg("messages_to_delay"),
55           py::arg("num_fail_sends"))
56       .def_readwrite(
57           "num_worker_threads", &TensorPipeRpcBackendOptions::numWorkerThreads)
58       .def_readwrite(
59           "messages_to_fail",
60           &FaultyTensorPipeRpcBackendOptions::messagesToFail)
61       .def_readwrite(
62           "messages_to_delay",
63           &FaultyTensorPipeRpcBackendOptions::messagesToDelay)
64       .def_readwrite(
65           "num_fail_sends", &FaultyTensorPipeRpcBackendOptions::numFailSends);
66 
67   shared_ptr_class_<FaultyTensorPipeAgent>(
68       module, "FaultyTensorPipeAgent", rpc_module.attr("TensorPipeAgent"))
69       .def(
70           py::init(
71               [](const c10::intrusive_ptr<::c10d::Store> store,
72                  std::string name,
73                  worker_id_t rank,
74                  int world_size,
75                  FaultyTensorPipeRpcBackendOptions opts,
76                  std::unordered_map<std::string, DeviceMap> reverse_device_maps,
77                  std::vector<c10::Device> devices) {
78                 return std::shared_ptr<FaultyTensorPipeAgent>(
79                     new FaultyTensorPipeAgent(
80                         store,
81                         std::move(name),
82                         rank,
83                         world_size,
84                         opts,
85                         reverse_device_maps,
86                         devices,
87                         std::make_unique<RequestCallbackImpl>()),
88                     impl::destroy_without_gil<FaultyTensorPipeAgent>);
89               }),
90           py::arg("store"),
91           py::arg("name"),
92           py::arg("rank"),
93           py::arg("world_size"),
94           py::arg("opts"),
95           py::arg("reverse_device_maps"),
96           py::arg("devices"))
97       .def(
98           "join",
99           &TensorPipeAgent::join,
100           py::call_guard<py::gil_scoped_release>(),
101           py::arg("shutdown") = false,
102           py::arg("timeout") = 0)
103       .def(
104           "shutdown",
105           &TensorPipeAgent::shutdown,
106           py::call_guard<py::gil_scoped_release>())
107       .def(
108           "get_worker_info",
109           (const WorkerInfo& (TensorPipeAgent::*)(void) const) &
110               RpcAgent::getWorkerInfo,
111           py::call_guard<py::gil_scoped_release>())
112       .def(
113           "get_worker_info",
114           (const WorkerInfo& (TensorPipeAgent::*)(const std::string&) const) &
115               TensorPipeAgent::getWorkerInfo,
116           py::call_guard<py::gil_scoped_release>())
117       .def(
118           "get_worker_info",
119           (const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id) const) &
120               TensorPipeAgent::getWorkerInfo,
121           py::call_guard<py::gil_scoped_release>())
122       .def(
123           "get_worker_infos",
124           (std::vector<WorkerInfo>(TensorPipeAgent::*)() const) &
125               TensorPipeAgent::getWorkerInfos,
126           py::call_guard<py::gil_scoped_release>());
127 #endif // USE_TENSORPIPE
128 
129   Py_RETURN_TRUE;
130 }
131 
132 } // namespace
133 
134 static PyMethodDef methods[] = { // NOLINT
135     {"_faulty_agent_init", faulty_agent_init, METH_NOARGS, nullptr},
136     {nullptr, nullptr, 0, nullptr}};
137 
python_functions()138 PyMethodDef* python_functions() {
139   return methods;
140 }
141 
142 } // namespace testing
143 } // namespace rpc
144 } // namespace distributed
145 } // namespace torch
146