xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/python_remote_call.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/python_remote_call.h>
2 #include <torch/csrc/distributed/rpc/rpc_agent.h>
3 #include <torch/csrc/jit/serialization/pickle.h>
4 
5 namespace torch::distributed::rpc {
6 
PythonRemoteCall(SerializedPyObj && serializedPyObj,at::IValue retRRefId,at::IValue retForkId,const bool isAsyncExecution)7 PythonRemoteCall::PythonRemoteCall(
8     SerializedPyObj&& serializedPyObj,
9     at::IValue retRRefId,
10     at::IValue retForkId,
11     const bool isAsyncExecution)
12     : serializedPyObj_(std::move(serializedPyObj)),
13       retRRefId_(std::move(retRRefId)),
14       retForkId_(std::move(retForkId)),
15       isAsyncExecution_(isAsyncExecution) {}
16 
toMessageImpl()17 c10::intrusive_ptr<Message> PythonRemoteCall::toMessageImpl() && {
18   std::vector<IValue> ivalues = std::move(serializedPyObj_).toIValues();
19   ivalues.emplace_back(retRRefId_);
20   ivalues.emplace_back(retForkId_);
21   ivalues.emplace_back(isAsyncExecution_);
22 
23   std::vector<torch::Tensor> tensor_table;
24   auto payload =
25       jit::pickle(c10::ivalue::Tuple::create(ivalues), &tensor_table);
26 
27   return c10::make_intrusive<Message>(
28       std::move(payload),
29       std::move(tensor_table),
30       MessageType::PYTHON_REMOTE_CALL);
31 }
32 
fromMessage(const Message & message)33 std::unique_ptr<PythonRemoteCall> PythonRemoteCall::fromMessage(
34     const Message& message) {
35   auto payload = static_cast<const char*>(message.payload().data());
36   auto payload_size = message.payload().size();
37 
38   auto value = jit::unpickle(
39       payload,
40       payload_size,
41       *RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
42       message.tensors());
43   auto values = value.toTupleRef().elements().vec();
44 
45   // remove the last elements from values and convert it back to an RRef
46   TORCH_INTERNAL_ASSERT(
47       values.size() > 3,
48       "Expect at least 4 elements in the unpickled values, but got ",
49       values.size());
50   bool isAsyncExecution = values.back().toBool();
51   values.pop_back();
52   auto retForkId = std::move(values.back());
53   values.pop_back();
54   auto retRRefId = std::move(values.back());
55   values.pop_back();
56   auto serializedPyObj = SerializedPyObj::fromIValues(std::move(values));
57 
58   return std::make_unique<PythonRemoteCall>(
59       std::move(serializedPyObj),
60       std::move(retRRefId),
61       std::move(retForkId),
62       isAsyncExecution);
63 }
64 
65 } // namespace torch::distributed::rpc
66