xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
2 #include <torch/csrc/distributed/rpc/rpc_agent.h>
3 #include <torch/csrc/distributed/rpc/utils.h>
4 #include <torch/csrc/jit/serialization/pickle.h>
5 #include <torch/csrc/utils/byte_order.h>
6 
7 namespace torch {
8 namespace distributed {
9 namespace autograd {
10 
11 using rpc::Message;
12 using rpc::MessageType;
13 using rpc::RpcCommandBase;
14 using rpc::worker_id_t;
15 
RpcWithAutograd(worker_id_t fromWorkerId,MessageType messageType,const AutogradMetadata & autogradMetadata,c10::intrusive_ptr<rpc::Message> wrappedMessage,rpc::DeviceMap deviceMap)16 RpcWithAutograd::RpcWithAutograd(
17     worker_id_t fromWorkerId,
18     MessageType messageType,
19     const AutogradMetadata& autogradMetadata,
20     c10::intrusive_ptr<rpc::Message> wrappedMessage,
21     rpc::DeviceMap deviceMap)
22     : fromWorkerId_(fromWorkerId),
23       messageType_(messageType),
24       autogradMetadata_(autogradMetadata),
25       wrappedMessage_(std::move(wrappedMessage)),
26       deviceMap_(std::move(deviceMap)) {
27   TORCH_INTERNAL_ASSERT(
28       messageType_ == MessageType::FORWARD_AUTOGRAD_REQ ||
29       messageType_ == MessageType::FORWARD_AUTOGRAD_RESP);
30   tensors_ = wrappedMessage_->tensors();
31   wrappedMessageType_ = wrappedMessage_->type();
32 }
33 
RpcWithAutograd(worker_id_t fromWorkerId,MessageType messageType,const AutogradMetadata & autogradMetadata,std::unique_ptr<RpcCommandBase> wrappedRpc,MessageType wrappedMessageType,std::vector<torch::Tensor> tensors,rpc::DeviceMap deviceMap)34 RpcWithAutograd::RpcWithAutograd(
35     worker_id_t fromWorkerId,
36     MessageType messageType,
37     const AutogradMetadata& autogradMetadata,
38     std::unique_ptr<RpcCommandBase> wrappedRpc,
39     MessageType wrappedMessageType,
40     std::vector<torch::Tensor> tensors,
41     rpc::DeviceMap deviceMap)
42     : fromWorkerId_(fromWorkerId),
43       messageType_(messageType),
44       autogradMetadata_(autogradMetadata),
45       wrappedRpc_(std::move(wrappedRpc)),
46       wrappedMessageType_(wrappedMessageType),
47       tensors_(std::move(tensors)),
48       deviceMap_(std::move(deviceMap)) {
49   TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
50   TORCH_INTERNAL_ASSERT(
51       messageType_ == MessageType::FORWARD_AUTOGRAD_REQ ||
52       messageType_ == MessageType::FORWARD_AUTOGRAD_RESP);
53 }
54 
toMessageImpl()55 c10::intrusive_ptr<Message> RpcWithAutograd::toMessageImpl() && {
56   auto messageId = wrappedMessage_->id();
57   auto wrappedMessageType = wrappedMessage_->type();
58 
59   auto payload = std::move(*wrappedMessage_).movePayload();
60   TORCH_INTERNAL_ASSERT(!payload.empty());
61 
62   // Convert deviceMap to c10::Dict for serialization.
63   c10::Dict<std::string, std::string> deviceMap;
64   for (const auto& mapEntry : deviceMap_) {
65     deviceMap.insert(mapEntry.first.str(), mapEntry.second.str());
66   }
67 
68   std::vector<at::IValue> ivalues{
69       wrappedMessageType,
70       autogradMetadata_.autogradContextId,
71       autogradMetadata_.autogradMessageId,
72       fromWorkerId_,
73       deviceMap};
74 
75   // Now pickle using JIT pickler.
76   std::vector<torch::Tensor> tensorTable;
77   std::vector<char> additionalPayload =
78       jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);
79 
80   // We shouldn't have any tensors!
81   TORCH_INTERNAL_ASSERT(tensorTable.empty());
82 
83   // This wraps additionalPayload into payload and takes care of resizing,
84   // encoding.
85   rpc::writeWrappedPayload(payload, additionalPayload);
86 
87   return c10::make_intrusive<Message>(
88       std::move(payload), std::move(tensors_), messageType_, messageId);
89 }
90 
fromMessage(const Message & message)91 std::unique_ptr<RpcWithAutograd> RpcWithAutograd::fromMessage(
92     const Message& message) {
93   MessageType originalMessageType = message.type();
94   TORCH_INTERNAL_ASSERT(
95       MessageType::FORWARD_AUTOGRAD_REQ == originalMessageType ||
96       MessageType::FORWARD_AUTOGRAD_RESP == originalMessageType);
97 
98   std::vector<torch::Tensor> tensors = message.tensors();
99   int64_t messageId = message.id();
100   // Decode message type, autograd context id, autograd message id and worker
101   // id from which we received this message.
102   auto payload = message.payload();
103   auto tupleElements = rpc::readWrappedPayload(payload, message);
104 
105   // Gather all the fields.
106   TORCH_INTERNAL_ASSERT(tupleElements.size() == 5);
107   MessageType wrappedMessageType =
108       static_cast<MessageType>(tupleElements[0].toInt());
109   AutogradMetadata autogradMetadata(
110       tupleElements[1].toInt(), tupleElements[2].toInt());
111   worker_id_t workerId = tupleElements[3].toInt();
112   auto c10DeviceMap =
113       tupleElements[4].to<c10::Dict<std::string, std::string>>();
114 
115   // Convert to regular map.
116   rpc::DeviceMap deviceMap;
117   for (const auto& mapEntry : c10DeviceMap) {
118     deviceMap.insert({mapEntry.key(), mapEntry.value()});
119   }
120 
121   // Create new message type and build wrapped RPC.
122   auto wrappedMessage = c10::make_intrusive<Message>(
123       std::move(payload), std::move(tensors), wrappedMessageType, messageId);
124 
125   std::unique_ptr<RpcCommandBase> wrappedRpc;
126   if (originalMessageType == MessageType::FORWARD_AUTOGRAD_REQ) {
127     wrappedRpc = deserializeRequest(*wrappedMessage);
128   } else {
129     wrappedRpc = deserializeResponse(*wrappedMessage, wrappedMessageType);
130   }
131 
132   return std::make_unique<RpcWithAutograd>(
133       workerId,
134       originalMessageType,
135       autogradMetadata,
136       std::move(wrappedRpc),
137       wrappedMessageType,
138       wrappedMessage->tensors(),
139       deviceMap);
140 }
141 
tensors()142 std::vector<torch::Tensor>& RpcWithAutograd::tensors() {
143   return tensors_;
144 }
145 
autogradMetadata() const146 const AutogradMetadata& RpcWithAutograd::autogradMetadata() const {
147   return autogradMetadata_;
148 }
149 
wrappedRpc()150 RpcCommandBase& RpcWithAutograd::wrappedRpc() {
151   TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
152   return *wrappedRpc_;
153 }
154 
setWrappedRpc(std::unique_ptr<RpcCommandBase> wrappedRpc)155 void RpcWithAutograd::setWrappedRpc(
156     std::unique_ptr<RpcCommandBase> wrappedRpc) {
157   wrappedRpc_ = std::move(wrappedRpc);
158 }
159 
moveWrappedRpc()160 std::unique_ptr<RpcCommandBase> RpcWithAutograd::moveWrappedRpc() && {
161   TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
162   return std::move(wrappedRpc_);
163 }
164 
wrappedMessageType() const165 MessageType RpcWithAutograd::wrappedMessageType() const {
166   return wrappedMessageType_;
167 }
168 
fromWorkerId() const169 rpc::worker_id_t RpcWithAutograd::fromWorkerId() const {
170   return fromWorkerId_;
171 }
172 
deviceMap()173 const rpc::DeviceMap& RpcWithAutograd::deviceMap() {
174   return deviceMap_;
175 }
176 
177 } // namespace autograd
178 } // namespace distributed
179 } // namespace torch
180