xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h>
3 #include <torch/csrc/distributed/rpc/utils.h>
4 #include <torch/csrc/jit/serialization/pickle.h>
5 
6 namespace torch {
7 namespace distributed {
8 namespace autograd {
9 using rpc::RpcCommandBase;
10 
11 constexpr auto kProfileEventsStartIdx = 3;
12 // This constructor is called when creating the RpcProfilingResp before sending
13 // it as a message over the wire.
RpcWithProfilingResp(rpc::MessageType messageType,c10::intrusive_ptr<rpc::Message> wrappedMessage,std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents,rpc::ProfilingId profilingId)14 RpcWithProfilingResp::RpcWithProfilingResp(
15     rpc::MessageType messageType,
16     c10::intrusive_ptr<rpc::Message> wrappedMessage,
17     std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents,
18     rpc::ProfilingId profilingId)
19     : messageType_(messageType),
20       wrappedMessage_(std::move(wrappedMessage)),
21       tensors_(wrappedMessage_->tensors()),
22       profiledEvents_(std::move(profiledEvents)),
23       profilingId_(profilingId) {
24   TORCH_INTERNAL_ASSERT(
25       messageType_ == rpc::MessageType::RUN_WITH_PROFILING_RESP,
26       "Incorrect Message type");
27   wrappedMessageType_ = wrappedMessage_->type();
28 }
29 // this constructor is called in fromMessage() which is called when
30 // reconstructing this RPC command when processing a message of this type
RpcWithProfilingResp(rpc::MessageType messageType,std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,rpc::MessageType wrappedMessageType,std::vector<torch::Tensor> tensors,std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents,rpc::ProfilingId profilingId)31 RpcWithProfilingResp::RpcWithProfilingResp(
32     rpc::MessageType messageType,
33     std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,
34     rpc::MessageType wrappedMessageType,
35     std::vector<torch::Tensor> tensors,
36     std::vector<torch::autograd::profiler::LegacyEvent> profiledEvents,
37     rpc::ProfilingId profilingId)
38     : messageType_(messageType),
39       wrappedRpc_(std::move(wrappedRpc)),
40       wrappedMessageType_(wrappedMessageType),
41       tensors_(std::move(tensors)),
42       profiledEvents_(std::move(profiledEvents)),
43       profilingId_(profilingId) {
44   TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrapped RPC cannot be null");
45 }
46 
moveWrappedRpc()47 std::unique_ptr<RpcCommandBase> RpcWithProfilingResp::moveWrappedRpc() && {
48   TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
49   return std::move(wrappedRpc_);
50 }
51 
wrappedMessageType() const52 rpc::MessageType RpcWithProfilingResp::wrappedMessageType() const {
53   return wrappedMessageType_;
54 }
55 
56 std::vector<torch::autograd::profiler::LegacyEvent> RpcWithProfilingResp::
getProfiledEvents() const57     getProfiledEvents() const {
58   return profiledEvents_;
59 }
60 
getProfilingId() const61 const rpc::ProfilingId& RpcWithProfilingResp::getProfilingId() const {
62   return profilingId_;
63 }
64 
setWrappedRpc(std::unique_ptr<RpcCommandBase> wrappedRpc)65 void RpcWithProfilingResp::setWrappedRpc(
66     std::unique_ptr<RpcCommandBase> wrappedRpc) {
67   wrappedRpc_ = std::move(wrappedRpc);
68 }
69 
toMessageImpl()70 c10::intrusive_ptr<rpc::Message> RpcWithProfilingResp::toMessageImpl() && {
71   auto wrappedMsgId = wrappedMessage_->id();
72   auto wrappedMsgType = wrappedMessage_->type();
73   auto wrappedPayload = std::move(*wrappedMessage_).movePayload();
74   // Wrapped payload should not be empty
75   TORCH_INTERNAL_ASSERT(
76       !wrappedPayload.empty(), "Wrapped payload cannot be empty");
77   // Create ivalues to send over
78   std::vector<at::IValue> ivalues{wrappedMsgType, profilingId_.toIValue()};
79   // Attach the serialized events.
80   ivalues.emplace_back(static_cast<int32_t>(profiledEvents_.size()));
81   for (const auto& e : profiledEvents_) {
82     ivalues.emplace_back(e.toIValue());
83   }
84   std::vector<torch::Tensor> tensorTable;
85   std::vector<char> profilingPayload =
86       jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);
87   rpc::writeWrappedPayload(wrappedPayload, profilingPayload);
88 
89   auto returnMsg = c10::make_intrusive<rpc::Message>(
90       std::move(wrappedPayload),
91       std::move(tensors_),
92       messageType_,
93       wrappedMsgId);
94   return returnMsg;
95 }
96 
wrappedRpc()97 RpcCommandBase& RpcWithProfilingResp::wrappedRpc() {
98   TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
99   return *wrappedRpc_;
100 }
101 
102 // Runs on client when deserializing this message.
fromMessage(const rpc::Message & message)103 std::unique_ptr<RpcWithProfilingResp> RpcWithProfilingResp::fromMessage(
104     const rpc::Message& message) {
105   rpc::MessageType origMsgType = message.type();
106   std::vector<torch::Tensor> tensors = message.tensors();
107   int64_t msgId = message.id();
108   auto payload = message.payload();
109   auto tupleElements = rpc::readWrappedPayload(payload, message);
110   // Ensure that we have the expected number of elements
111   TORCH_INTERNAL_ASSERT(
112       tupleElements.size() >= kProfileEventsStartIdx,
113       c10::str(
114           "Expected payload size of at least ",
115           kProfileEventsStartIdx,
116           " but got size ",
117           tupleElements.size()));
118   rpc::MessageType wrappedMsgType =
119       static_cast<rpc::MessageType>(tupleElements[0].toInt());
120   rpc::ProfilingId profilingId = rpc::ProfilingId::fromIValue(tupleElements[1]);
121   int profiledEventsSize = tupleElements[2].toInt();
122   std::vector<torch::autograd::profiler::LegacyEvent> remoteEvents;
123   remoteEvents.reserve(profiledEventsSize);
124   for (const auto i : c10::irange(
125            kProfileEventsStartIdx,
126            kProfileEventsStartIdx + profiledEventsSize)) {
127     TORCH_CHECK(static_cast<size_t>(i) < tupleElements.size());
128     // Reconstruct remote event from the ivalues.
129     torch::autograd::profiler::LegacyEvent fromIvalueEvent =
130         torch::autograd::profiler::LegacyEvent::fromIValue(tupleElements[i]);
131     remoteEvents.push_back(std::move(fromIvalueEvent));
132   }
133 
134   auto wrappedMessage = c10::make_intrusive<rpc::Message>(
135       std::move(payload), std::move(tensors), wrappedMsgType, msgId);
136   TORCH_INTERNAL_ASSERT(
137       wrappedMessage->isResponse(),
138       "Messages wrapped with profiling response must be responses.");
139   std::unique_ptr<RpcCommandBase> wrappedRpc =
140       deserializeResponse(*wrappedMessage, wrappedMsgType);
141   return std::make_unique<RpcWithProfilingResp>(
142       origMsgType,
143       std::move(wrappedRpc),
144       wrappedMsgType,
145       std::move(wrappedMessage->tensors()),
146       std::move(remoteEvents),
147       profilingId);
148 }
149 } // namespace autograd
150 } // namespace distributed
151 } // namespace torch
152