1 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
2 #include <torch/csrc/distributed/rpc/rpc_agent.h>
3 #include <torch/csrc/distributed/rpc/utils.h>
4 #include <torch/csrc/jit/serialization/pickle.h>
5 #include <torch/csrc/utils/byte_order.h>
6
7 namespace torch {
8 namespace distributed {
9 namespace autograd {
10
11 using rpc::Message;
12 using rpc::MessageType;
13 using rpc::RpcCommandBase;
14 using rpc::worker_id_t;
15
RpcWithAutograd(worker_id_t fromWorkerId,MessageType messageType,const AutogradMetadata & autogradMetadata,c10::intrusive_ptr<rpc::Message> wrappedMessage,rpc::DeviceMap deviceMap)16 RpcWithAutograd::RpcWithAutograd(
17 worker_id_t fromWorkerId,
18 MessageType messageType,
19 const AutogradMetadata& autogradMetadata,
20 c10::intrusive_ptr<rpc::Message> wrappedMessage,
21 rpc::DeviceMap deviceMap)
22 : fromWorkerId_(fromWorkerId),
23 messageType_(messageType),
24 autogradMetadata_(autogradMetadata),
25 wrappedMessage_(std::move(wrappedMessage)),
26 deviceMap_(std::move(deviceMap)) {
27 TORCH_INTERNAL_ASSERT(
28 messageType_ == MessageType::FORWARD_AUTOGRAD_REQ ||
29 messageType_ == MessageType::FORWARD_AUTOGRAD_RESP);
30 tensors_ = wrappedMessage_->tensors();
31 wrappedMessageType_ = wrappedMessage_->type();
32 }
33
RpcWithAutograd(worker_id_t fromWorkerId,MessageType messageType,const AutogradMetadata & autogradMetadata,std::unique_ptr<RpcCommandBase> wrappedRpc,MessageType wrappedMessageType,std::vector<torch::Tensor> tensors,rpc::DeviceMap deviceMap)34 RpcWithAutograd::RpcWithAutograd(
35 worker_id_t fromWorkerId,
36 MessageType messageType,
37 const AutogradMetadata& autogradMetadata,
38 std::unique_ptr<RpcCommandBase> wrappedRpc,
39 MessageType wrappedMessageType,
40 std::vector<torch::Tensor> tensors,
41 rpc::DeviceMap deviceMap)
42 : fromWorkerId_(fromWorkerId),
43 messageType_(messageType),
44 autogradMetadata_(autogradMetadata),
45 wrappedRpc_(std::move(wrappedRpc)),
46 wrappedMessageType_(wrappedMessageType),
47 tensors_(std::move(tensors)),
48 deviceMap_(std::move(deviceMap)) {
49 TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
50 TORCH_INTERNAL_ASSERT(
51 messageType_ == MessageType::FORWARD_AUTOGRAD_REQ ||
52 messageType_ == MessageType::FORWARD_AUTOGRAD_RESP);
53 }
54
toMessageImpl()55 c10::intrusive_ptr<Message> RpcWithAutograd::toMessageImpl() && {
56 auto messageId = wrappedMessage_->id();
57 auto wrappedMessageType = wrappedMessage_->type();
58
59 auto payload = std::move(*wrappedMessage_).movePayload();
60 TORCH_INTERNAL_ASSERT(!payload.empty());
61
62 // Convert deviceMap to c10::Dict for serialization.
63 c10::Dict<std::string, std::string> deviceMap;
64 for (const auto& mapEntry : deviceMap_) {
65 deviceMap.insert(mapEntry.first.str(), mapEntry.second.str());
66 }
67
68 std::vector<at::IValue> ivalues{
69 wrappedMessageType,
70 autogradMetadata_.autogradContextId,
71 autogradMetadata_.autogradMessageId,
72 fromWorkerId_,
73 deviceMap};
74
75 // Now pickle using JIT pickler.
76 std::vector<torch::Tensor> tensorTable;
77 std::vector<char> additionalPayload =
78 jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);
79
80 // We shouldn't have any tensors!
81 TORCH_INTERNAL_ASSERT(tensorTable.empty());
82
83 // This wraps additionalPayload into payload and takes care of resizing,
84 // encoding.
85 rpc::writeWrappedPayload(payload, additionalPayload);
86
87 return c10::make_intrusive<Message>(
88 std::move(payload), std::move(tensors_), messageType_, messageId);
89 }
90
fromMessage(const Message & message)91 std::unique_ptr<RpcWithAutograd> RpcWithAutograd::fromMessage(
92 const Message& message) {
93 MessageType originalMessageType = message.type();
94 TORCH_INTERNAL_ASSERT(
95 MessageType::FORWARD_AUTOGRAD_REQ == originalMessageType ||
96 MessageType::FORWARD_AUTOGRAD_RESP == originalMessageType);
97
98 std::vector<torch::Tensor> tensors = message.tensors();
99 int64_t messageId = message.id();
100 // Decode message type, autograd context id, autograd message id and worker
101 // id from which we received this message.
102 auto payload = message.payload();
103 auto tupleElements = rpc::readWrappedPayload(payload, message);
104
105 // Gather all the fields.
106 TORCH_INTERNAL_ASSERT(tupleElements.size() == 5);
107 MessageType wrappedMessageType =
108 static_cast<MessageType>(tupleElements[0].toInt());
109 AutogradMetadata autogradMetadata(
110 tupleElements[1].toInt(), tupleElements[2].toInt());
111 worker_id_t workerId = tupleElements[3].toInt();
112 auto c10DeviceMap =
113 tupleElements[4].to<c10::Dict<std::string, std::string>>();
114
115 // Convert to regular map.
116 rpc::DeviceMap deviceMap;
117 for (const auto& mapEntry : c10DeviceMap) {
118 deviceMap.insert({mapEntry.key(), mapEntry.value()});
119 }
120
121 // Create new message type and build wrapped RPC.
122 auto wrappedMessage = c10::make_intrusive<Message>(
123 std::move(payload), std::move(tensors), wrappedMessageType, messageId);
124
125 std::unique_ptr<RpcCommandBase> wrappedRpc;
126 if (originalMessageType == MessageType::FORWARD_AUTOGRAD_REQ) {
127 wrappedRpc = deserializeRequest(*wrappedMessage);
128 } else {
129 wrappedRpc = deserializeResponse(*wrappedMessage, wrappedMessageType);
130 }
131
132 return std::make_unique<RpcWithAutograd>(
133 workerId,
134 originalMessageType,
135 autogradMetadata,
136 std::move(wrappedRpc),
137 wrappedMessageType,
138 wrappedMessage->tensors(),
139 deviceMap);
140 }
141
tensors()142 std::vector<torch::Tensor>& RpcWithAutograd::tensors() {
143 return tensors_;
144 }
145
autogradMetadata() const146 const AutogradMetadata& RpcWithAutograd::autogradMetadata() const {
147 return autogradMetadata_;
148 }
149
wrappedRpc()150 RpcCommandBase& RpcWithAutograd::wrappedRpc() {
151 TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
152 return *wrappedRpc_;
153 }
154
setWrappedRpc(std::unique_ptr<RpcCommandBase> wrappedRpc)155 void RpcWithAutograd::setWrappedRpc(
156 std::unique_ptr<RpcCommandBase> wrappedRpc) {
157 wrappedRpc_ = std::move(wrappedRpc);
158 }
159
moveWrappedRpc()160 std::unique_ptr<RpcCommandBase> RpcWithAutograd::moveWrappedRpc() && {
161 TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
162 return std::move(wrappedRpc_);
163 }
164
wrappedMessageType() const165 MessageType RpcWithAutograd::wrappedMessageType() const {
166 return wrappedMessageType_;
167 }
168
fromWorkerId() const169 rpc::worker_id_t RpcWithAutograd::fromWorkerId() const {
170 return fromWorkerId_;
171 }
172
deviceMap()173 const rpc::DeviceMap& RpcWithAutograd::deviceMap() {
174 return deviceMap_;
175 }
176
177 } // namespace autograd
178 } // namespace distributed
179 } // namespace torch
180