1 #include <torch/csrc/distributed/rpc/python_call.h>
2
3 namespace torch::distributed::rpc {
4
PythonCall(SerializedPyObj && serializedPyObj,bool isAsyncExecution)5 PythonCall::PythonCall(SerializedPyObj&& serializedPyObj, bool isAsyncExecution)
6 : serializedPyObj_(std::move(serializedPyObj)),
7 isAsyncExecution_(isAsyncExecution) {}
8
toMessageImpl()9 c10::intrusive_ptr<Message> PythonCall::toMessageImpl() && {
10 std::vector<char> payload;
11 payload.reserve(serializedPyObj_.payload_.length() + 1);
12 payload.push_back(isAsyncExecution_ ? 1 : 0);
13 payload.insert(
14 payload.end(),
15 serializedPyObj_.payload_.begin(),
16 serializedPyObj_.payload_.end());
17
18 return c10::make_intrusive<Message>(
19 std::move(payload),
20 std::move(serializedPyObj_.tensors_),
21 MessageType::PYTHON_CALL);
22 }
23
fromMessage(const Message & message)24 std::unique_ptr<PythonCall> PythonCall::fromMessage(const Message& message) {
25 TORCH_INTERNAL_ASSERT(
26 !message.payload().empty(),
27 "Failed to convert an RPC message to PythonCall, the payload should at "
28 "least contain one byte indicating whether this is an async function, "
29 "but got payload of size ",
30 message.payload().size());
31 const char& c = message.payload()[0];
32 TORCH_INTERNAL_ASSERT(c == 0 || c == 1);
33 bool isAsyncExecution = (c == 1);
34 std::string payload(message.payload().begin() + 1, message.payload().end());
35 std::vector<Tensor> tensors = message.tensors();
36 SerializedPyObj serializedPyObj(std::move(payload), std::move(tensors));
37 return std::make_unique<PythonCall>(
38 std::move(serializedPyObj), isAsyncExecution);
39 }
40
serializedPyObj() const41 const SerializedPyObj& PythonCall::serializedPyObj() const {
42 return serializedPyObj_;
43 }
44
45 } // namespace torch::distributed::rpc
46