xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/local_rendezvous.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/local_rendezvous.h"
17 
18 #include "tensorflow/core/framework/allocator.h"
19 #include "tensorflow/core/framework/types.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/core/notification.h"
22 #include "tensorflow/core/lib/gtl/manual_constructor.h"
23 #include "tensorflow/core/lib/monitoring/counter.h"
24 #include "tensorflow/core/lib/strings/numbers.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/platform/refcount.h"
29 #include "tensorflow/core/platform/types.h"
30 
31 namespace tensorflow {
32 
33 // Represents a blocked Send() or Recv() call in the rendezvous.
34 struct LocalRendezvous::Item {
35   enum Type { kSend = 0, kRecv = 1 };
36 
Itemtensorflow::LocalRendezvous::Item37   Item(Rendezvous::Args send_args, const Tensor& value, bool is_dead)
38       : Item(send_args, kSend) {
39     send_state.value.Init(value);
40     send_state.is_dead = is_dead;
41   }
42 
Itemtensorflow::LocalRendezvous::Item43   Item(Rendezvous::Args recv_args, Rendezvous::DoneCallback waiter,
44        CancellationToken cancellation_token)
45       : Item(recv_args, kRecv) {
46     recv_state.waiter.Init(std::move(waiter));
47     recv_state.cancellation_token = cancellation_token;
48   }
49 
~Itemtensorflow::LocalRendezvous::Item50   ~Item() {
51     if (args.device_context) {
52       args.device_context->Unref();
53     }
54     if (type == kSend) {
55       send_state.value.Destroy();
56     } else {
57       recv_state.waiter.Destroy();
58     }
59   }
60 
61   const Rendezvous::Args args;
62   const Type type;
63 
64   // Link to next item in an ItemQueue.
65   Item* next = nullptr;
66 
67   // The validity of `send_state` or `recv_state` is determined by `type ==
68   // kSend` or `type == kRecv` respectively.
69   union {
70     struct {
71       ManualConstructor<Tensor> value;
72       bool is_dead;
73     } send_state;
74     struct {
75       ManualConstructor<Rendezvous::DoneCallback> waiter;
76       CancellationToken cancellation_token;
77     } recv_state;
78   };
79 
80  private:
Itemtensorflow::LocalRendezvous::Item81   Item(Rendezvous::Args args, Type type) : args(args), type(type) {
82     if (args.device_context) {
83       args.device_context->Ref();
84     }
85   }
86 };
87 
push_back(Item * item)88 void LocalRendezvous::ItemQueue::push_back(Item* item) {
89   if (TF_PREDICT_TRUE(head == nullptr)) {
90     // The queue is empty.
91     head = item;
92     tail = item;
93   } else {
94     DCHECK_EQ(tail->type, item->type);
95     tail->next = item;
96     tail = item;
97   }
98 }
99 
~LocalRendezvous()100 LocalRendezvous::~LocalRendezvous() {
101   // Before destroying this rendezvous instance, make sure all the done-callback
102   // calls have finished and the tensors have been released from the queue.
103   {
104     mutex_lock l(mu_);
105     while (pending_callback_counter_ != 0) {
106       pending_callback_cond_var_.wait_for(l, std::chrono::milliseconds(50));
107     }
108   }
109 
110   if (!table_.empty()) {
111     StartAbort(errors::Cancelled("LocalRendezvous deleted"));
112   }
113 }
114 
115 namespace {
KeyHash(const StringPiece & k)116 uint64 KeyHash(const StringPiece& k) { return Hash64(k.data(), k.size()); }
117 }  // namespace
118 
Send(const Rendezvous::ParsedKey & key,const Rendezvous::Args & send_args,const Tensor & val,const bool is_dead)119 Status LocalRendezvous::Send(const Rendezvous::ParsedKey& key,
120                              const Rendezvous::Args& send_args,
121                              const Tensor& val, const bool is_dead) {
122   uint64 key_hash = KeyHash(key.FullKey());
123   DVLOG(2) << "Send " << this << " " << key_hash << " " << key.FullKey();
124 
125   if (is_dead) {
126     static auto* rendezvous_dead_values_sent = monitoring::Counter<2>::New(
127         "/tensorflow/core/rendezvous_dead_values_sent",
128         "The number of dead values sent between a pair of devices.",
129         "send_device", "recv_device");
130     rendezvous_dead_values_sent
131         ->GetCell(string(key.src_device), string(key.dst_device))
132         ->IncrementBy(1);
133   }
134 
135   mu_.lock();
136   if (!status_.ok()) {
137     // Rendezvous has been aborted.
138     Status s = status_;
139     mu_.unlock();
140     return s;
141   }
142 
143   ItemQueue* queue = &table_[key_hash];
144   if (queue->head == nullptr || queue->head->type == Item::kSend) {
145     // There is no waiter for this message. Append the message
146     // into the queue. The waiter will pick it up when arrives.
147     // Only send-related fields need to be filled.
148     // TODO(b/143786186): Investigate moving the allocation of `Item` outside
149     // the lock.
150     DVLOG(2) << "Enqueue Send Item (key:" << key.FullKey() << "). ";
151     queue->push_back(new Item(send_args, val, is_dead));
152     mu_.unlock();
153     return OkStatus();
154   }
155 
156   DVLOG(2) << "Consume Recv Item (key:" << key.FullKey() << "). ";
157   // There is an earliest waiter to consume this message.
158   Item* item = queue->head;
159 
160   // Delete the queue when the last element has been consumed.
161   if (item->next == nullptr) {
162     DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). ";
163     table_.erase(key_hash);
164   } else {
165     queue->head = item->next;
166   }
167 
168   // Make sure the ref-count of the rendezvous won't reach 0 while the
169   // done_callback is running, which would otherwise become deadlock:
170   // the done_callback waits for the Unref() to return, while the destructor
171   // wiats for the pending_callback_counter to reach 0.
172   core::RefCountPtr<const Rendezvous> rc_owner_ref;
173   if (rc_owner_) {
174     rc_owner_ref.reset(rc_owner_);
175     rc_owner_->Ref();
176   }
177   pending_callback_counter_++;
178   // Invoke the done-callback, without holding the lock.
179   mu_.unlock();
180   DCHECK_EQ(item->type, Item::kRecv);
181   (*item->recv_state.waiter)(OkStatus(), send_args, item->args, val, is_dead);
182   delete item;
183   {
184     mutex_lock l(mu_);
185     pending_callback_counter_--;
186     if (pending_callback_counter_ == 0) {
187       pending_callback_cond_var_.notify_all();
188     }
189   }
190   return OkStatus();
191 }
192 
RecvAsync(const Rendezvous::ParsedKey & key,const Rendezvous::Args & recv_args,Rendezvous::DoneCallback done)193 void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
194                                 const Rendezvous::Args& recv_args,
195                                 Rendezvous::DoneCallback done) {
196   uint64 key_hash = KeyHash(key.FullKey());
197   DVLOG(2) << "Recv " << this << " " << key_hash << " " << key.FullKey();
198 
199   mu_.lock();
200   if (!status_.ok()) {
201     // Rendezvous has been aborted.
202     Status s = status_;
203     mu_.unlock();
204     done(s, Rendezvous::Args(), recv_args, Tensor(), false);
205     return;
206   }
207 
208   ItemQueue* queue = &table_[key_hash];
209   if (queue->head == nullptr || queue->head->type == Item::kRecv) {
210     // There is no message to pick up.
211     // Only recv-related fields need to be filled.
212     CancellationManager* cm = recv_args.cancellation_manager;
213     CancellationToken token = CancellationManager::kInvalidToken;
214     bool already_cancelled = false;
215     if (cm != nullptr) {
216       // Increment the refcount when cancellation manager is present, to make
217       // sure the rendezvous outlives the recv and its cancel callbacks.
218       // This refcount is dropped in exactly one of the following cases:
219       // (1) Recv registers cancellation callback to cm, and then cm is
220       //     cancelled, unref in the cancellation callback;
221       // (2) Recv registers cancellation callback to cm, but cm is already
222       //     cancelled, unref in the already_cancelled check;
223       // (3) Recv is successful, and item done callback finishes deregistering
224       //     the cancellation callback, unref in the item done callback;
225       // (4) Recv is successful, but the item done callback fails to deregister
226       //     the cancellation callback because cm already StartCancel, in this
227       //     case the cancellation callback will be invoked by the cm anyway,
228       //     unref in the cancellation callback.
229       if (rc_owner_) rc_owner_->Ref();
230       token = cm->get_cancellation_token();
231       already_cancelled = !cm->RegisterCallback(token, [this, token, key_hash] {
232         Item* item = nullptr;
233         {
234           mutex_lock l(mu_);
235           ItemQueue* queue = &table_[key_hash];
236           // Find an item in the queue with a cancellation token that matches
237           // `token`, and remove it.
238           if (queue->head != nullptr && queue->head->type == Item::kRecv) {
239             for (Item *prev = nullptr, *curr = queue->head; curr != nullptr;
240                  prev = curr, curr = curr->next) {
241               if (curr->recv_state.cancellation_token == token) {
242                 item = curr;
243                 if (queue->head->next == nullptr) {
244                   // We have a single-element queue, so we can erase it from
245                   // the table.
246                   table_.erase(key_hash);
247                 } else {
248                   // Remove the current item from the queue.
249                   if (curr == queue->head) {
250                     DCHECK_EQ(prev, nullptr);
251                     queue->head = curr->next;
252                   } else {
253                     DCHECK_NE(prev, nullptr);
254                     prev->next = curr->next;
255                   }
256                   if (queue->tail == curr) {
257                     queue->tail = prev;
258                   }
259                 }
260                 break;
261               }
262             }
263           }
264         }
265 
266         if (item != nullptr) {
267           (*item->recv_state.waiter)(
268               StatusGroup::MakeDerived(
269                   errors::Cancelled("RecvAsync is cancelled.")),
270               Rendezvous::Args(), item->args, Tensor(), /*is_dead=*/false);
271           delete item;
272         }
273         // Unref case (1) and (4)
274         if (rc_owner_) rc_owner_->Unref();
275       });
276     }
277     if (already_cancelled) {
278       mu_.unlock();
279       // Unref case (2)
280       if (rc_owner_) rc_owner_->Unref();
281       done(StatusGroup::MakeDerived(
282                errors::Cancelled("RecvAsync is cancelled.")),
283            Rendezvous::Args(), recv_args, Tensor(), /*is_dead=*/false);
284       return;
285     }
286 
287     DVLOG(2) << "Enqueue Recv Item (key:" << key.FullKey() << "). ";
288 
289     // TODO(b/143786186): Investigate moving the allocation of `Item` outside
290     // the lock.
291     if (cm != nullptr) {
292       // NOTE(mrry): We must wrap `done` with code that deregisters the
293       // cancellation callback before calling the `done` callback, because the
294       // cancellation manager may no longer be live after `done` is called.
295       queue->push_back(new Item(
296           recv_args,
297           [this, cm, token, done = std::move(done)](
298               const Status& s, const Rendezvous::Args& send_args,
299               const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
300             // TryDeregisterCallback returns true when the cancellation callback
301             // is successfully deregistered. If it fails because the CM already
302             // StartAbort, Unref will happen inside the cancellation callback
303             // when called by the CM.
304             if (cm->TryDeregisterCallback(token)) {
305               // Unref case (3)
306               if (this->rc_owner_) this->rc_owner_->Unref();
307             }
308             done(s, send_args, recv_args, v, dead);
309           },
310           token));
311     } else {
312       queue->push_back(new Item(recv_args, std::move(done), token));
313     }
314 
315     mu_.unlock();
316     return;
317   }
318 
319   DVLOG(2) << "Consume Send Item (key:" << key.FullKey() << "). ";
320   // A message has already arrived and is queued in the table under
321   // this key.  Consumes the message and invokes the done closure.
322   Item* item = queue->head;
323 
324   // Delete the queue when the last element has been consumed.
325   if (item->next == nullptr) {
326     DVLOG(2) << "Clean up Send/Recv queue (key:" << key.FullKey() << "). ";
327     table_.erase(key_hash);
328   } else {
329     queue->head = item->next;
330   }
331 
332   // Make sure the ref-count of the rendezvous won't reach 0 while the
333   // done_callback is running, which would otherwise become deadlock:
334   // the done_callback waits for the Unref() to return, while the destructor
335   // wiats for the pending_callback_counter to reach 0.
336   core::RefCountPtr<const Rendezvous> rc_owner_ref;
337   if (rc_owner_) {
338     rc_owner_ref.reset(rc_owner_);
339     rc_owner_->Ref();
340   }
341   pending_callback_counter_++;
342   // Invoke the done-callback, without holding the lock.
343   mu_.unlock();
344   DCHECK_EQ(item->type, Item::kSend);
345   done(OkStatus(), item->args, recv_args, *item->send_state.value,
346        item->send_state.is_dead);
347   delete item;
348   {
349     mutex_lock l(mu_);
350     pending_callback_counter_--;
351     if (pending_callback_counter_ == 0) {
352       pending_callback_cond_var_.notify_all();
353     }
354   }
355 }
356 
StartAbort(const Status & status)357 void LocalRendezvous::StartAbort(const Status& status) {
358   CHECK(!status.ok());
359   Table table;
360   {
361     mutex_lock l(mu_);
362     status_.Update(status);
363     table_.swap(table);
364   }
365   for (auto& p : table) {
366     Item* item = p.second.head;
367     while (item != nullptr) {
368       if (item->type == Item::kRecv) {
369         (*item->recv_state.waiter)(status, Rendezvous::Args(),
370                                    Rendezvous::Args(), Tensor(), false);
371       }
372       Item* to_delete = item;
373       item = item->next;
374       delete to_delete;
375     }
376   }
377 }
378 
status()379 Status LocalRendezvous::status() {
380   mu_.lock();
381   Status s = status_;
382   mu_.unlock();
383   return s;
384 }
385 
386 }  // namespace tensorflow
387