1 #include <torch/csrc/distributed/rpc/python_remote_call.h>
2 #include <torch/csrc/distributed/rpc/rpc_agent.h>
3 #include <torch/csrc/jit/serialization/pickle.h>
4
5 namespace torch::distributed::rpc {
6
PythonRemoteCall(SerializedPyObj && serializedPyObj,at::IValue retRRefId,at::IValue retForkId,const bool isAsyncExecution)7 PythonRemoteCall::PythonRemoteCall(
8 SerializedPyObj&& serializedPyObj,
9 at::IValue retRRefId,
10 at::IValue retForkId,
11 const bool isAsyncExecution)
12 : serializedPyObj_(std::move(serializedPyObj)),
13 retRRefId_(std::move(retRRefId)),
14 retForkId_(std::move(retForkId)),
15 isAsyncExecution_(isAsyncExecution) {}
16
toMessageImpl()17 c10::intrusive_ptr<Message> PythonRemoteCall::toMessageImpl() && {
18 std::vector<IValue> ivalues = std::move(serializedPyObj_).toIValues();
19 ivalues.emplace_back(retRRefId_);
20 ivalues.emplace_back(retForkId_);
21 ivalues.emplace_back(isAsyncExecution_);
22
23 std::vector<torch::Tensor> tensor_table;
24 auto payload =
25 jit::pickle(c10::ivalue::Tuple::create(ivalues), &tensor_table);
26
27 return c10::make_intrusive<Message>(
28 std::move(payload),
29 std::move(tensor_table),
30 MessageType::PYTHON_REMOTE_CALL);
31 }
32
fromMessage(const Message & message)33 std::unique_ptr<PythonRemoteCall> PythonRemoteCall::fromMessage(
34 const Message& message) {
35 auto payload = static_cast<const char*>(message.payload().data());
36 auto payload_size = message.payload().size();
37
38 auto value = jit::unpickle(
39 payload,
40 payload_size,
41 *RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
42 message.tensors());
43 auto values = value.toTupleRef().elements().vec();
44
45 // remove the last elements from values and convert it back to an RRef
46 TORCH_INTERNAL_ASSERT(
47 values.size() > 3,
48 "Expect at least 4 elements in the unpickled values, but got ",
49 values.size());
50 bool isAsyncExecution = values.back().toBool();
51 values.pop_back();
52 auto retForkId = std::move(values.back());
53 values.pop_back();
54 auto retRRefId = std::move(values.back());
55 values.pop_back();
56 auto serializedPyObj = SerializedPyObj::fromIValues(std::move(values));
57
58 return std::make_unique<PythonRemoteCall>(
59 std::move(serializedPyObj),
60 std::move(retRRefId),
61 std::move(retForkId),
62 isAsyncExecution);
63 }
64
65 } // namespace torch::distributed::rpc
66