xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/python_call.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/python_call.h>
2 
3 namespace torch::distributed::rpc {
4 
PythonCall(SerializedPyObj && serializedPyObj,bool isAsyncExecution)5 PythonCall::PythonCall(SerializedPyObj&& serializedPyObj, bool isAsyncExecution)
6     : serializedPyObj_(std::move(serializedPyObj)),
7       isAsyncExecution_(isAsyncExecution) {}
8 
toMessageImpl()9 c10::intrusive_ptr<Message> PythonCall::toMessageImpl() && {
10   std::vector<char> payload;
11   payload.reserve(serializedPyObj_.payload_.length() + 1);
12   payload.push_back(isAsyncExecution_ ? 1 : 0);
13   payload.insert(
14       payload.end(),
15       serializedPyObj_.payload_.begin(),
16       serializedPyObj_.payload_.end());
17 
18   return c10::make_intrusive<Message>(
19       std::move(payload),
20       std::move(serializedPyObj_.tensors_),
21       MessageType::PYTHON_CALL);
22 }
23 
fromMessage(const Message & message)24 std::unique_ptr<PythonCall> PythonCall::fromMessage(const Message& message) {
25   TORCH_INTERNAL_ASSERT(
26       !message.payload().empty(),
27       "Failed to convert an RPC message to PythonCall, the payload should at "
28       "least contain one byte indicating whether this is an async function, "
29       "but got payload of size ",
30       message.payload().size());
31   const char& c = message.payload()[0];
32   TORCH_INTERNAL_ASSERT(c == 0 || c == 1);
33   bool isAsyncExecution = (c == 1);
34   std::string payload(message.payload().begin() + 1, message.payload().end());
35   std::vector<Tensor> tensors = message.tensors();
36   SerializedPyObj serializedPyObj(std::move(payload), std::move(tensors));
37   return std::make_unique<PythonCall>(
38       std::move(serializedPyObj), isAsyncExecution);
39 }
40 
serializedPyObj() const41 const SerializedPyObj& PythonCall::serializedPyObj() const {
42   return serializedPyObj_;
43 }
44 
45 } // namespace torch::distributed::rpc
46