1 #include <c10/util/irange.h>
2 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h>
3 #include <torch/csrc/distributed/rpc/utils.h>
4 #include <torch/csrc/jit/serialization/pickle.h>
5
6 namespace torch {
7 namespace distributed {
8 namespace autograd {
9 using rpc::RpcCommandBase;
10
11 constexpr auto kProfileEventsStartIdx = 3;
12 // This constructor is called when creating the RpcProfilingResp before sending
13 // it as a message over the wire.
RpcWithProfilingResp(rpc::MessageType messageType,c10::intrusive_ptr<rpc::Message> wrappedMessage,std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents,rpc::ProfilingId profilingId)14 RpcWithProfilingResp::RpcWithProfilingResp(
15 rpc::MessageType messageType,
16 c10::intrusive_ptr<rpc::Message> wrappedMessage,
17 std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents,
18 rpc::ProfilingId profilingId)
19 : messageType_(messageType),
20 wrappedMessage_(std::move(wrappedMessage)),
21 tensors_(wrappedMessage_->tensors()),
22 profiledEvents_(std::move(profiledEvents)),
23 profilingId_(profilingId) {
24 TORCH_INTERNAL_ASSERT(
25 messageType_ == rpc::MessageType::RUN_WITH_PROFILING_RESP,
26 "Incorrect Message type");
27 wrappedMessageType_ = wrappedMessage_->type();
28 }
29 // this constructor is called in fromMessage() which is called when
30 // reconstructing this RPC command when processing a message of this type
RpcWithProfilingResp(rpc::MessageType messageType,std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,rpc::MessageType wrappedMessageType,std::vector<torch::Tensor> tensors,std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents,rpc::ProfilingId profilingId)31 RpcWithProfilingResp::RpcWithProfilingResp(
32 rpc::MessageType messageType,
33 std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,
34 rpc::MessageType wrappedMessageType,
35 std::vector<torch::Tensor> tensors,
36 std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents,
37 rpc::ProfilingId profilingId)
38 : messageType_(messageType),
39 wrappedRpc_(std::move(wrappedRpc)),
40 wrappedMessageType_(wrappedMessageType),
41 tensors_(std::move(tensors)),
42 profiledEvents_(std::move(profiledEvents)),
43 profilingId_(profilingId) {
44 TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrapped RPC cannot be null");
45 }
46
moveWrappedRpc()47 std::unique_ptr<RpcCommandBase> RpcWithProfilingResp::moveWrappedRpc() && {
48 TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
49 return std::move(wrappedRpc_);
50 }
51
wrappedMessageType() const52 rpc::MessageType RpcWithProfilingResp::wrappedMessageType() const {
53 return wrappedMessageType_;
54 }
55
56 std::vector<torch::autograd::profiler::LegacyEvent> RpcWithProfilingResp::
getProfiledEvents() const57 getProfiledEvents() const {
58 return profiledEvents_;
59 }
60
getProfilingId() const61 const rpc::ProfilingId& RpcWithProfilingResp::getProfilingId() const {
62 return profilingId_;
63 }
64
setWrappedRpc(std::unique_ptr<RpcCommandBase> wrappedRpc)65 void RpcWithProfilingResp::setWrappedRpc(
66 std::unique_ptr<RpcCommandBase> wrappedRpc) {
67 wrappedRpc_ = std::move(wrappedRpc);
68 }
69
toMessageImpl()70 c10::intrusive_ptr<rpc::Message> RpcWithProfilingResp::toMessageImpl() && {
71 auto wrappedMsgId = wrappedMessage_->id();
72 auto wrappedMsgType = wrappedMessage_->type();
73 auto wrappedPayload = std::move(*wrappedMessage_).movePayload();
74 // Wrapped payload should not be empty
75 TORCH_INTERNAL_ASSERT(
76 !wrappedPayload.empty(), "Wrapped payload cannot be empty");
77 // Create ivalues to send over
78 std::vector<at::IValue> ivalues{wrappedMsgType, profilingId_.toIValue()};
79 // Attach the serialized events.
80 ivalues.emplace_back(static_cast<int32_t>(profiledEvents_.size()));
81 for (const auto& e : profiledEvents_) {
82 ivalues.emplace_back(e.toIValue());
83 }
84 std::vector<torch::Tensor> tensorTable;
85 std::vector<char> profilingPayload =
86 jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);
87 rpc::writeWrappedPayload(wrappedPayload, profilingPayload);
88
89 auto returnMsg = c10::make_intrusive<rpc::Message>(
90 std::move(wrappedPayload),
91 std::move(tensors_),
92 messageType_,
93 wrappedMsgId);
94 return returnMsg;
95 }
96
wrappedRpc()97 RpcCommandBase& RpcWithProfilingResp::wrappedRpc() {
98 TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
99 return *wrappedRpc_;
100 }
101
102 // Runs on client when deserializing this message.
fromMessage(const rpc::Message & message)103 std::unique_ptr<RpcWithProfilingResp> RpcWithProfilingResp::fromMessage(
104 const rpc::Message& message) {
105 rpc::MessageType origMsgType = message.type();
106 std::vector<torch::Tensor> tensors = message.tensors();
107 int64_t msgId = message.id();
108 auto payload = message.payload();
109 auto tupleElements = rpc::readWrappedPayload(payload, message);
110 // Ensure that we have the expected number of elements
111 TORCH_INTERNAL_ASSERT(
112 tupleElements.size() >= kProfileEventsStartIdx,
113 c10::str(
114 "Expected payload size of at least ",
115 kProfileEventsStartIdx,
116 " but got size ",
117 tupleElements.size()));
118 rpc::MessageType wrappedMsgType =
119 static_cast<rpc::MessageType>(tupleElements[0].toInt());
120 rpc::ProfilingId profilingId = rpc::ProfilingId::fromIValue(tupleElements[1]);
121 int profiledEventsSize = tupleElements[2].toInt();
122 std::vector<torch::autograd::profiler::LegacyEvent> remoteEvents;
123 remoteEvents.reserve(profiledEventsSize);
124 for (const auto i : c10::irange(
125 kProfileEventsStartIdx,
126 kProfileEventsStartIdx + profiledEventsSize)) {
127 TORCH_CHECK(static_cast<size_t>(i) < tupleElements.size());
128 // Reconstruct remote event from the ivalues.
129 torch::autograd::profiler::LegacyEvent fromIvalueEvent =
130 torch::autograd::profiler::LegacyEvent::fromIValue(tupleElements[i]);
131 remoteEvents.push_back(std::move(fromIvalueEvent));
132 }
133
134 auto wrappedMessage = c10::make_intrusive<rpc::Message>(
135 std::move(payload), std::move(tensors), wrappedMsgType, msgId);
136 TORCH_INTERNAL_ASSERT(
137 wrappedMessage->isResponse(),
138 "Messages wrapped with profiling response must be responses.");
139 std::unique_ptr<RpcCommandBase> wrappedRpc =
140 deserializeResponse(*wrappedMessage, wrappedMsgType);
141 return std::make_unique<RpcWithProfilingResp>(
142 origMsgType,
143 std::move(wrappedRpc),
144 wrappedMsgType,
145 std::move(wrappedMessage->tensors()),
146 std::move(remoteEvents),
147 profilingId);
148 }
149 } // namespace autograd
150 } // namespace distributed
151 } // namespace torch
152