1 #include <ATen/ThreadLocalState.h>
2 #include <torch/csrc/distributed/autograd/context/container.h>
3 #include <torch/csrc/distributed/autograd/utils.h>
4 #include <torch/csrc/distributed/rpc/message.h>
5 #include <torch/csrc/distributed/rpc/python_call.h>
6 #include <torch/csrc/distributed/rpc/python_functions.h>
7 #include <torch/csrc/distributed/rpc/python_remote_call.h>
8 #include <torch/csrc/distributed/rpc/python_resp.h>
9 #include <torch/csrc/distributed/rpc/python_rpc_handler.h>
10 #include <torch/csrc/distributed/rpc/rref_context.h>
11 #include <torch/csrc/distributed/rpc/rref_proto.h>
12 #include <torch/csrc/distributed/rpc/script_call.h>
13 #include <torch/csrc/distributed/rpc/script_remote_call.h>
14 #include <torch/csrc/distributed/rpc/script_resp.h>
15 #include <torch/csrc/distributed/rpc/torchscript_functions.h>
16 #include <torch/csrc/distributed/rpc/utils.h>
17 #include <torch/csrc/jit/runtime/operator.h>
18 #include <torch/csrc/utils/python_compat.h>
19 #include <exception>
20
21 namespace torch::distributed::rpc {
22
23 namespace {
24
toPyIValue(const Message & message)25 IValue toPyIValue(const Message& message) {
26 MessageType msgType = message.type();
27 auto response = deserializeResponse(message, msgType);
28 switch (msgType) {
29 case MessageType::SCRIPT_RET: {
30 auto& ret = static_cast<ScriptResp&>(*response);
31 Stack stack;
32 stack.push_back(ret.value());
33 // Need GIL to guard createPyObjectForStack() and its returned
34 // py::object
35 py::gil_scoped_acquire acquire;
36 return jit::toIValue(
37 torch::jit::createPyObjectForStack(std::move(stack)),
38 PyObjectType::get());
39 }
40 case MessageType::PYTHON_RET: {
41 // TODO: Try to avoid a copy here.
42 auto& resp = static_cast<PythonResp&>(*response);
43 auto& pythonRpcHandler = PythonRpcHandler::getInstance();
44 // Need GIL to destruct the py::object returned by deserialize()
45 py::gil_scoped_acquire acquire;
46 py::object value = pythonRpcHandler.deserialize(resp.serializedPyObj());
47 pythonRpcHandler.handleException(value);
48 return jit::toIValue(value, PyObjectType::get());
49 }
50 default: {
51 TORCH_CHECK(false, "Unrecognized response message type ", msgType);
52 }
53 }
54 }
55
matchBuiltinOp(const std::string & opName,const py::args & args,const py::kwargs & kwargs,Stack & stack)56 std::shared_ptr<Operator> matchBuiltinOp(
57 const std::string& opName,
58 const py::args& args,
59 const py::kwargs& kwargs,
60 Stack& stack) {
61 Symbol symbol = Symbol::fromQualString(opName);
62
63 std::shared_ptr<jit::Operator> matchedOperator;
64 if (symbol.is_aten()) {
65 // Prefer C10 ops so that they go through C10 dispatch. We expect the
66 // total # of possible overloaded ops (i.e. size of below ops list) to be
67 // small (i.e. it is 10 for torch.add) so a worst-case linear search should
68 // not incur significant extra overhead.
69 auto ops = torch::jit::getAllOperatorsFor(symbol);
70 std::vector<std::shared_ptr<torch::jit::Operator>> c10OpsForSymbol;
71 for (auto it = ops.begin(); it != ops.end();) {
72 std::shared_ptr<jit::Operator> op = *it;
73 if (op->isC10Op()) {
74 c10OpsForSymbol.emplace_back(std::move(op));
75 it = ops.erase(it);
76 } else {
77 ++it;
78 }
79 }
80
81 // Don't throw on failures in this call, since we are not examining on all
82 // operators here, and the matched operator may indeed not be a c10 op.
83 std::pair<std::shared_ptr<torch::jit::Operator>, torch::jit::Stack>
84 opWithStack;
85 try {
86 opWithStack = torch::jit::getOpWithStack(c10OpsForSymbol, args, kwargs);
87 } catch (const std::runtime_error& e) {
88 opWithStack = torch::jit::getOpWithStack(ops, args, kwargs);
89 }
90 matchedOperator = std::get<0>(opWithStack);
91 stack = std::get<1>(opWithStack);
92 }
93
94 // We should never hit this path, since if !matchedOperator, then the last
95 // call to getOpWithStack should have thrown.
96 TORCH_CHECK(
97 matchedOperator != nullptr,
98 "Failed to match operator name ",
99 opName,
100 " and arguments "
101 "(args: ",
102 args,
103 ", kwargs: ",
104 kwargs,
105 ") to a builtin operator");
106
107 return matchedOperator;
108 }
109
sendPythonRemoteCall(const WorkerInfo & dst,SerializedPyObj serializedPyObj,const IValue & rrefId,const IValue & forkId,const float rpcTimeoutSeconds,const bool isAsyncExecution)110 c10::intrusive_ptr<JitFuture> sendPythonRemoteCall(
111 const WorkerInfo& dst,
112 SerializedPyObj serializedPyObj,
113 const IValue& rrefId,
114 const IValue& forkId,
115 const float rpcTimeoutSeconds,
116 const bool isAsyncExecution) {
117 auto pythonRemoteCall = std::make_unique<PythonRemoteCall>(
118 std::move(serializedPyObj), rrefId, forkId, isAsyncExecution);
119
120 // set forceGradRecording to true as even if the args does not contain any
121 // tensor, the return value might still contain tensors.
122 auto agent = RpcAgent::getCurrentRpcAgent();
123 return torch::distributed::autograd::sendMessageWithAutograd(
124 *agent,
125 dst,
126 std::move(*pythonRemoteCall).toMessage(),
127 true /*forceGradRecording*/,
128 rpcTimeoutSeconds);
129 }
130
131 } // namespace
132
133 using namespace torch::distributed::autograd;
134
toPyJitFuture(const c10::intrusive_ptr<JitFuture> & messageJitFuture,bool hasValue)135 c10::intrusive_ptr<JitFuture> toPyJitFuture(
136 const c10::intrusive_ptr<JitFuture>& messageJitFuture,
137 bool hasValue) {
138 if (hasValue) {
139 auto child = messageJitFuture->createInstance(PyObjectType::get());
140 messageJitFuture->addCallback(
141 at::wrapPropagateTLSState([child](JitFuture& future) {
142 if (future.hasError()) {
143 child->setError(future.exception_ptr());
144 } else {
145 const Message& message = *future.value().toCustomClass<Message>();
146
147 // toPyIValue might throw and we need to record the appropriate
148 // exception.
149 IValue ivalue;
150 try {
151 ivalue = toPyIValue(message);
152 } catch (py::error_already_set& e) {
153 py::gil_scoped_acquire acquire;
154 // FIXME: this is a temporary solution to add a special-case for
155 // ValueError and TypeError, as those are already used in our
156 // tests. We should have a more comprehensive coverage for other
157 // types of exceptions as well.
158 if (e.matches(PyExc_ValueError)) {
159 child->setErrorIfNeeded(
160 std::make_exception_ptr(pybind11::value_error(e.what())));
161 } else if (e.matches(PyExc_TypeError)) {
162 child->setErrorIfNeeded(
163 std::make_exception_ptr(pybind11::type_error(e.what())));
164 } else {
165 // py::error_already_set requires GIL to destruct, take special
166 // care.
167 child->setErrorIfNeeded(
168 std::make_exception_ptr(std::runtime_error(e.what())));
169 }
170 e.restore();
171 PyErr_Clear();
172 return;
173 } catch (std::exception& e) {
174 child->setErrorIfNeeded(std::current_exception());
175 return;
176 }
177
178 child->markCompleted(ivalue, future.storages());
179 }
180 }));
181 return child;
182 } else {
183 return messageJitFuture->then(
184 at::wrapPropagateTLSState([](JitFuture& future) {
185 if (future.hasError()) {
186 std::rethrow_exception(future.exception_ptr());
187 } else {
188 return IValue();
189 }
190 }),
191 NoneType::get());
192 }
193 }
194
pyRpcBuiltin(const WorkerInfo & dst,const std::string & opName,const py::args & args,const py::kwargs & kwargs,const float rpcTimeoutSeconds)195 c10::intrusive_ptr<JitFuture> pyRpcBuiltin(
196 const WorkerInfo& dst,
197 const std::string& opName,
198 const py::args& args,
199 const py::kwargs& kwargs,
200 const float rpcTimeoutSeconds) {
201 DCHECK(PyGILState_Check());
202 Stack stack;
203 auto op = matchBuiltinOp(opName, args, kwargs, stack);
204 // Release GIL since args and kwargs processing is done.
205 py::gil_scoped_release release;
206 auto scriptCall = std::make_unique<ScriptCall>(op, std::move(stack));
207 auto agent = RpcAgent::getCurrentRpcAgent();
208 return toPyJitFuture(sendMessageWithAutograd(
209 *agent,
210 dst,
211 std::move(*scriptCall).toMessage(),
212 false,
213 rpcTimeoutSeconds));
214 }
215
pyRpcPythonUdf(const WorkerInfo & dst,std::string & pickledPythonUDF,std::vector<torch::Tensor> & tensors,const float rpcTimeoutSeconds,const bool isAsyncExecution)216 c10::intrusive_ptr<JitFuture> pyRpcPythonUdf(
217 const WorkerInfo& dst,
218 std::string& pickledPythonUDF,
219 std::vector<torch::Tensor>& tensors,
220 const float rpcTimeoutSeconds,
221 const bool isAsyncExecution) {
222 DCHECK(!PyGILState_Check());
223 auto serializedPyObj =
224 SerializedPyObj(std::move(pickledPythonUDF), std::move(tensors));
225 auto pythonCall = std::make_unique<PythonCall>(
226 std::move(serializedPyObj), isAsyncExecution);
227
228 auto agent = RpcAgent::getCurrentRpcAgent();
229 return toPyJitFuture(sendMessageWithAutograd(
230 *agent,
231 dst,
232 std::move(*pythonCall).toMessage(),
233 true /*forceGradRecording*/,
234 rpcTimeoutSeconds));
235 }
236
pyRpcTorchscript(const std::string & dstWorkerName,const std::string & qualifiedNameStr,const py::tuple & argsTuple,const py::dict & kwargsDict,const float rpcTimeoutSeconds,const bool isAsyncExecution)237 c10::intrusive_ptr<JitFuture> pyRpcTorchscript(
238 const std::string& dstWorkerName,
239 const std::string& qualifiedNameStr,
240 const py::tuple& argsTuple,
241 const py::dict& kwargsDict,
242 const float rpcTimeoutSeconds,
243 const bool isAsyncExecution) {
244 // No need to catch exception here, if function can not be found,
245 // exception will be thrown in get_function() call; if args do not match
246 // with function schema, exception will be thrown in
247 // createStackForSchema() call.
248 DCHECK(!PyGILState_Check());
249 const c10::QualifiedName qualifiedName(qualifiedNameStr);
250 auto functionSchema = PythonRpcHandler::getInstance()
251 .jitCompilationUnit()
252 ->get_function(qualifiedName)
253 .getSchema();
254 Stack stack;
255 {
256 // Acquire GIL for py::args and py::kwargs processing.
257 py::gil_scoped_acquire acquire;
258 stack = torch::jit::createStackForSchema(
259 functionSchema,
260 argsTuple.cast<py::args>(),
261 kwargsDict.cast<py::kwargs>(),
262 std::nullopt);
263 }
264 DCHECK(!PyGILState_Check());
265 c10::intrusive_ptr<c10::ivalue::Future> fut = rpcTorchscript(
266 dstWorkerName,
267 qualifiedName,
268 functionSchema,
269 stack,
270 rpcTimeoutSeconds,
271 isAsyncExecution);
272 return fut;
273 }
274
pyRemoteBuiltin(const WorkerInfo & dst,const std::string & opName,const float rpcTimeoutSeconds,const py::args & args,const py::kwargs & kwargs)275 PyRRef pyRemoteBuiltin(
276 const WorkerInfo& dst,
277 const std::string& opName,
278 const float rpcTimeoutSeconds,
279 const py::args& args,
280 const py::kwargs& kwargs) {
281 DCHECK(PyGILState_Check());
282 Stack stack;
283 auto op = matchBuiltinOp(opName, args, kwargs, stack);
284 // Release GIL since args and kwargs processing is done.
285 py::gil_scoped_release release;
286 TypePtr returnType = op->schema().returns()[0].type();
287
288 auto& ctx = RRefContext::getInstance();
289 auto agent = RpcAgent::getCurrentRpcAgent();
290
291 if (ctx.getWorkerId() != dst.id_) {
292 auto userRRef = ctx.createUserRRef(dst.id_, returnType);
293
294 auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
295 op, std::move(stack), userRRef->rrefId(), userRRef->forkId());
296
297 auto jitFuture = sendMessageWithAutograd(
298 *agent,
299 dst,
300 std::move(*scriptRemoteCall).toMessage(),
301 /*forceGradRecord */ false,
302 /* timeout */ rpcTimeoutSeconds);
303
304 userRRef->registerOwnerCreationFuture(jitFuture);
305 ctx.addPendingUser(userRRef->forkId(), userRRef);
306 jitFuture->addCallback(at::wrapPropagateTLSState(
307 [forkId{userRRef->forkId()}](JitFuture& future) {
308 callback::confirmPendingUser(future, forkId);
309 }));
310 return PyRRef(userRRef);
311 } else {
312 auto ownerRRef = ctx.createOwnerRRef(returnType);
313 // prevent this owner RRef being deleted due to other forks
314 ctx.addSelfAsFork(ownerRRef);
315
316 auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
317 op, std::move(stack), ownerRRef->rrefId(), ownerRRef->rrefId());
318 auto jitFuture = sendMessageWithAutograd(
319 *agent,
320 dst,
321 std::move(*scriptRemoteCall).toMessage(),
322 /* forceGradRecord */ false,
323 /* timeout */ rpcTimeoutSeconds);
324
325 ownerRRef->registerOwnerCreationFuture(jitFuture);
326 // Builtin operators does not return py::object, and hence does not require
327 // GIL for destructing the potentially deleted OwerRRef.
328 jitFuture->addCallback(at::wrapPropagateTLSState(
329 [ownerRRefId = ownerRRef->rrefId()](JitFuture& future) {
330 callback::finishCreatingOwnerRRef(future, ownerRRefId);
331 }));
332 return PyRRef(ownerRRef);
333 }
334 }
335
pyRemotePythonUdf(const WorkerInfo & dst,std::string & pickledPythonUDF,std::vector<torch::Tensor> & tensors,const float rpcTimeoutSeconds,const bool isAsyncExecution)336 PyRRef pyRemotePythonUdf(
337 const WorkerInfo& dst,
338 std::string& pickledPythonUDF,
339 std::vector<torch::Tensor>& tensors,
340 const float rpcTimeoutSeconds,
341 const bool isAsyncExecution) {
342 DCHECK(!PyGILState_Check());
343 auto& ctx = RRefContext::getInstance();
344 auto serializedPyObj =
345 SerializedPyObj(std::move(pickledPythonUDF), std::move(tensors));
346
347 if (ctx.getWorkerId() != dst.id_) {
348 auto userRRef = ctx.createUserRRef(dst.id_, PyObjectType::get());
349 auto jitFuture = sendPythonRemoteCall(
350 dst,
351 std::move(serializedPyObj),
352 userRRef->rrefId().toIValue(),
353 userRRef->forkId().toIValue(),
354 rpcTimeoutSeconds,
355 isAsyncExecution);
356
357 userRRef->registerOwnerCreationFuture(jitFuture);
358 ctx.addPendingUser(userRRef->forkId(), userRRef);
359 jitFuture->addCallback(at::wrapPropagateTLSState(
360 [forkId{userRRef->forkId()}](JitFuture& future) {
361 callback::confirmPendingUser(future, forkId);
362 }));
363 return PyRRef(userRRef);
364 } else {
365 // Sending remote message to self
366 auto ownerRRef = ctx.createOwnerRRef(PyObjectType::get());
367 // prevent this owner RRef being deleted due to other forks
368 ctx.addSelfAsFork(ownerRRef);
369 auto jitFuture = sendPythonRemoteCall(
370 dst,
371 std::move(serializedPyObj),
372 ownerRRef->rrefId().toIValue(),
373 ownerRRef->rrefId().toIValue(),
374 rpcTimeoutSeconds,
375 isAsyncExecution);
376
377 ownerRRef->registerOwnerCreationFuture(jitFuture);
378 jitFuture->addCallback(at::wrapPropagateTLSState(
379 [ownerRRefId = ownerRRef->rrefId()](JitFuture& future) {
380 auto deletedRRef =
381 callback::finishCreatingOwnerRRef(future, ownerRRefId);
382 if (deletedRRef && deletedRRef->isPyObj()) {
383 py::gil_scoped_acquire ag;
384 deletedRRef.reset();
385 }
386 }));
387 return PyRRef(ownerRRef);
388 }
389 }
390
pyRemoteTorchscript(const std::string & dstWorkerName,const std::string & qualifiedNameStr,const float rpcTimeoutSeconds,const bool isAsyncExecution,const py::args & args,const py::kwargs & kwargs)391 PyRRef pyRemoteTorchscript(
392 const std::string& dstWorkerName,
393 const std::string& qualifiedNameStr,
394 const float rpcTimeoutSeconds,
395 const bool isAsyncExecution,
396 const py::args& args,
397 const py::kwargs& kwargs) {
398 DCHECK(!PyGILState_Check());
399 auto qualifiedName = c10::QualifiedName(qualifiedNameStr);
400 auto functionSchema = PythonRpcHandler::getInstance()
401 .jitCompilationUnit()
402 ->get_function(qualifiedName)
403 .getSchema();
404 Stack stack;
405 {
406 // Acquire GIL for py::args and py::kwargs processing.
407 py::gil_scoped_acquire ag;
408 stack = torch::jit::createStackForSchema(
409 functionSchema, args, kwargs, std::nullopt);
410 }
411 DCHECK(!PyGILState_Check());
412 auto rrefPtr = remoteTorchscript(
413 dstWorkerName,
414 qualifiedName,
415 functionSchema,
416 stack,
417 rpcTimeoutSeconds,
418 isAsyncExecution);
419 return PyRRef(rrefPtr);
420 }
421
422 } // namespace torch::distributed::rpc
423