1 #include <torch/csrc/distributed/rpc/message.h>
2 #include <torch/custom_class.h>
3
4 namespace torch::distributed::rpc {
5
6 Message::Message() = default;
7
Message(std::vector<char> && payload,std::vector<torch::Tensor> && tensors,MessageType type)8 Message::Message(
9 std::vector<char>&& payload,
10 std::vector<torch::Tensor>&& tensors,
11 MessageType type)
12 : payload_(std::move(payload)), tensors_(std::move(tensors)), type_(type) {}
13
Message(std::vector<char> && payload,std::vector<torch::Tensor> && tensors,MessageType type,int64_t id)14 Message::Message(
15 std::vector<char>&& payload,
16 std::vector<torch::Tensor>&& tensors,
17 MessageType type,
18 int64_t id)
19 : payload_(std::move(payload)),
20 tensors_(std::move(tensors)),
21 type_(type),
22 id_(id) {}
23
movePayload()24 std::vector<char>&& Message::movePayload() && {
25 return std::move(payload_);
26 }
27
payload()28 std::vector<char>& Message::payload() {
29 return payload_;
30 }
31
payload() const32 const std::vector<char>& Message::payload() const {
33 return payload_;
34 }
35
moveTensors()36 std::vector<torch::Tensor>&& Message::moveTensors() && {
37 return std::move(tensors_);
38 }
39
tensors()40 std::vector<torch::Tensor>& Message::tensors() {
41 return tensors_;
42 }
43
tensors() const44 const std::vector<torch::Tensor>& Message::tensors() const {
45 return tensors_;
46 }
47
type() const48 MessageType Message::type() const {
49 return type_;
50 }
51
isRequest() const52 bool Message::isRequest() const {
53 return MessageTypeFlags::REQUEST_TYPE & type_;
54 }
55
isResponse() const56 bool Message::isResponse() const {
57 return MessageTypeFlags::RESPONSE_TYPE & type_;
58 }
59
id() const60 int64_t Message::id() const {
61 return id_;
62 }
63
setId(int64_t id)64 void Message::setId(int64_t id) {
65 id_ = id;
66 }
67
getStorages() const68 std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> Message::getStorages()
69 const {
70 // Sparse tensors do not have storage. Instead, a sparse tensor
71 // contains two tensors indices and values, and both contain storage.
72 std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> storages;
73 storages.reserve(2 * tensors_.size());
74 for (const auto& tensor : tensors_) {
75 if (tensor.is_sparse()) {
76 storages.emplace_back(tensor._indices().storage().getWeakStorageImpl());
77 storages.emplace_back(tensor._values().storage().getWeakStorageImpl());
78 } else {
79 storages.emplace_back(tensor.storage().getWeakStorageImpl());
80 }
81 }
82 return storages;
83 }
84
createExceptionResponse(const std::exception & e,int64_t id)85 c10::intrusive_ptr<Message> createExceptionResponse(
86 const std::exception& e,
87 int64_t id) {
88 return createExceptionResponse(e.what(), id);
89 }
90
createExceptionResponse(const std::string & exceptionStr,int64_t id)91 c10::intrusive_ptr<Message> createExceptionResponse(
92 const std::string& exceptionStr,
93 int64_t id) {
94 std::vector<char> payload(exceptionStr.begin(), exceptionStr.end());
95 return c10::make_intrusive<Message>(
96 std::move(payload),
97 std::vector<torch::Tensor>(),
98 MessageType::EXCEPTION,
99 id);
100 }
101
102 namespace {
103
104 // NB: need to call torch::class_ to register Message in the map returned by
105 // c10::getCustomClassTypeMap(). Otherwise, Message cannot be wrapped within
106 // an IValue.
107 // NB: add this line here instead of in rpc/init.cpp because 1) we have C++
108 // only tests that won't run rpc/init.cpp; 2) Message is not meant to be
109 // visible from Python.
110 static const auto message = torch::class_<Message>("rpc", "_Message");
111
112 } // namespace
113
114 } // namespace torch::distributed::rpc
115