xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/rref_impl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/jit_type.h>
4 #include <ATen/core/rref_interface.h>
5 #include <c10/core/Event.h>
6 #include <torch/csrc/distributed/rpc/message.h>
7 #include <torch/csrc/distributed/rpc/rpc_agent.h>
8 #include <torch/csrc/distributed/rpc/types.h>
9 #include <optional>
10 
11 #include <atomic>
12 
13 namespace torch::distributed::rpc {
14 
15 class RRef;
16 class RRefContext;
17 class UserRRef;
18 
19 constexpr int OWNER_IDX = 0; // index of ownerId in the tuple
20 constexpr int RREFID_ON_IDX = 1; // index of RRefId.createdOn_ in the tuple
21 constexpr int RREFID_ID_IDX = 2; // index of RRefId.localId_ in the tuple
22 constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple
23 constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple
24 constexpr int PARENT_IDX = 5; // index of parent in the tuple
25 constexpr int TYPE_IDX = 6; // index of parent in the tuple
26 
27 // NB: if more fields are added, make sure this field is also bumped
28 constexpr int RFD_TUPLE_SIZE = 7; // number of RRefForkData fields in py::tuple
29 
30 // Represents fork of an RRef to be sent over the wire.
31 struct TORCH_API RRefForkData {
32   const worker_id_t ownerId_;
33   const RRefId rrefId_;
34   const ForkId forkId_;
35   const worker_id_t parent_;
36   const std::string typeStr_;
37 
38   RRefForkData(
39       worker_id_t ownerId,
40       const RRefId& rrefId,
41       const ForkId& forkId,
42       worker_id_t parent,
43       std::string typeStr);
44 };
45 
46 // Note [RRef Protocol]
47 // ~~~~~~~~~~~~~~~~~~~~~~~~~~
48 //
49 // [Background]
50 //
51 // RRef stands for Remote REFerence. Each RRef is owned by a single worker
52 // (i.e., owner) and can be used by multiple users. The owner stores the real
53 // data referenced by its RRefs. RRef needs to support fast and scalable RPC.
54 // Hence, in the design, we avoid using a single global master to keep RRef
55 // states, instead owners will keep track of the global reference counts
56 // for its RRefs. Every RRef can be uniquely identified by a global RRefId,
57 // which is assigned at the time it is first created either on a user or on the
58 // owner.
59 //
60 // On the owner worker, there is only one OwnerRRef instance, which contains the
61 // real data, while on user workers, there can be as many UserRRefs as
62 // necessary, and UserRRef does not hold the data. All usage on the OwnerRRef
63 // should retrieve the unique OwnerRRef instance using the globally unique
64 // RRefId. //A UserRRef will be created when it is used as an argument or return
65 // value in dist.rpc or dist.remote call, but RRef forking and reference
66 // counting (RC) are completely transparent to applications. Every UserRRef will
67 // also have its globally unique ForkId.
68 //
69 // [Assumptions]
70 //
71 // 1. Transient Network Failures
72 //
73 // TODO: current RRef implementation does not tolerate failures
74 //
75 // The RRef design handles transient network failures by retrying
76 // messages. Node crashes or permanent network partition is beyond the scope.
77 // When those incidents occur, the application may take down all workers, revert
78 // to the previous checkpoint, and resume training.
79 //
80 // 2. Non-idempotent UDFs
81 //
82 // We assume UDFs are not idempotent and therefore cannot be retried. However,
83 // internal RRef control messages are idempotent and retried upon message
84 // failure.
85 //
86 // TODO: RRef internal messages are not yet idempotent
87 //
88 // 3. Out of Order Message Delivery
89 //
90 // We do not assume message delivery order between any pair of nodes, because
91 // both sender and receiver are using multiple threads. There is no guarantee on
92 // which message will be processed first.
93 //
94 // [RRef Lifetime]
95 //
96 // The goal of the protocol is to delete an OwnerRRef at an appropriate time.
97 // The right time to delete an OwnerRRef is when there are no living UserRRefs
98 // and Python GC also agrees to delete the OwnerRRef instance on the owner. The
99 // tricky part is to determine if there are any living UserRRefs.
100 //
101 // A user can get a UserRRef in three situations:
102 //
103 // (1). Receiving a UserRRef from the owner.
104 // (2). Receiving a UserRRef from another user.
105 // (3). Creating a new UserRRef owned by another worker.
106 //
107 // (1) is the simplest case where the owner initiates the fork, and hence it can
108 // easily increment local RC. The only requirement is that any UserRRef must
109 // notify the owner before destruction. Hence, we need the first guarantee:
110 //
111 // G1. The owner will be notified when any UserRRef is deleted.
112 //
113 // As messages might come delayed or out-of-order, we need more one guarantee to
114 // make sure the delete message is not sent out too soon. Let us first introduce
115 // a new concept. If A sends an RPC to B that involves an RRef, we call the RRef
116 // on A the parent RRef and the RRef on B the child RRef.
117 //
118 // G2. Parent RRef cannot be deleted until the child RRef is confirmed by the
119 //     owner.
120 //
121 // Under (1), where the caller is UserRRef and callee is OwnerRRef, it simply
122 // means that the user will not send out the delete message until all previous
123 // messages are ACKed. Note that ACKed does not mean the owner finishes
124 // executing the function, instead, it only means the owner has retrieved its
125 // local OwnerRRef and about to pass it to the function, which is sufficient to
126 // keep the OwnerRRef alive even if the delete message from the user arrives at
127 // the owner before the function finishes execution.
128 //
129 // With (2) and (3), it is possible that the owner only partially knows the RRef
130 // fork graph or not even knowing it at all. For example, the RRef could be
131 // constructed on a user, and before the owner receives the RPC call, the
132 // creator user might have already shared the RRef with other users, and those
133 // users could further share the RRef. One invariant is that the fork graph of
134 // any RRef is always a tree rooted at the owner, because forking an RRef always
135 // creates a new RRef instance, and hence every RRef has a single parent. One
136 // nasty detail is that when an RRef is created on a user, technically the owner
137 // is not its parent but we still consider it that way and it does not break the
138 // argument below.
139 //
140 // The owner's view on any node (fork) in the tree has three stages:
141 //
142 //       1) unknown -> 2) known -> 3) deleted.
143 //
144 // The owner's view on the entire tree keeps changing. The owner deletes its
145 // OwnerRRef instance when it thinks there are no living UserRRefs, i.e., when
146 // OwnerRRef is deleted, all UserRRefs could be either indeed deleted or
147 // unknown. The dangerous case is when some forks are unknown and others are
148 // deleted.
149 //
150 // G2 trivially guarantees that no parent UserRRef Y can be deleted before the
151 // owner knows all of Y's children UserRRefs.
152 //
153 // However, it is possible that the child UserRRef Z may be deleted before the
154 // owner knows its parent Y. More specifically, this can happen when all of Z's
155 // messages are processed by the owner before all messages from Y, including the
156 // delete message. Nevertheless, this does not cause any problem. Because, at
157 // least one of Y's ancestor will be alive, and it will prevent the owner from
158 // deleting the OwnerRRef. Consider the following example: (NB: this scenario
159 // will no longer relevant when we block UDF until all RRefs are confirmed by
160 // the owner)
161 //
162 //     OwnerRRef -> A -> Y -> Z
163 //
164 // OwnerRRef forks to A, then A forks to Y, and Y forks to Z. Z can be deleted
165 // without OwnerRRef knowing Y. However, the OwnerRRef will at least know A, as
166 // the owner directly forks the RRef to A. A won't die before the owner knows Y.
167 //
168 // Things get a little trickier if the RRef is created on a user:
169 //
170 //  OwnerRRef
171 //      ^
172 //      |
173 //      A -> Y -> Z
174 //
175 // If Z calls to_here on the UserRRef, the owner at least knows A when Z is
176 // deleted, because otherwise to_here wouldn't finish. If Z does not call
177 // to_here, it is possible that the owner receives all messages from Z before
178 // any message from A and Y. In this case, as the real data of the OwnerRRef has
179 // not been created yet, there is nothing to be deleted either. It is the same
180 // as Z does not exist at all Hence, it's still OK.
181 //
182 // See #26759 for more details and discussions.
183 //
184 // TODO: make RRef an IValue, and edit createStackForSchema accordingly
185 // TODO: make RRef system messages idempotent and retry on failures.
186 //
187 // ``RRef`` is the base type for both ``UserRRef`` and ``OwnerRRef``.
188 // Each ``RRef`` has a globally unique ``RRefId``.
189 class TORCH_API RRef : public RRefInterface {
190  public:
191   // RRef is made NOT copyable NOT movable to prevent messing up reference
192   // counting.
193   explicit RRef(const RRef& other) = delete;
194   explicit RRef(RRef&& other) = delete;
195   RRef& operator=(RRef&& other) = delete;
196 
197   ~RRef() override = default;
198 
199   // returns the worker id of the owner
owner()200   inline worker_id_t owner() const override {
201     return ownerId_;
202   }
203 
204   // returns the worker name of the owner
ownerName()205   inline std::string ownerName() const override {
206     return RpcAgent::getCurrentRpcAgent()->getWorkerInfo(ownerId_).name_;
207   }
208 
209   // returns the worker info of the owner
ownerWorkerInfo()210   inline WorkerInfo ownerWorkerInfo() const {
211     return RpcAgent::getCurrentRpcAgent()->getWorkerInfo(ownerId_);
212   }
213 
214   // Returns the globally unique RRefId of this RRef
rrefId()215   inline const RRefId& rrefId() const {
216     return rrefId_;
217   }
218 
isPyObj()219   inline bool isPyObj() const {
220     return type_ == PyObjectType::get();
221   }
type()222   inline const TypePtr type() const override {
223     return type_;
224   }
225 
226   // Save the future corresponding to the creation of this RRef on a remote
227   // node. Note that this is only set when processing requests invoked with
228   // rpc.remote. This is only used to get the future corresponding to the rref
229   // for profiling use cases.
registerOwnerCreationFuture(c10::intrusive_ptr<JitFuture> fut)230   inline void registerOwnerCreationFuture(c10::intrusive_ptr<JitFuture> fut) {
231     ownerCreationFuture_ = std::move(fut);
232   }
233 
234   // Get the future corresponding to the creation of this rref.
getOwnerCreationFuture()235   inline c10::intrusive_ptr<JitFuture> getOwnerCreationFuture() const {
236     return ownerCreationFuture_;
237   }
238 
239   // Check if creation of this RRef on owner node has timed out.
getTimedOut()240   inline bool getTimedOut() const {
241     return timedOut_.load();
242   }
243 
244   // Dispatches an error to the correct handler based on its RPCErrorType.
245   void handleError(RPCErrorType errorType, const JitFuture& JitFuture);
246 
247   // Send delete UserRRef request to Owner,
248   // if the request hasn't been sent yet.
249   // There are 2 cases to call it,
250   // 1, Python GC decides end of UserRRef lifetime, calling destructor.
251   // 2, RPC module graceful shutdown calls it on all UserRRefs tracked
252   //    in the RRefContext.
tryDel()253   virtual void tryDel() {}
254 
255  protected:
256   // Indicates that the creation of this RRef on owner node has timed out.
setTimedOut()257   inline void setTimedOut() {
258     timedOut_ = true;
259   }
260   friend class RRefContext;
261 
262   RRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type);
263 
264   virtual RRefForkData fork() const;
265 
266   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
267   const worker_id_t ownerId_;
268   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
269   const RRefId rrefId_;
270   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
271   std::atomic<bool> timedOut_{false};
272 
273   // type field to denote the type of the element that the RRef is holding
274   // it could be any TypePtr that JIT support, including PyObjectType
275   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
276   const TypePtr type_;
277   // Future corresponding to request to create RRef on remote node.
278   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
279   c10::intrusive_ptr<JitFuture> ownerCreationFuture_;
280 };
281 
282 // ``UserRRef`` represents a user of an RRef. Besides the ``RRefId``, each user
283 // also has a globally unique ``ForkId`` to identify this user. ``UserRRef``
284 // never owns the real value, the only way to get the value of the ``RRef`` is
285 // to call ``to_here()`` and get a copy..
286 class TORCH_API UserRRef final : public RRef {
287  public:
288   UserRRef(const UserRRef& other) = delete;
289   UserRRef(UserRRef&& other) = delete;
290   UserRRef& operator=(const UserRRef& other) = delete;
291   UserRRef& operator=(UserRRef&& other) = delete;
292 
293   UserRRef(
294       worker_id_t ownerId,
295       const RRefId& rrefId,
296       const ForkId& forkId,
297       TypePtr type);
298 
isOwner()299   inline bool isOwner() const override {
300     return false;
301   }
302 
confirmedByOwner()303   inline bool confirmedByOwner() const override {
304     return confirmedByOwner_;
305   }
306 
307   // Returns the globally unique ForkId of this RRef
308   const ForkId& forkId() const;
309 
310   // Get of copy of the value from the ``OwnerRRef``. If the value is not ready
311   // yet, this call will block.
312   IValue toHere(
313       const float timeoutSeconds =
314           torch::distributed::rpc::kUnsetRpcTimeout) const;
315 
316   void tryDel() override;
317 
318   // Will be called when refcount reaches 0.
319   // Upon destruction, this ``UserRRef`` will tell the owner to deref.
320   void release_resources() override;
321 
322   // Will be called when both refcount and weakcount reach 0. See
323   // https://github.com/pytorch/pytorch/blob/9116f02bebf3a5260feef5732d36c54ecb3b4033/c10/util/intrusive_ptr.h#L204
324   // This is called on destructing the wrapping intrusive_ptr_target instance
325   // and it's data members.
326   ~UserRRef() override;
327 
328  private:
329   friend class RRefContext;
330 
331   RRefForkData fork() const override;
confirm()332   inline void confirm() {
333     confirmedByOwner_ = true;
334   }
335 
336   const ForkId forkId_;
337 
338   // Indicates if this user has sent delete message to it's owner.
339   // Note, thread safety is needed because delete message could be sent by
340   // either the destructor called by Python garbage collection or RRefContext
341   // proactive cleanup on RPC graceful shutdown.
342   std::mutex deletedOnOwnerMutex_;
343   bool deletedOnOwner_{false};
344   // Indicating whether this UserRRef has been confirmed by its owner.
345   std::atomic<bool> confirmedByOwner_;
346 };
347 
348 // Keep the template only on the derived class because ``RRefContext`` needs to
349 // erase the type on ``RRef`` and keep them in one map.
350 class TORCH_API OwnerRRef final : public RRef {
351  public:
352   OwnerRRef(const OwnerRRef& other) = delete;
353   OwnerRRef(OwnerRRef&& other) = delete;
354   OwnerRRef& operator=(const OwnerRRef& other) = delete;
355   OwnerRRef& operator=(OwnerRRef&& other) = delete;
356 
357   OwnerRRef(
358       worker_id_t ownerId,
359       const RRefId& rrefId,
360       TypePtr type,
361       std::vector<c10::Device> devices);
362 
363   OwnerRRef(
364       worker_id_t ownerId,
365       const RRefId& rrefId,
366       TypePtr type,
367       std::optional<IValue> value,
368       std::vector<c10::Device> devices);
369 
isOwner()370   inline bool isOwner() const override {
371     return true;
372   }
373 
374   // OwnerRRef is always confirmed, while UserRRef is only confirmed when the
375   // owner knows about it.
confirmedByOwner()376   inline bool confirmedByOwner() const override {
377     return true;
378   }
379 
380   // Get a constant reference of the real value. This method will block if the
381   // value is not ready. This method does not need GIL as it does not create
382   // any new py::object. It will throw if there is an error.
383   const IValue& getValue() const;
384 
385   // Set the value of this ``OwnerRRef``. This method does not need GIL as it
386   // does not create any new py::object.
387   void setValue(IValue&& value);
388   // Sets the value of this ``OwnerRRef`` to contain an exception.
389   void setError(std::exception_ptr eptr);
390 
391   // Has a value or error been set?
392   bool hasValue() const;
393   // Gets a future that is satisfied when the value or error is set.
394   c10::intrusive_ptr<JitFuture> getFuture();
395 
396  private:
397   friend class RRefContext;
398 
399   c10::intrusive_ptr<JitFuture> future_;
400 };
401 
402 TORCH_API std::ostream& operator<<(std::ostream& os, const RRef& rref);
403 
404 // Helper function that casts from c10::RRefInterface to OwnerRRef
fromRRefInterface(const c10::intrusive_ptr<c10::RRefInterface> & rrefInterface)405 inline TORCH_API c10::intrusive_ptr<OwnerRRef> fromRRefInterface(
406     const c10::intrusive_ptr<c10::RRefInterface>& rrefInterface) {
407   return c10::static_intrusive_pointer_cast<OwnerRRef>(rrefInterface);
408 }
409 
410 // Helper function that casts from OwnerRRef to c10::RRefInterface
fromOwnerRRef(const c10::intrusive_ptr<RRef> & ownerRRef)411 inline TORCH_API c10::intrusive_ptr<c10::RRefInterface> fromOwnerRRef(
412     const c10::intrusive_ptr<RRef>& ownerRRef) {
413   return c10::static_intrusive_pointer_cast<c10::RRefInterface>(ownerRRef);
414 }
415 
416 } // namespace torch::distributed::rpc
417