xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/message.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/types.h>
4 #include <vector>
5 
6 namespace torch {
7 namespace distributed {
8 namespace rpc {
9 
10 // An enum denoting common RPC errors to allow specific error handling for them.
11 enum RPCErrorType {
12   UNKNOWN_ERROR = 0, /* Indicates that error type could not be parsed */
13   TIMEOUT = 1, /* Indicates that the RPC has timed out */
14   INTENTIONAL_FAILURE = 2 /* Deliberate failure, such as those injected by
15                              FaultyAgent for testing */
16 };
17 
18 // The enum values are bitwise ORed with MessageType
19 // They are bit flags starting from 0x100 and should have
20 // value such as 0x100, 0x200, 0x400, 0x800, 0xF00, etc.
21 enum MessageTypeFlags {
22   REQUEST_TYPE = 0x100,
23   RESPONSE_TYPE = 0x200,
24 };
25 
26 // Message types must have values between 0x00 to 0xff
27 enum MessageType {
28   // messages for dist.rpc on builtin operators
29   SCRIPT_CALL = 0x00 | MessageTypeFlags::REQUEST_TYPE,
30   SCRIPT_RET = 0x01 | MessageTypeFlags::RESPONSE_TYPE,
31 
32   // messages for dist.rpc on Python UDF
33   PYTHON_CALL = 0x02 | MessageTypeFlags::REQUEST_TYPE,
34   PYTHON_RET = 0x03 | MessageTypeFlags::RESPONSE_TYPE,
35 
36   // messages for dist.remote on builtin operators and Python UDF
37   SCRIPT_REMOTE_CALL = 0x04 |
38       MessageTypeFlags::REQUEST_TYPE, // A remote call on a builtin operator
39   PYTHON_REMOTE_CALL =
40       0x05 | MessageTypeFlags::REQUEST_TYPE, // A remote call on a Python UDF
41   REMOTE_RET =
42       0x06 | MessageTypeFlags::RESPONSE_TYPE, // Response for remote calls for
43                                               // UDF, builtin, or script
44 
45   // RRef related internal messages
46   SCRIPT_RREF_FETCH_CALL =
47       0x07 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef<IValue> fetches value
48                                              // from owner
49   PYTHON_RREF_FETCH_CALL =
50       0x08 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef<py::object> fetches
51                                              // value from owner
52   SCRIPT_RREF_FETCH_RET = 0x09 |
53       MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends ivalue to user
54   PYTHON_RREF_FETCH_RET = 0x0a |
55       MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends py::object to user
56   RREF_USER_DELETE = 0x0b |
57       MessageTypeFlags::REQUEST_TYPE, // A UserRRef tells the owner to deref
58   RREF_FORK_REQUEST =
59       0x0c | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells the owner
60                                              // about itself
61   RREF_CHILD_ACCEPT =
62       0x0d | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells parent
63                                              // that owner knows it
64   RREF_ACK =
65       0x0e | MessageTypeFlags::RESPONSE_TYPE, // ACK to internal RRef messages
66 
67   // Messages with autograd info
68   FORWARD_AUTOGRAD_REQ = 0x0f | MessageTypeFlags::REQUEST_TYPE,
69   FORWARD_AUTOGRAD_RESP = 0x10 | MessageTypeFlags::RESPONSE_TYPE,
70 
71   // Messages to propagate gradients on the backward pass.
72   BACKWARD_AUTOGRAD_REQ = 0x11 | MessageTypeFlags::REQUEST_TYPE,
73   BACKWARD_AUTOGRAD_RESP = 0x12 | MessageTypeFlags::RESPONSE_TYPE,
74 
75   // Messages to tell workers to clean up their autograd context.
76   CLEANUP_AUTOGRAD_CONTEXT_REQ = 0x13 | MessageTypeFlags::REQUEST_TYPE,
77   CLEANUP_AUTOGRAD_CONTEXT_RESP = 0x14 | MessageTypeFlags::RESPONSE_TYPE,
78 
79   // Messages that tell workers to run requests with profiling enabled.
80   RUN_WITH_PROFILING_REQ = 0x15 | MessageTypeFlags::REQUEST_TYPE,
81   RUN_WITH_PROFILING_RESP = 0x16 | MessageTypeFlags::RESPONSE_TYPE,
82 
83   // Messages to support RRef.backward().
84   RREF_BACKWARD_REQ = 0x17 | MessageTypeFlags::REQUEST_TYPE,
85   RREF_BACKWARD_RESP = 0x18 | MessageTypeFlags::RESPONSE_TYPE,
86 
87   // Other internal message types
88   EXCEPTION = 0x37 | MessageTypeFlags::RESPONSE_TYPE,
89   UNKNOWN = 0x3c
90 };
91 
92 // A message to be sent/received by an RpcAgent.
93 //
94 // A Message object contains 4 fields:
95 //    payload (std::vector<char>): a binary chunk of data.
96 //    tensors (std::vector<torch::Tensor>): all tensors. Tensor data are not
97 //        included in the payload, and it is up to the RpcAgent implementation
98 //        to determine how to serialize them. This design is helpful for
99 //        communicating super large tensors where serializing all the data at
100 //        once leads to excessively large memory footprint. An implementation
101 //        can then serialize and send tensors chunk-by-chunk, in the streaming
102 //        fashion.
103 //    type (MessageType): type of the message.
104 //    id (int64_t): message id, this is used to match request and response.
105 //               Other implementation can ignore it if they have their own
106 //               ways to do matching.
107 //
108 // Layers above ``RpcAgent`` only converts ScriptCall, ScriptResp, PythonCall,
109 // and PythonResp into a Message, and it is up to the RpcAgent
110 // implementation to determine how to serialize a message.
111 class TORCH_API Message final : public torch::CustomClassHolder {
112  private:
113   // Keep these private in order to force users to go through make_intrusive and
114   // thus prevent creating a Message that's not held by an intrusive_ptr.
115   Message();
116 
117   Message(
118       std::vector<char>&& payload,
119       std::vector<torch::Tensor>&& tensors,
120       MessageType type);
121 
122   Message(
123       std::vector<char>&& payload,
124       std::vector<torch::Tensor>&& tensors,
125       MessageType type,
126       int64_t id);
127 
128   friend c10::intrusive_ptr<Message>;
129 
130  public:
131   Message(const Message& other) = delete;
132   Message(Message&& other) = delete;
133   Message& operator=(Message const& rhs) = delete;
134   Message& operator=(Message&& rhs) = delete;
135 
136   // Destructively retrieves the payload.
137   std::vector<char>&& movePayload() &&;
138   std::vector<torch::Tensor>&& moveTensors() &&;
139 
140   std::vector<char>& payload();
141   const std::vector<char>& payload() const;
142   std::vector<torch::Tensor>& tensors();
143   const std::vector<torch::Tensor>& tensors() const;
144   MessageType type() const;
145 
146   bool isRequest() const;
147   bool isResponse() const;
148   bool isShutdown() const;
149 
150   // id is an optional field to match request/response. If an RpcAgent
151   // implementation is able to do the matching without using this id, it can be
152   // dropped during message serialization.
153   int64_t id() const;
154   void setId(int64_t id);
155 
156   std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> getStorages() const;
157 
158  private:
159   std::vector<char> payload_;
160   std::vector<torch::Tensor> tensors_;
161   MessageType type_ = MessageType::UNKNOWN;
162   int64_t id_ = -1;
163 };
164 
165 // Create a response Message of type Exception.
166 // The exception string representation will be used as the message's payload.
167 // A message ID corresponding to the request that resulted in this response can
168 // be provided for matching requests/responses.
169 TORCH_API c10::intrusive_ptr<Message> createExceptionResponse(
170     const std::exception& e,
171     int64_t id);
172 
173 // Create a response Message of type Exception.
174 // The passed in string representation will be used as the message's payload.
175 // A message ID corresponding to the request that resulted in this response can
176 // be provided for matching requests/responses.
177 TORCH_API c10::intrusive_ptr<Message> createExceptionResponse(
178     const std::string& exceptionStr,
179     int64_t id);
180 
181 inline std::tuple<
182     c10::intrusive_ptr<Message>,
183     std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>>>
withStorages(c10::intrusive_ptr<Message> message)184 withStorages(c10::intrusive_ptr<Message> message) {
185   auto storages = message->getStorages();
186   return std::make_tuple(std::move(message), std::move(storages));
187 }
188 
189 using JitFuture = c10::ivalue::Future;
190 
191 } // namespace rpc
192 } // namespace distributed
193 } // namespace torch
194