xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/rref_proto.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/rpc_agent.h>
2 #include <torch/csrc/distributed/rpc/rref_proto.h>
3 #include <torch/csrc/jit/serialization/pickle.h>
4 
5 #include <limits>
6 
7 namespace torch::distributed::rpc {
8 
9 namespace {
10 
toIValues(const Message & message,MessageType type)11 c10::ivalue::TupleElements toIValues(const Message& message, MessageType type) {
12   TORCH_INTERNAL_ASSERT(
13       type == message.type(),
14       "Expecting message of type ",
15       type,
16       ", but got ",
17       message.type());
18   auto payload = static_cast<const char*>(message.payload().data());
19   auto payload_size = message.payload().size();
20 
21   auto value = jit::unpickle(
22       payload,
23       payload_size,
24       *RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
25       message.tensors());
26   return std::move(*std::move(value).toTuple()).elements();
27 }
28 
fromIValues(std::vector<IValue> ivalues,MessageType type)29 c10::intrusive_ptr<Message> fromIValues(
30     std::vector<IValue> ivalues,
31     MessageType type) {
32   std::vector<torch::Tensor> tensor_table;
33   auto payload = jit::pickle(
34       c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table);
35   return c10::make_intrusive<Message>(
36       std::move(payload), std::move(tensor_table), type);
37 }
38 
39 } // namespace
40 
41 /////////////////////////// RRefMessageBase //////////////////////////////////
42 
rrefId()43 const RRefId& RRefMessageBase::rrefId() {
44   return rrefId_;
45 }
46 
47 /////////////////////////// ForkMessageBase //////////////////////////////////
48 
forkId()49 const ForkId& ForkMessageBase::forkId() {
50   return forkId_;
51 }
52 
toMessageImpl()53 c10::intrusive_ptr<Message> ForkMessageBase::toMessageImpl() && {
54   return fromIValues({rrefId_.toIValue(), forkId_.toIValue()}, type_);
55 }
56 
fromMessage(const Message & message,MessageType type)57 std::pair<RRefId, ForkId> ForkMessageBase::fromMessage(
58     const Message& message,
59     MessageType type) {
60   auto ivalues = toIValues(message, type);
61 
62   TORCH_INTERNAL_ASSERT(
63       ivalues.size() == 2, "ForkMessageBase expects 2 IValue from message.");
64 
65   return std::make_pair(
66       RRefId::fromIValue(ivalues[0]), ForkId::fromIValue(ivalues[1]));
67 }
68 
69 /////////////////////////// RRef Protocol //////////////////////////////////
70 
toMessageImpl()71 c10::intrusive_ptr<Message> ScriptRRefFetchCall::toMessageImpl() && {
72   std::vector<at::IValue> ivalues;
73   ivalues.reserve(2);
74   ivalues.emplace_back(rrefId_.toIValue());
75   ivalues.emplace_back(fromWorkerId_);
76   return fromIValues(std::move(ivalues), MessageType::SCRIPT_RREF_FETCH_CALL);
77 }
78 
fromMessage(const Message & message)79 std::unique_ptr<ScriptRRefFetchCall> ScriptRRefFetchCall::fromMessage(
80     const Message& message) {
81   auto values = toIValues(message, MessageType::SCRIPT_RREF_FETCH_CALL);
82   TORCH_INTERNAL_ASSERT(
83       values.size() == 2, "ScriptRRefFetchCall expects 2 IValues from message");
84   auto id = values[1].toInt();
85   TORCH_INTERNAL_ASSERT(
86       id >= std::numeric_limits<worker_id_t>::min() &&
87           id <= std::numeric_limits<worker_id_t>::max(),
88       "ScriptRRefFetchCall fromWorkerId exceeds worker_id_t limit.")
89   return std::make_unique<ScriptRRefFetchCall>(
90       worker_id_t(id), RRefId::fromIValue(values[0]));
91 }
92 
toMessageImpl()93 c10::intrusive_ptr<Message> PythonRRefFetchCall::toMessageImpl() && {
94   std::vector<at::IValue> ivalues;
95   ivalues.reserve(2);
96   ivalues.emplace_back(rrefId_.toIValue());
97   ivalues.emplace_back(fromWorkerId_);
98   return fromIValues(std::move(ivalues), MessageType::PYTHON_RREF_FETCH_CALL);
99 }
100 
fromMessage(const Message & message)101 std::unique_ptr<PythonRRefFetchCall> PythonRRefFetchCall::fromMessage(
102     const Message& message) {
103   auto values = toIValues(message, MessageType::PYTHON_RREF_FETCH_CALL);
104   TORCH_INTERNAL_ASSERT(
105       values.size() == 2, "PythonRRefFetchCall expects 2 IValues from message");
106   auto id = values[1].toInt();
107   TORCH_INTERNAL_ASSERT(
108       id >= std::numeric_limits<worker_id_t>::min() &&
109           id <= std::numeric_limits<worker_id_t>::max(),
110       "PythonRRefFetchCall fromWorkerId exceeds worker_id_t limit.")
111   return std::make_unique<PythonRRefFetchCall>(
112       worker_id_t(id), RRefId::fromIValue(values[0]));
113 }
114 
values()115 const std::vector<at::IValue>& RRefFetchRet::values() {
116   return values_;
117 }
118 
toMessageImpl()119 c10::intrusive_ptr<Message> RRefFetchRet::toMessageImpl() && {
120   return fromIValues(values_, type_);
121 }
122 
fromMessage(const Message & message)123 std::unique_ptr<ScriptRRefFetchRet> ScriptRRefFetchRet::fromMessage(
124     const Message& message) {
125   auto values = toIValues(message, MessageType::SCRIPT_RREF_FETCH_RET);
126   TORCH_INTERNAL_ASSERT(
127       values.size() == 1,
128       "RRef of IValue should contain a single IValue, but got ",
129       values.size());
130   return std::make_unique<ScriptRRefFetchRet>(std::move(values).vec());
131 }
132 
fromMessage(const Message & message)133 std::unique_ptr<PythonRRefFetchRet> PythonRRefFetchRet::fromMessage(
134     const Message& message) {
135   return std::make_unique<PythonRRefFetchRet>(
136       toIValues(message, MessageType::PYTHON_RREF_FETCH_RET).vec());
137 }
138 
fromMessage(const Message & message)139 std::unique_ptr<RRefUserDelete> RRefUserDelete::fromMessage(
140     const Message& message) {
141   auto pair =
142       ForkMessageBase::fromMessage(message, MessageType::RREF_USER_DELETE);
143   return std::make_unique<RRefUserDelete>(pair.first, pair.second);
144 }
145 
fromMessage(const Message & message)146 std::unique_ptr<RemoteRet> RemoteRet::fromMessage(const Message& message) {
147   auto pair = ForkMessageBase::fromMessage(message, MessageType::REMOTE_RET);
148   return std::make_unique<RemoteRet>(pair.first, pair.second);
149 }
150 
forkId() const151 const ForkId& RRefChildAccept::forkId() const {
152   return forkId_;
153 }
154 
toMessageImpl()155 c10::intrusive_ptr<Message> RRefChildAccept::toMessageImpl() && {
156   return fromIValues({forkId_.toIValue()}, MessageType::RREF_CHILD_ACCEPT);
157 }
158 
fromMessage(const Message & message)159 std::unique_ptr<RRefChildAccept> RRefChildAccept::fromMessage(
160     const Message& message) {
161   auto values = toIValues(message, MessageType::RREF_CHILD_ACCEPT);
162   TORCH_INTERNAL_ASSERT(values.size() == 1, "Expect 1 IValues from message.");
163 
164   return std::make_unique<RRefChildAccept>(ForkId::fromIValue(values.back()));
165 }
166 
fromMessage(const Message & message)167 std::unique_ptr<RRefForkRequest> RRefForkRequest::fromMessage(
168     const Message& message) {
169   auto pair =
170       ForkMessageBase::fromMessage(message, MessageType::RREF_FORK_REQUEST);
171   return std::make_unique<RRefForkRequest>(pair.first, pair.second);
172 }
173 
toMessageImpl()174 c10::intrusive_ptr<Message> RRefAck::toMessageImpl() && {
175   return c10::make_intrusive<Message>(
176       std::vector<char>{}, std::vector<torch::Tensor>{}, MessageType::RREF_ACK);
177 }
178 
fromMessage(const Message & message)179 std::unique_ptr<RRefAck> RRefAck::fromMessage(const Message& message) {
180   TORCH_INTERNAL_ASSERT(
181       message.type() == MessageType::RREF_ACK,
182       "Message type miss match, expect ",
183       MessageType::RREF_ACK,
184       ", but got ",
185       message.type());
186   return std::make_unique<RRefAck>();
187 }
188 
189 } // namespace torch::distributed::rpc
190