1 #pragma once
2
3 #include <ATen/core/jit_type.h>
4 #include <ATen/core/rref_interface.h>
5 #include <c10/core/Event.h>
6 #include <torch/csrc/distributed/rpc/message.h>
7 #include <torch/csrc/distributed/rpc/rpc_agent.h>
8 #include <torch/csrc/distributed/rpc/types.h>
9 #include <optional>
10
11 #include <atomic>
12
13 namespace torch::distributed::rpc {
14
15 class RRef;
16 class RRefContext;
17 class UserRRef;
18
19 constexpr int OWNER_IDX = 0; // index of ownerId in the tuple
20 constexpr int RREFID_ON_IDX = 1; // index of RRefId.createdOn_ in the tuple
21 constexpr int RREFID_ID_IDX = 2; // index of RRefId.localId_ in the tuple
22 constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple
23 constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple
24 constexpr int PARENT_IDX = 5; // index of parent in the tuple
25 constexpr int TYPE_IDX = 6; // index of parent in the tuple
26
27 // NB: if more fields are added, make sure this field is also bumped
28 constexpr int RFD_TUPLE_SIZE = 7; // number of RRefForkData fields in py::tuple
29
30 // Represents fork of an RRef to be sent over the wire.
31 struct TORCH_API RRefForkData {
32 const worker_id_t ownerId_;
33 const RRefId rrefId_;
34 const ForkId forkId_;
35 const worker_id_t parent_;
36 const std::string typeStr_;
37
38 RRefForkData(
39 worker_id_t ownerId,
40 const RRefId& rrefId,
41 const ForkId& forkId,
42 worker_id_t parent,
43 std::string typeStr);
44 };
45
46 // Note [RRef Protocol]
47 // ~~~~~~~~~~~~~~~~~~~~~~~~~~
48 //
49 // [Background]
50 //
51 // RRef stands for Remote REFerence. Each RRef is owned by a single worker
52 // (i.e., owner) and can be used by multiple users. The owner stores the real
53 // data referenced by its RRefs. RRef needs to support fast and scalable RPC.
54 // Hence, in the design, we avoid using a single global master to keep RRef
55 // states, instead owners will keep track of the global reference counts
56 // for its RRefs. Every RRef can be uniquely identified by a global RRefId,
57 // which is assigned at the time it is first created either on a user or on the
58 // owner.
59 //
60 // On the owner worker, there is only one OwnerRRef instance, which contains the
61 // real data, while on user workers, there can be as many UserRRefs as
62 // necessary, and UserRRef does not hold the data. All usage on the OwnerRRef
63 // should retrieve the unique OwnerRRef instance using the globally unique
64 // RRefId. //A UserRRef will be created when it is used as an argument or return
65 // value in dist.rpc or dist.remote call, but RRef forking and reference
66 // counting (RC) are completely transparent to applications. Every UserRRef will
67 // also have its globally unique ForkId.
68 //
69 // [Assumptions]
70 //
71 // 1. Transient Network Failures
72 //
73 // TODO: current RRef implementation does not tolerate failures
74 //
75 // The RRef design handles transient network failures by retrying
76 // messages. Node crashes or permanent network partition is beyond the scope.
77 // When those incidents occur, the application may take down all workers, revert
78 // to the previous checkpoint, and resume training.
79 //
80 // 2. Non-idempotent UDFs
81 //
82 // We assume UDFs are not idempotent and therefore cannot be retried. However,
83 // internal RRef control messages are idempotent and retried upon message
84 // failure.
85 //
86 // TODO: RRef internal messages are not yet idempotent
87 //
88 // 3. Out of Order Message Delivery
89 //
90 // We do not assume message delivery order between any pair of nodes, because
91 // both sender and receiver are using multiple threads. There is no guarantee on
92 // which message will be processed first.
93 //
94 // [RRef Lifetime]
95 //
96 // The goal of the protocol is to delete an OwnerRRef at an appropriate time.
97 // The right time to delete an OwnerRRef is when there are no living UserRRefs
98 // and Python GC also agrees to delete the OwnerRRef instance on the owner. The
99 // tricky part is to determine if there are any living UserRRefs.
100 //
101 // A user can get a UserRRef in three situations:
102 //
103 // (1). Receiving a UserRRef from the owner.
104 // (2). Receiving a UserRRef from another user.
105 // (3). Creating a new UserRRef owned by another worker.
106 //
107 // (1) is the simplest case where the owner initiates the fork, and hence it can
108 // easily increment local RC. The only requirement is that any UserRRef must
109 // notify the owner before destruction. Hence, we need the first guarantee:
110 //
111 // G1. The owner will be notified when any UserRRef is deleted.
112 //
113 // As messages might come delayed or out-of-order, we need more one guarantee to
114 // make sure the delete message is not sent out too soon. Let us first introduce
115 // a new concept. If A sends an RPC to B that involves an RRef, we call the RRef
116 // on A the parent RRef and the RRef on B the child RRef.
117 //
118 // G2. Parent RRef cannot be deleted until the child RRef is confirmed by the
119 // owner.
120 //
121 // Under (1), where the caller is UserRRef and callee is OwnerRRef, it simply
122 // means that the user will not send out the delete message until all previous
123 // messages are ACKed. Note that ACKed does not mean the owner finishes
124 // executing the function, instead, it only means the owner has retrieved its
125 // local OwnerRRef and about to pass it to the function, which is sufficient to
126 // keep the OwnerRRef alive even if the delete message from the user arrives at
127 // the owner before the function finishes execution.
128 //
129 // With (2) and (3), it is possible that the owner only partially knows the RRef
130 // fork graph or not even knowing it at all. For example, the RRef could be
131 // constructed on a user, and before the owner receives the RPC call, the
132 // creator user might have already shared the RRef with other users, and those
133 // users could further share the RRef. One invariant is that the fork graph of
134 // any RRef is always a tree rooted at the owner, because forking an RRef always
135 // creates a new RRef instance, and hence every RRef has a single parent. One
136 // nasty detail is that when an RRef is created on a user, technically the owner
137 // is not its parent but we still consider it that way and it does not break the
138 // argument below.
139 //
140 // The owner's view on any node (fork) in the tree has three stages:
141 //
142 // 1) unknown -> 2) known -> 3) deleted.
143 //
144 // The owner's view on the entire tree keeps changing. The owner deletes its
145 // OwnerRRef instance when it thinks there are no living UserRRefs, i.e., when
146 // OwnerRRef is deleted, all UserRRefs could be either indeed deleted or
147 // unknown. The dangerous case is when some forks are unknown and others are
148 // deleted.
149 //
150 // G2 trivially guarantees that no parent UserRRef Y can be deleted before the
151 // owner knows all of Y's children UserRRefs.
152 //
153 // However, it is possible that the child UserRRef Z may be deleted before the
154 // owner knows its parent Y. More specifically, this can happen when all of Z's
155 // messages are processed by the owner before all messages from Y, including the
156 // delete message. Nevertheless, this does not cause any problem. Because, at
157 // least one of Y's ancestor will be alive, and it will prevent the owner from
158 // deleting the OwnerRRef. Consider the following example: (NB: this scenario
159 // will no longer relevant when we block UDF until all RRefs are confirmed by
160 // the owner)
161 //
162 // OwnerRRef -> A -> Y -> Z
163 //
164 // OwnerRRef forks to A, then A forks to Y, and Y forks to Z. Z can be deleted
165 // without OwnerRRef knowing Y. However, the OwnerRRef will at least know A, as
166 // the owner directly forks the RRef to A. A won't die before the owner knows Y.
167 //
168 // Things get a little trickier if the RRef is created on a user:
169 //
170 // OwnerRRef
171 // ^
172 // |
173 // A -> Y -> Z
174 //
175 // If Z calls to_here on the UserRRef, the owner at least knows A when Z is
176 // deleted, because otherwise to_here wouldn't finish. If Z does not call
177 // to_here, it is possible that the owner receives all messages from Z before
178 // any message from A and Y. In this case, as the real data of the OwnerRRef has
179 // not been created yet, there is nothing to be deleted either. It is the same
180 // as Z does not exist at all Hence, it's still OK.
181 //
182 // See #26759 for more details and discussions.
183 //
184 // TODO: make RRef an IValue, and edit createStackForSchema accordingly
185 // TODO: make RRef system messages idempotent and retry on failures.
186 //
187 // ``RRef`` is the base type for both ``UserRRef`` and ``OwnerRRef``.
188 // Each ``RRef`` has a globally unique ``RRefId``.
189 class TORCH_API RRef : public RRefInterface {
190 public:
191 // RRef is made NOT copyable NOT movable to prevent messing up reference
192 // counting.
193 explicit RRef(const RRef& other) = delete;
194 explicit RRef(RRef&& other) = delete;
195 RRef& operator=(RRef&& other) = delete;
196
197 ~RRef() override = default;
198
199 // returns the worker id of the owner
owner()200 inline worker_id_t owner() const override {
201 return ownerId_;
202 }
203
204 // returns the worker name of the owner
ownerName()205 inline std::string ownerName() const override {
206 return RpcAgent::getCurrentRpcAgent()->getWorkerInfo(ownerId_).name_;
207 }
208
209 // returns the worker info of the owner
ownerWorkerInfo()210 inline WorkerInfo ownerWorkerInfo() const {
211 return RpcAgent::getCurrentRpcAgent()->getWorkerInfo(ownerId_);
212 }
213
214 // Returns the globally unique RRefId of this RRef
rrefId()215 inline const RRefId& rrefId() const {
216 return rrefId_;
217 }
218
isPyObj()219 inline bool isPyObj() const {
220 return type_ == PyObjectType::get();
221 }
type()222 inline const TypePtr type() const override {
223 return type_;
224 }
225
226 // Save the future corresponding to the creation of this RRef on a remote
227 // node. Note that this is only set when processing requests invoked with
228 // rpc.remote. This is only used to get the future corresponding to the rref
229 // for profiling use cases.
registerOwnerCreationFuture(c10::intrusive_ptr<JitFuture> fut)230 inline void registerOwnerCreationFuture(c10::intrusive_ptr<JitFuture> fut) {
231 ownerCreationFuture_ = std::move(fut);
232 }
233
234 // Get the future corresponding to the creation of this rref.
getOwnerCreationFuture()235 inline c10::intrusive_ptr<JitFuture> getOwnerCreationFuture() const {
236 return ownerCreationFuture_;
237 }
238
239 // Check if creation of this RRef on owner node has timed out.
getTimedOut()240 inline bool getTimedOut() const {
241 return timedOut_.load();
242 }
243
244 // Dispatches an error to the correct handler based on its RPCErrorType.
245 void handleError(RPCErrorType errorType, const JitFuture& JitFuture);
246
247 // Send delete UserRRef request to Owner,
248 // if the request hasn't been sent yet.
249 // There are 2 cases to call it,
250 // 1, Python GC decides end of UserRRef lifetime, calling destructor.
251 // 2, RPC module graceful shutdown calls it on all UserRRefs tracked
252 // in the RRefContext.
tryDel()253 virtual void tryDel() {}
254
255 protected:
256 // Indicates that the creation of this RRef on owner node has timed out.
setTimedOut()257 inline void setTimedOut() {
258 timedOut_ = true;
259 }
260 friend class RRefContext;
261
262 RRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type);
263
264 virtual RRefForkData fork() const;
265
266 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
267 const worker_id_t ownerId_;
268 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
269 const RRefId rrefId_;
270 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
271 std::atomic<bool> timedOut_{false};
272
273 // type field to denote the type of the element that the RRef is holding
274 // it could be any TypePtr that JIT support, including PyObjectType
275 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
276 const TypePtr type_;
277 // Future corresponding to request to create RRef on remote node.
278 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
279 c10::intrusive_ptr<JitFuture> ownerCreationFuture_;
280 };
281
282 // ``UserRRef`` represents a user of an RRef. Besides the ``RRefId``, each user
283 // also has a globally unique ``ForkId`` to identify this user. ``UserRRef``
284 // never owns the real value, the only way to get the value of the ``RRef`` is
285 // to call ``to_here()`` and get a copy..
286 class TORCH_API UserRRef final : public RRef {
287 public:
288 UserRRef(const UserRRef& other) = delete;
289 UserRRef(UserRRef&& other) = delete;
290 UserRRef& operator=(const UserRRef& other) = delete;
291 UserRRef& operator=(UserRRef&& other) = delete;
292
293 UserRRef(
294 worker_id_t ownerId,
295 const RRefId& rrefId,
296 const ForkId& forkId,
297 TypePtr type);
298
isOwner()299 inline bool isOwner() const override {
300 return false;
301 }
302
confirmedByOwner()303 inline bool confirmedByOwner() const override {
304 return confirmedByOwner_;
305 }
306
307 // Returns the globally unique ForkId of this RRef
308 const ForkId& forkId() const;
309
310 // Get of copy of the value from the ``OwnerRRef``. If the value is not ready
311 // yet, this call will block.
312 IValue toHere(
313 const float timeoutSeconds =
314 torch::distributed::rpc::kUnsetRpcTimeout) const;
315
316 void tryDel() override;
317
318 // Will be called when refcount reaches 0.
319 // Upon destruction, this ``UserRRef`` will tell the owner to deref.
320 void release_resources() override;
321
322 // Will be called when both refcount and weakcount reach 0. See
323 // https://github.com/pytorch/pytorch/blob/9116f02bebf3a5260feef5732d36c54ecb3b4033/c10/util/intrusive_ptr.h#L204
324 // This is called on destructing the wrapping intrusive_ptr_target instance
325 // and it's data members.
326 ~UserRRef() override;
327
328 private:
329 friend class RRefContext;
330
331 RRefForkData fork() const override;
confirm()332 inline void confirm() {
333 confirmedByOwner_ = true;
334 }
335
336 const ForkId forkId_;
337
338 // Indicates if this user has sent delete message to it's owner.
339 // Note, thread safety is needed because delete message could be sent by
340 // either the destructor called by Python garbage collection or RRefContext
341 // proactive cleanup on RPC graceful shutdown.
342 std::mutex deletedOnOwnerMutex_;
343 bool deletedOnOwner_{false};
344 // Indicating whether this UserRRef has been confirmed by its owner.
345 std::atomic<bool> confirmedByOwner_;
346 };
347
348 // Keep the template only on the derived class because ``RRefContext`` needs to
349 // erase the type on ``RRef`` and keep them in one map.
350 class TORCH_API OwnerRRef final : public RRef {
351 public:
352 OwnerRRef(const OwnerRRef& other) = delete;
353 OwnerRRef(OwnerRRef&& other) = delete;
354 OwnerRRef& operator=(const OwnerRRef& other) = delete;
355 OwnerRRef& operator=(OwnerRRef&& other) = delete;
356
357 OwnerRRef(
358 worker_id_t ownerId,
359 const RRefId& rrefId,
360 TypePtr type,
361 std::vector<c10::Device> devices);
362
363 OwnerRRef(
364 worker_id_t ownerId,
365 const RRefId& rrefId,
366 TypePtr type,
367 std::optional<IValue> value,
368 std::vector<c10::Device> devices);
369
isOwner()370 inline bool isOwner() const override {
371 return true;
372 }
373
374 // OwnerRRef is always confirmed, while UserRRef is only confirmed when the
375 // owner knows about it.
confirmedByOwner()376 inline bool confirmedByOwner() const override {
377 return true;
378 }
379
380 // Get a constant reference of the real value. This method will block if the
381 // value is not ready. This method does not need GIL as it does not create
382 // any new py::object. It will throw if there is an error.
383 const IValue& getValue() const;
384
385 // Set the value of this ``OwnerRRef``. This method does not need GIL as it
386 // does not create any new py::object.
387 void setValue(IValue&& value);
388 // Sets the value of this ``OwnerRRef`` to contain an exception.
389 void setError(std::exception_ptr eptr);
390
391 // Has a value or error been set?
392 bool hasValue() const;
393 // Gets a future that is satisfied when the value or error is set.
394 c10::intrusive_ptr<JitFuture> getFuture();
395
396 private:
397 friend class RRefContext;
398
399 c10::intrusive_ptr<JitFuture> future_;
400 };
401
402 TORCH_API std::ostream& operator<<(std::ostream& os, const RRef& rref);
403
404 // Helper function that casts from c10::RRefInterface to OwnerRRef
fromRRefInterface(const c10::intrusive_ptr<c10::RRefInterface> & rrefInterface)405 inline TORCH_API c10::intrusive_ptr<OwnerRRef> fromRRefInterface(
406 const c10::intrusive_ptr<c10::RRefInterface>& rrefInterface) {
407 return c10::static_intrusive_pointer_cast<OwnerRRef>(rrefInterface);
408 }
409
410 // Helper function that casts from OwnerRRef to c10::RRefInterface
fromOwnerRRef(const c10::intrusive_ptr<RRef> & ownerRRef)411 inline TORCH_API c10::intrusive_ptr<c10::RRefInterface> fromOwnerRRef(
412 const c10::intrusive_ptr<RRef>& ownerRRef) {
413 return c10::static_intrusive_pointer_cast<c10::RRefInterface>(ownerRRef);
414 }
415
416 } // namespace torch::distributed::rpc
417