xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/python_functions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ThreadLocalState.h>
2 #include <torch/csrc/distributed/autograd/context/container.h>
3 #include <torch/csrc/distributed/autograd/utils.h>
4 #include <torch/csrc/distributed/rpc/message.h>
5 #include <torch/csrc/distributed/rpc/python_call.h>
6 #include <torch/csrc/distributed/rpc/python_functions.h>
7 #include <torch/csrc/distributed/rpc/python_remote_call.h>
8 #include <torch/csrc/distributed/rpc/python_resp.h>
9 #include <torch/csrc/distributed/rpc/python_rpc_handler.h>
10 #include <torch/csrc/distributed/rpc/rref_context.h>
11 #include <torch/csrc/distributed/rpc/rref_proto.h>
12 #include <torch/csrc/distributed/rpc/script_call.h>
13 #include <torch/csrc/distributed/rpc/script_remote_call.h>
14 #include <torch/csrc/distributed/rpc/script_resp.h>
15 #include <torch/csrc/distributed/rpc/torchscript_functions.h>
16 #include <torch/csrc/distributed/rpc/utils.h>
17 #include <torch/csrc/jit/runtime/operator.h>
18 #include <torch/csrc/utils/python_compat.h>
19 #include <exception>
20 
21 namespace torch::distributed::rpc {
22 
23 namespace {
24 
toPyIValue(const Message & message)25 IValue toPyIValue(const Message& message) {
26   MessageType msgType = message.type();
27   auto response = deserializeResponse(message, msgType);
28   switch (msgType) {
29     case MessageType::SCRIPT_RET: {
30       auto& ret = static_cast<ScriptResp&>(*response);
31       Stack stack;
32       stack.push_back(ret.value());
33       // Need GIL to guard createPyObjectForStack() and its returned
34       // py::object
35       py::gil_scoped_acquire acquire;
36       return jit::toIValue(
37           torch::jit::createPyObjectForStack(std::move(stack)),
38           PyObjectType::get());
39     }
40     case MessageType::PYTHON_RET: {
41       // TODO: Try to avoid a copy here.
42       auto& resp = static_cast<PythonResp&>(*response);
43       auto& pythonRpcHandler = PythonRpcHandler::getInstance();
44       // Need GIL to destruct the py::object returned by deserialize()
45       py::gil_scoped_acquire acquire;
46       py::object value = pythonRpcHandler.deserialize(resp.serializedPyObj());
47       pythonRpcHandler.handleException(value);
48       return jit::toIValue(value, PyObjectType::get());
49     }
50     default: {
51       TORCH_CHECK(false, "Unrecognized response message type ", msgType);
52     }
53   }
54 }
55 
matchBuiltinOp(const std::string & opName,const py::args & args,const py::kwargs & kwargs,Stack & stack)56 std::shared_ptr<Operator> matchBuiltinOp(
57     const std::string& opName,
58     const py::args& args,
59     const py::kwargs& kwargs,
60     Stack& stack) {
61   Symbol symbol = Symbol::fromQualString(opName);
62 
63   std::shared_ptr<jit::Operator> matchedOperator;
64   if (symbol.is_aten()) {
65     // Prefer C10 ops so that they go through C10 dispatch. We expect the
66     // total # of possible overloaded ops (i.e. size of below ops list) to be
67     // small (i.e. it is 10 for torch.add) so a worst-case linear search should
68     // not incur significant extra overhead.
69     auto ops = torch::jit::getAllOperatorsFor(symbol);
70     std::vector<std::shared_ptr<torch::jit::Operator>> c10OpsForSymbol;
71     for (auto it = ops.begin(); it != ops.end();) {
72       std::shared_ptr<jit::Operator> op = *it;
73       if (op->isC10Op()) {
74         c10OpsForSymbol.emplace_back(std::move(op));
75         it = ops.erase(it);
76       } else {
77         ++it;
78       }
79     }
80 
81     // Don't throw on failures in this call, since we are not examining on all
82     // operators here, and the matched operator may indeed not be a c10 op.
83     std::pair<std::shared_ptr<torch::jit::Operator>, torch::jit::Stack>
84         opWithStack;
85     try {
86       opWithStack = torch::jit::getOpWithStack(c10OpsForSymbol, args, kwargs);
87     } catch (const std::runtime_error& e) {
88       opWithStack = torch::jit::getOpWithStack(ops, args, kwargs);
89     }
90     matchedOperator = std::get<0>(opWithStack);
91     stack = std::get<1>(opWithStack);
92   }
93 
94   // We should never hit this path, since if !matchedOperator, then the last
95   // call to getOpWithStack should have thrown.
96   TORCH_CHECK(
97       matchedOperator != nullptr,
98       "Failed to match operator name ",
99       opName,
100       " and arguments "
101       "(args: ",
102       args,
103       ", kwargs: ",
104       kwargs,
105       ") to a builtin operator");
106 
107   return matchedOperator;
108 }
109 
sendPythonRemoteCall(const WorkerInfo & dst,SerializedPyObj serializedPyObj,const IValue & rrefId,const IValue & forkId,const float rpcTimeoutSeconds,const bool isAsyncExecution)110 c10::intrusive_ptr<JitFuture> sendPythonRemoteCall(
111     const WorkerInfo& dst,
112     SerializedPyObj serializedPyObj,
113     const IValue& rrefId,
114     const IValue& forkId,
115     const float rpcTimeoutSeconds,
116     const bool isAsyncExecution) {
117   auto pythonRemoteCall = std::make_unique<PythonRemoteCall>(
118       std::move(serializedPyObj), rrefId, forkId, isAsyncExecution);
119 
120   // set forceGradRecording to true as even if the args does not contain any
121   // tensor, the return value might still contain tensors.
122   auto agent = RpcAgent::getCurrentRpcAgent();
123   return torch::distributed::autograd::sendMessageWithAutograd(
124       *agent,
125       dst,
126       std::move(*pythonRemoteCall).toMessage(),
127       true /*forceGradRecording*/,
128       rpcTimeoutSeconds);
129 }
130 
131 } // namespace
132 
133 using namespace torch::distributed::autograd;
134 
toPyJitFuture(const c10::intrusive_ptr<JitFuture> & messageJitFuture,bool hasValue)135 c10::intrusive_ptr<JitFuture> toPyJitFuture(
136     const c10::intrusive_ptr<JitFuture>& messageJitFuture,
137     bool hasValue) {
138   if (hasValue) {
139     auto child = messageJitFuture->createInstance(PyObjectType::get());
140     messageJitFuture->addCallback(
141         at::wrapPropagateTLSState([child](JitFuture& future) {
142           if (future.hasError()) {
143             child->setError(future.exception_ptr());
144           } else {
145             const Message& message = *future.value().toCustomClass<Message>();
146 
147             // toPyIValue might throw and we need to record the appropriate
148             // exception.
149             IValue ivalue;
150             try {
151               ivalue = toPyIValue(message);
152             } catch (py::error_already_set& e) {
153               py::gil_scoped_acquire acquire;
154               // FIXME: this is a temporary solution to add a special-case for
155               // ValueError and TypeError, as those are already used in our
156               // tests. We should have a more comprehensive coverage for other
157               // types of exceptions as well.
158               if (e.matches(PyExc_ValueError)) {
159                 child->setErrorIfNeeded(
160                     std::make_exception_ptr(pybind11::value_error(e.what())));
161               } else if (e.matches(PyExc_TypeError)) {
162                 child->setErrorIfNeeded(
163                     std::make_exception_ptr(pybind11::type_error(e.what())));
164               } else {
165                 // py::error_already_set requires GIL to destruct, take special
166                 // care.
167                 child->setErrorIfNeeded(
168                     std::make_exception_ptr(std::runtime_error(e.what())));
169               }
170               e.restore();
171               PyErr_Clear();
172               return;
173             } catch (std::exception& e) {
174               child->setErrorIfNeeded(std::current_exception());
175               return;
176             }
177 
178             child->markCompleted(ivalue, future.storages());
179           }
180         }));
181     return child;
182   } else {
183     return messageJitFuture->then(
184         at::wrapPropagateTLSState([](JitFuture& future) {
185           if (future.hasError()) {
186             std::rethrow_exception(future.exception_ptr());
187           } else {
188             return IValue();
189           }
190         }),
191         NoneType::get());
192   }
193 }
194 
pyRpcBuiltin(const WorkerInfo & dst,const std::string & opName,const py::args & args,const py::kwargs & kwargs,const float rpcTimeoutSeconds)195 c10::intrusive_ptr<JitFuture> pyRpcBuiltin(
196     const WorkerInfo& dst,
197     const std::string& opName,
198     const py::args& args,
199     const py::kwargs& kwargs,
200     const float rpcTimeoutSeconds) {
201   DCHECK(PyGILState_Check());
202   Stack stack;
203   auto op = matchBuiltinOp(opName, args, kwargs, stack);
204   // Release GIL since args and kwargs processing is done.
205   py::gil_scoped_release release;
206   auto scriptCall = std::make_unique<ScriptCall>(op, std::move(stack));
207   auto agent = RpcAgent::getCurrentRpcAgent();
208   return toPyJitFuture(sendMessageWithAutograd(
209       *agent,
210       dst,
211       std::move(*scriptCall).toMessage(),
212       false,
213       rpcTimeoutSeconds));
214 }
215 
pyRpcPythonUdf(const WorkerInfo & dst,std::string & pickledPythonUDF,std::vector<torch::Tensor> & tensors,const float rpcTimeoutSeconds,const bool isAsyncExecution)216 c10::intrusive_ptr<JitFuture> pyRpcPythonUdf(
217     const WorkerInfo& dst,
218     std::string& pickledPythonUDF,
219     std::vector<torch::Tensor>& tensors,
220     const float rpcTimeoutSeconds,
221     const bool isAsyncExecution) {
222   DCHECK(!PyGILState_Check());
223   auto serializedPyObj =
224       SerializedPyObj(std::move(pickledPythonUDF), std::move(tensors));
225   auto pythonCall = std::make_unique<PythonCall>(
226       std::move(serializedPyObj), isAsyncExecution);
227 
228   auto agent = RpcAgent::getCurrentRpcAgent();
229   return toPyJitFuture(sendMessageWithAutograd(
230       *agent,
231       dst,
232       std::move(*pythonCall).toMessage(),
233       true /*forceGradRecording*/,
234       rpcTimeoutSeconds));
235 }
236 
pyRpcTorchscript(const std::string & dstWorkerName,const std::string & qualifiedNameStr,const py::tuple & argsTuple,const py::dict & kwargsDict,const float rpcTimeoutSeconds,const bool isAsyncExecution)237 c10::intrusive_ptr<JitFuture> pyRpcTorchscript(
238     const std::string& dstWorkerName,
239     const std::string& qualifiedNameStr,
240     const py::tuple& argsTuple,
241     const py::dict& kwargsDict,
242     const float rpcTimeoutSeconds,
243     const bool isAsyncExecution) {
244   // No need to catch exception here, if function can not be found,
245   // exception will be thrown in get_function() call; if args do not match
246   // with function schema, exception will be thrown in
247   // createStackForSchema() call.
248   DCHECK(!PyGILState_Check());
249   const c10::QualifiedName qualifiedName(qualifiedNameStr);
250   auto functionSchema = PythonRpcHandler::getInstance()
251                             .jitCompilationUnit()
252                             ->get_function(qualifiedName)
253                             .getSchema();
254   Stack stack;
255   {
256     // Acquire GIL for py::args and py::kwargs processing.
257     py::gil_scoped_acquire acquire;
258     stack = torch::jit::createStackForSchema(
259         functionSchema,
260         argsTuple.cast<py::args>(),
261         kwargsDict.cast<py::kwargs>(),
262         std::nullopt);
263   }
264   DCHECK(!PyGILState_Check());
265   c10::intrusive_ptr<c10::ivalue::Future> fut = rpcTorchscript(
266       dstWorkerName,
267       qualifiedName,
268       functionSchema,
269       stack,
270       rpcTimeoutSeconds,
271       isAsyncExecution);
272   return fut;
273 }
274 
pyRemoteBuiltin(const WorkerInfo & dst,const std::string & opName,const float rpcTimeoutSeconds,const py::args & args,const py::kwargs & kwargs)275 PyRRef pyRemoteBuiltin(
276     const WorkerInfo& dst,
277     const std::string& opName,
278     const float rpcTimeoutSeconds,
279     const py::args& args,
280     const py::kwargs& kwargs) {
281   DCHECK(PyGILState_Check());
282   Stack stack;
283   auto op = matchBuiltinOp(opName, args, kwargs, stack);
284   // Release GIL since args and kwargs processing is done.
285   py::gil_scoped_release release;
286   TypePtr returnType = op->schema().returns()[0].type();
287 
288   auto& ctx = RRefContext::getInstance();
289   auto agent = RpcAgent::getCurrentRpcAgent();
290 
291   if (ctx.getWorkerId() != dst.id_) {
292     auto userRRef = ctx.createUserRRef(dst.id_, returnType);
293 
294     auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
295         op, std::move(stack), userRRef->rrefId(), userRRef->forkId());
296 
297     auto jitFuture = sendMessageWithAutograd(
298         *agent,
299         dst,
300         std::move(*scriptRemoteCall).toMessage(),
301         /*forceGradRecord */ false,
302         /* timeout */ rpcTimeoutSeconds);
303 
304     userRRef->registerOwnerCreationFuture(jitFuture);
305     ctx.addPendingUser(userRRef->forkId(), userRRef);
306     jitFuture->addCallback(at::wrapPropagateTLSState(
307         [forkId{userRRef->forkId()}](JitFuture& future) {
308           callback::confirmPendingUser(future, forkId);
309         }));
310     return PyRRef(userRRef);
311   } else {
312     auto ownerRRef = ctx.createOwnerRRef(returnType);
313     // prevent this owner RRef being deleted due to other forks
314     ctx.addSelfAsFork(ownerRRef);
315 
316     auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
317         op, std::move(stack), ownerRRef->rrefId(), ownerRRef->rrefId());
318     auto jitFuture = sendMessageWithAutograd(
319         *agent,
320         dst,
321         std::move(*scriptRemoteCall).toMessage(),
322         /* forceGradRecord */ false,
323         /* timeout */ rpcTimeoutSeconds);
324 
325     ownerRRef->registerOwnerCreationFuture(jitFuture);
326     // Builtin operators does not return py::object, and hence does not require
327     // GIL for destructing the potentially deleted OwerRRef.
328     jitFuture->addCallback(at::wrapPropagateTLSState(
329         [ownerRRefId = ownerRRef->rrefId()](JitFuture& future) {
330           callback::finishCreatingOwnerRRef(future, ownerRRefId);
331         }));
332     return PyRRef(ownerRRef);
333   }
334 }
335 
pyRemotePythonUdf(const WorkerInfo & dst,std::string & pickledPythonUDF,std::vector<torch::Tensor> & tensors,const float rpcTimeoutSeconds,const bool isAsyncExecution)336 PyRRef pyRemotePythonUdf(
337     const WorkerInfo& dst,
338     std::string& pickledPythonUDF,
339     std::vector<torch::Tensor>& tensors,
340     const float rpcTimeoutSeconds,
341     const bool isAsyncExecution) {
342   DCHECK(!PyGILState_Check());
343   auto& ctx = RRefContext::getInstance();
344   auto serializedPyObj =
345       SerializedPyObj(std::move(pickledPythonUDF), std::move(tensors));
346 
347   if (ctx.getWorkerId() != dst.id_) {
348     auto userRRef = ctx.createUserRRef(dst.id_, PyObjectType::get());
349     auto jitFuture = sendPythonRemoteCall(
350         dst,
351         std::move(serializedPyObj),
352         userRRef->rrefId().toIValue(),
353         userRRef->forkId().toIValue(),
354         rpcTimeoutSeconds,
355         isAsyncExecution);
356 
357     userRRef->registerOwnerCreationFuture(jitFuture);
358     ctx.addPendingUser(userRRef->forkId(), userRRef);
359     jitFuture->addCallback(at::wrapPropagateTLSState(
360         [forkId{userRRef->forkId()}](JitFuture& future) {
361           callback::confirmPendingUser(future, forkId);
362         }));
363     return PyRRef(userRRef);
364   } else {
365     // Sending remote message to self
366     auto ownerRRef = ctx.createOwnerRRef(PyObjectType::get());
367     // prevent this owner RRef being deleted due to other forks
368     ctx.addSelfAsFork(ownerRRef);
369     auto jitFuture = sendPythonRemoteCall(
370         dst,
371         std::move(serializedPyObj),
372         ownerRRef->rrefId().toIValue(),
373         ownerRRef->rrefId().toIValue(),
374         rpcTimeoutSeconds,
375         isAsyncExecution);
376 
377     ownerRRef->registerOwnerCreationFuture(jitFuture);
378     jitFuture->addCallback(at::wrapPropagateTLSState(
379         [ownerRRefId = ownerRRef->rrefId()](JitFuture& future) {
380           auto deletedRRef =
381               callback::finishCreatingOwnerRRef(future, ownerRRefId);
382           if (deletedRRef && deletedRRef->isPyObj()) {
383             py::gil_scoped_acquire ag;
384             deletedRRef.reset();
385           }
386         }));
387     return PyRRef(ownerRRef);
388   }
389 }
390 
pyRemoteTorchscript(const std::string & dstWorkerName,const std::string & qualifiedNameStr,const float rpcTimeoutSeconds,const bool isAsyncExecution,const py::args & args,const py::kwargs & kwargs)391 PyRRef pyRemoteTorchscript(
392     const std::string& dstWorkerName,
393     const std::string& qualifiedNameStr,
394     const float rpcTimeoutSeconds,
395     const bool isAsyncExecution,
396     const py::args& args,
397     const py::kwargs& kwargs) {
398   DCHECK(!PyGILState_Check());
399   auto qualifiedName = c10::QualifiedName(qualifiedNameStr);
400   auto functionSchema = PythonRpcHandler::getInstance()
401                             .jitCompilationUnit()
402                             ->get_function(qualifiedName)
403                             .getSchema();
404   Stack stack;
405   {
406     // Acquire GIL for py::args and py::kwargs processing.
407     py::gil_scoped_acquire ag;
408     stack = torch::jit::createStackForSchema(
409         functionSchema, args, kwargs, std::nullopt);
410   }
411   DCHECK(!PyGILState_Check());
412   auto rrefPtr = remoteTorchscript(
413       dstWorkerName,
414       qualifiedName,
415       functionSchema,
416       stack,
417       rpcTimeoutSeconds,
418       isAsyncExecution);
419   return PyRRef(rrefPtr);
420 }
421 
422 } // namespace torch::distributed::rpc
423