xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/message.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/message.h>
2 #include <torch/custom_class.h>
3 
4 namespace torch::distributed::rpc {
5 
6 Message::Message() = default;
7 
Message(std::vector<char> && payload,std::vector<torch::Tensor> && tensors,MessageType type)8 Message::Message(
9     std::vector<char>&& payload,
10     std::vector<torch::Tensor>&& tensors,
11     MessageType type)
12     : payload_(std::move(payload)), tensors_(std::move(tensors)), type_(type) {}
13 
Message(std::vector<char> && payload,std::vector<torch::Tensor> && tensors,MessageType type,int64_t id)14 Message::Message(
15     std::vector<char>&& payload,
16     std::vector<torch::Tensor>&& tensors,
17     MessageType type,
18     int64_t id)
19     : payload_(std::move(payload)),
20       tensors_(std::move(tensors)),
21       type_(type),
22       id_(id) {}
23 
movePayload()24 std::vector<char>&& Message::movePayload() && {
25   return std::move(payload_);
26 }
27 
payload()28 std::vector<char>& Message::payload() {
29   return payload_;
30 }
31 
payload() const32 const std::vector<char>& Message::payload() const {
33   return payload_;
34 }
35 
moveTensors()36 std::vector<torch::Tensor>&& Message::moveTensors() && {
37   return std::move(tensors_);
38 }
39 
tensors()40 std::vector<torch::Tensor>& Message::tensors() {
41   return tensors_;
42 }
43 
tensors() const44 const std::vector<torch::Tensor>& Message::tensors() const {
45   return tensors_;
46 }
47 
type() const48 MessageType Message::type() const {
49   return type_;
50 }
51 
isRequest() const52 bool Message::isRequest() const {
53   return MessageTypeFlags::REQUEST_TYPE & type_;
54 }
55 
isResponse() const56 bool Message::isResponse() const {
57   return MessageTypeFlags::RESPONSE_TYPE & type_;
58 }
59 
id() const60 int64_t Message::id() const {
61   return id_;
62 }
63 
setId(int64_t id)64 void Message::setId(int64_t id) {
65   id_ = id;
66 }
67 
getStorages() const68 std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> Message::getStorages()
69     const {
70   // Sparse tensors do not have storage. Instead, a sparse tensor
71   // contains two tensors indices and values, and both contain storage.
72   std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> storages;
73   storages.reserve(2 * tensors_.size());
74   for (const auto& tensor : tensors_) {
75     if (tensor.is_sparse()) {
76       storages.emplace_back(tensor._indices().storage().getWeakStorageImpl());
77       storages.emplace_back(tensor._values().storage().getWeakStorageImpl());
78     } else {
79       storages.emplace_back(tensor.storage().getWeakStorageImpl());
80     }
81   }
82   return storages;
83 }
84 
createExceptionResponse(const std::exception & e,int64_t id)85 c10::intrusive_ptr<Message> createExceptionResponse(
86     const std::exception& e,
87     int64_t id) {
88   return createExceptionResponse(e.what(), id);
89 }
90 
createExceptionResponse(const std::string & exceptionStr,int64_t id)91 c10::intrusive_ptr<Message> createExceptionResponse(
92     const std::string& exceptionStr,
93     int64_t id) {
94   std::vector<char> payload(exceptionStr.begin(), exceptionStr.end());
95   return c10::make_intrusive<Message>(
96       std::move(payload),
97       std::vector<torch::Tensor>(),
98       MessageType::EXCEPTION,
99       id);
100 }
101 
102 namespace {
103 
104 // NB: need to call torch::class_ to register Message in the map returned by
105 // c10::getCustomClassTypeMap(). Otherwise, Message cannot be wrapped within
106 // an IValue.
107 // NB: add this line here instead of in rpc/init.cpp because 1) we have C++
108 // only tests that won't run rpc/init.cpp; 2) Message is not meant to be
109 // visible from Python.
110 static const auto message = torch::class_<Message>("rpc", "_Message");
111 
112 } // namespace
113 
114 } // namespace torch::distributed::rpc
115