1 #include <torch/csrc/distributed/rpc/rpc_agent.h>
2 #include <torch/csrc/distributed/rpc/rref_proto.h>
3 #include <torch/csrc/jit/serialization/pickle.h>
4
5 #include <limits>
6
7 namespace torch::distributed::rpc {
8
9 namespace {
10
toIValues(const Message & message,MessageType type)11 c10::ivalue::TupleElements toIValues(const Message& message, MessageType type) {
12 TORCH_INTERNAL_ASSERT(
13 type == message.type(),
14 "Expecting message of type ",
15 type,
16 ", but got ",
17 message.type());
18 auto payload = static_cast<const char*>(message.payload().data());
19 auto payload_size = message.payload().size();
20
21 auto value = jit::unpickle(
22 payload,
23 payload_size,
24 *RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
25 message.tensors());
26 return std::move(*std::move(value).toTuple()).elements();
27 }
28
fromIValues(std::vector<IValue> ivalues,MessageType type)29 c10::intrusive_ptr<Message> fromIValues(
30 std::vector<IValue> ivalues,
31 MessageType type) {
32 std::vector<torch::Tensor> tensor_table;
33 auto payload = jit::pickle(
34 c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table);
35 return c10::make_intrusive<Message>(
36 std::move(payload), std::move(tensor_table), type);
37 }
38
39 } // namespace
40
41 /////////////////////////// RRefMessageBase //////////////////////////////////
42
rrefId()43 const RRefId& RRefMessageBase::rrefId() {
44 return rrefId_;
45 }
46
47 /////////////////////////// ForkMessageBase //////////////////////////////////
48
forkId()49 const ForkId& ForkMessageBase::forkId() {
50 return forkId_;
51 }
52
toMessageImpl()53 c10::intrusive_ptr<Message> ForkMessageBase::toMessageImpl() && {
54 return fromIValues({rrefId_.toIValue(), forkId_.toIValue()}, type_);
55 }
56
fromMessage(const Message & message,MessageType type)57 std::pair<RRefId, ForkId> ForkMessageBase::fromMessage(
58 const Message& message,
59 MessageType type) {
60 auto ivalues = toIValues(message, type);
61
62 TORCH_INTERNAL_ASSERT(
63 ivalues.size() == 2, "ForkMessageBase expects 2 IValue from message.");
64
65 return std::make_pair(
66 RRefId::fromIValue(ivalues[0]), ForkId::fromIValue(ivalues[1]));
67 }
68
69 /////////////////////////// RRef Protocol //////////////////////////////////
70
toMessageImpl()71 c10::intrusive_ptr<Message> ScriptRRefFetchCall::toMessageImpl() && {
72 std::vector<at::IValue> ivalues;
73 ivalues.reserve(2);
74 ivalues.emplace_back(rrefId_.toIValue());
75 ivalues.emplace_back(fromWorkerId_);
76 return fromIValues(std::move(ivalues), MessageType::SCRIPT_RREF_FETCH_CALL);
77 }
78
fromMessage(const Message & message)79 std::unique_ptr<ScriptRRefFetchCall> ScriptRRefFetchCall::fromMessage(
80 const Message& message) {
81 auto values = toIValues(message, MessageType::SCRIPT_RREF_FETCH_CALL);
82 TORCH_INTERNAL_ASSERT(
83 values.size() == 2, "ScriptRRefFetchCall expects 2 IValues from message");
84 auto id = values[1].toInt();
85 TORCH_INTERNAL_ASSERT(
86 id >= std::numeric_limits<worker_id_t>::min() &&
87 id <= std::numeric_limits<worker_id_t>::max(),
88 "ScriptRRefFetchCall fromWorkerId exceeds worker_id_t limit.")
89 return std::make_unique<ScriptRRefFetchCall>(
90 worker_id_t(id), RRefId::fromIValue(values[0]));
91 }
92
toMessageImpl()93 c10::intrusive_ptr<Message> PythonRRefFetchCall::toMessageImpl() && {
94 std::vector<at::IValue> ivalues;
95 ivalues.reserve(2);
96 ivalues.emplace_back(rrefId_.toIValue());
97 ivalues.emplace_back(fromWorkerId_);
98 return fromIValues(std::move(ivalues), MessageType::PYTHON_RREF_FETCH_CALL);
99 }
100
fromMessage(const Message & message)101 std::unique_ptr<PythonRRefFetchCall> PythonRRefFetchCall::fromMessage(
102 const Message& message) {
103 auto values = toIValues(message, MessageType::PYTHON_RREF_FETCH_CALL);
104 TORCH_INTERNAL_ASSERT(
105 values.size() == 2, "PythonRRefFetchCall expects 2 IValues from message");
106 auto id = values[1].toInt();
107 TORCH_INTERNAL_ASSERT(
108 id >= std::numeric_limits<worker_id_t>::min() &&
109 id <= std::numeric_limits<worker_id_t>::max(),
110 "PythonRRefFetchCall fromWorkerId exceeds worker_id_t limit.")
111 return std::make_unique<PythonRRefFetchCall>(
112 worker_id_t(id), RRefId::fromIValue(values[0]));
113 }
114
values()115 const std::vector<at::IValue>& RRefFetchRet::values() {
116 return values_;
117 }
118
toMessageImpl()119 c10::intrusive_ptr<Message> RRefFetchRet::toMessageImpl() && {
120 return fromIValues(values_, type_);
121 }
122
fromMessage(const Message & message)123 std::unique_ptr<ScriptRRefFetchRet> ScriptRRefFetchRet::fromMessage(
124 const Message& message) {
125 auto values = toIValues(message, MessageType::SCRIPT_RREF_FETCH_RET);
126 TORCH_INTERNAL_ASSERT(
127 values.size() == 1,
128 "RRef of IValue should contain a single IValue, but got ",
129 values.size());
130 return std::make_unique<ScriptRRefFetchRet>(std::move(values).vec());
131 }
132
fromMessage(const Message & message)133 std::unique_ptr<PythonRRefFetchRet> PythonRRefFetchRet::fromMessage(
134 const Message& message) {
135 return std::make_unique<PythonRRefFetchRet>(
136 toIValues(message, MessageType::PYTHON_RREF_FETCH_RET).vec());
137 }
138
fromMessage(const Message & message)139 std::unique_ptr<RRefUserDelete> RRefUserDelete::fromMessage(
140 const Message& message) {
141 auto pair =
142 ForkMessageBase::fromMessage(message, MessageType::RREF_USER_DELETE);
143 return std::make_unique<RRefUserDelete>(pair.first, pair.second);
144 }
145
fromMessage(const Message & message)146 std::unique_ptr<RemoteRet> RemoteRet::fromMessage(const Message& message) {
147 auto pair = ForkMessageBase::fromMessage(message, MessageType::REMOTE_RET);
148 return std::make_unique<RemoteRet>(pair.first, pair.second);
149 }
150
forkId() const151 const ForkId& RRefChildAccept::forkId() const {
152 return forkId_;
153 }
154
toMessageImpl()155 c10::intrusive_ptr<Message> RRefChildAccept::toMessageImpl() && {
156 return fromIValues({forkId_.toIValue()}, MessageType::RREF_CHILD_ACCEPT);
157 }
158
fromMessage(const Message & message)159 std::unique_ptr<RRefChildAccept> RRefChildAccept::fromMessage(
160 const Message& message) {
161 auto values = toIValues(message, MessageType::RREF_CHILD_ACCEPT);
162 TORCH_INTERNAL_ASSERT(values.size() == 1, "Expect 1 IValues from message.");
163
164 return std::make_unique<RRefChildAccept>(ForkId::fromIValue(values.back()));
165 }
166
fromMessage(const Message & message)167 std::unique_ptr<RRefForkRequest> RRefForkRequest::fromMessage(
168 const Message& message) {
169 auto pair =
170 ForkMessageBase::fromMessage(message, MessageType::RREF_FORK_REQUEST);
171 return std::make_unique<RRefForkRequest>(pair.first, pair.second);
172 }
173
toMessageImpl()174 c10::intrusive_ptr<Message> RRefAck::toMessageImpl() && {
175 return c10::make_intrusive<Message>(
176 std::vector<char>{}, std::vector<torch::Tensor>{}, MessageType::RREF_ACK);
177 }
178
fromMessage(const Message & message)179 std::unique_ptr<RRefAck> RRefAck::fromMessage(const Message& message) {
180 TORCH_INTERNAL_ASSERT(
181 message.type() == MessageType::RREF_ACK,
182 "Message type miss match, expect ",
183 MessageType::RREF_ACK,
184 ", but got ",
185 message.type());
186 return std::make_unique<RRefAck>();
187 }
188
189 } // namespace torch::distributed::rpc
190