xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
17 
18 #include <unordered_set>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/core/common_runtime/copy_tensor.h"
23 #include "tensorflow/core/common_runtime/device.h"
24 #include "tensorflow/core/common_runtime/device_mgr.h"
25 #include "tensorflow/core/common_runtime/dma_helper.h"
26 #include "tensorflow/core/common_runtime/process_util.h"
27 #include "tensorflow/core/distributed_runtime/worker_cache.h"
28 #include "tensorflow/core/distributed_runtime/worker_interface.h"
29 #include "tensorflow/core/framework/cancellation.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/strings/numbers.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/logging.h"
37 #include "tensorflow/core/platform/mutex.h"
38 #include "tensorflow/core/platform/types.h"
39 #include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h"
40 
41 namespace tensorflow {
42 
StartAbortRendevous(Rendezvous * rendez,const Status & s)43 static void StartAbortRendevous(Rendezvous* rendez, const Status& s) {
44   rendez->StartAbort(s);
45   rendez->Unref();
46 }
47 
BaseRendezvousMgr(const WorkerEnv * worker_env)48 BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env)
49     : worker_env_(worker_env) {}
50 
~BaseRendezvousMgr()51 BaseRendezvousMgr::~BaseRendezvousMgr() {
52   for (auto& p : table_) {
53     auto rendez = p.second;
54     StartAbortRendevous(rendez, errors::Aborted("Shutdown"));
55   }
56 }
57 
Find(int64_t step_id)58 RemoteRendezvous* BaseRendezvousMgr::Find(int64_t step_id) {
59   return FindOrCreate(step_id);
60 }
61 
FindOrCreate(int64_t step_id)62 BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64_t step_id) {
63   mutex_lock l(mu_);
64   auto iter = table_.find(step_id);
65   if (iter == table_.end()) {
66     auto rr = Create(step_id, worker_env_);
67     iter = table_.insert({step_id, rr}).first;
68   }
69   iter->second->Ref();
70   return iter->second;
71 }
72 
RecvLocalAsync(int64_t step_id,const Rendezvous::ParsedKey & parsed,Rendezvous::DoneCallback done)73 void BaseRendezvousMgr::RecvLocalAsync(int64_t step_id,
74                                        const Rendezvous::ParsedKey& parsed,
75                                        Rendezvous::DoneCallback done) {
76   auto rendez = FindOrCreate(step_id);
77   auto done_cb = [rendez, done = std::move(done)](
78                      const Status& s, const Rendezvous::Args& send_args,
79                      const Rendezvous::Args& recv_args, const Tensor& v,
80                      bool dead) {
81     rendez->Unref();
82     done(s, send_args, recv_args, v, dead);
83   };
84   rendez->RecvLocalAsync(parsed, std::move(done_cb));
85 }
86 
RecvLocal(int64_t step_id,const Rendezvous::ParsedKey & parsed,Tensor * val,bool * is_dead)87 Status BaseRendezvousMgr::RecvLocal(int64_t step_id,
88                                     const Rendezvous::ParsedKey& parsed,
89                                     Tensor* val, bool* is_dead) {
90   Status ret;
91   Notification n;
92   RecvLocalAsync(step_id, parsed,
93                  [val, is_dead, &ret, &n](const Status& s,
94                                           const Rendezvous::Args& send_args,
95                                           const Rendezvous::Args& recv_args,
96                                           const Tensor& v, const bool dead) {
97                    ret = s;
98                    *val = v;
99                    *is_dead = dead;
100                    n.Notify();
101                  });
102   n.WaitForNotification();
103   return ret;
104 }
105 
Cleanup(int64_t step_id)106 void BaseRendezvousMgr::Cleanup(int64_t step_id) {
107   Rendezvous* rendez = nullptr;
108   {
109     mutex_lock l(mu_);
110     auto iter = table_.find(step_id);
111     if (iter != table_.end()) {
112       rendez = iter->second;
113       table_.erase(iter);
114     }
115   }
116   if (rendez) {
117     StartAbortRendevous(rendez, errors::Aborted("Cleanup ", step_id));
118   }
119 }
120 
CleanupAll()121 void BaseRendezvousMgr::CleanupAll() {
122   mutex_lock l(mu_);
123   for (auto iter = table_.begin(); iter != table_.end(); iter++) {
124     iter->second->Unref();
125   }
126 }
127 
BaseRemoteRendezvous(const WorkerEnv * env,int64_t step_id)128 BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env,
129                                            int64_t step_id)
130     : env_(env),
131       step_id_(step_id),
132       local_(NewLocalRendezvous()),
133       session_(nullptr) {}
134 
~BaseRemoteRendezvous()135 BaseRemoteRendezvous::~BaseRemoteRendezvous() {
136   {
137     mutex_lock l(calls_mu_);
138     calls_.clear();
139   }
140   local_->Unref();
141 }
142 
143 // Returns true if "device_name" is a valid full name of local device
144 // of the "worker".  This helper is purely based on the worker name
145 // and device name and does no lookups in the worker->device_mgr.
IsLocalDevice(const StringPiece worker_name,const StringPiece device_name)146 static bool IsLocalDevice(const StringPiece worker_name,
147                           const StringPiece device_name) {
148   return absl::StartsWith(device_name, worker_name);
149 }
150 
Initialize(WorkerSession * session)151 Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
152   CHECK_NE(session, nullptr) << "session must not be null!";
153   std::vector<DeferredCall> deferred_calls;
154   {
155     mutex_lock l(mu_);
156     if (session_ != nullptr) {
157       if (session_->worker_name() == session->worker_name()) {
158         VLOG(1) << "Skipping rendezvous re-initialization.";
159         return OkStatus();
160       }
161       Status s = errors::Internal(
162           "Double init! Worker names would have changed from: ",
163           session_->worker_name(), " -> ", session->worker_name());
164       LOG(WARNING) << s;
165       return s;
166     }
167     session_ = session;
168     std::swap(deferred_calls, deferred_calls_);
169   }
170   for (auto& call : deferred_calls) {
171     RecvLocalAsyncInternal(call.parsed, std::move(call.done));
172   }
173   return OkStatus();
174 }
175 
session()176 WorkerSession* BaseRemoteRendezvous::session() {
177   tf_shared_lock l(mu_);
178   return session_;
179 }
180 
is_initialized()181 bool BaseRemoteRendezvous::is_initialized() {
182   tf_shared_lock l(mu_);
183   return is_initialized_locked();
184 }
185 
Send(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & args,const Tensor & val,const bool is_dead)186 Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
187                                   const Rendezvous::Args& args,
188                                   const Tensor& val, const bool is_dead) {
189   VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey();
190   WorkerSession* sess = nullptr;
191   {
192     tf_shared_lock l(mu_);
193     if (!status_.ok()) return status_;
194     DCHECK(is_initialized_locked());
195     sess = session_;
196   }
197 
198   if (!IsLocalDevice(sess->worker_name(), parsed.src_device)) {
199     return errors::InvalidArgument(
200         "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
201         sess->worker_name());
202   }
203 
204   // Buffers "val" and "device_context" in local_.
205   return local_->Send(parsed, args, val, is_dead);
206 }
207 
ValidateDevices(const ParsedKey & parsed,bool is_src)208 Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
209                                              bool is_src) {
210   // Cache session pointer to avoid repeatedly taking & releasing the lock
211   // (e.g. calling session())
212   WorkerSession* sess = nullptr;
213   {
214     tf_shared_lock l(mu_);
215     if (!status_.ok()) return status_;
216     if (!is_initialized_locked()) {
217       return errors::Internal("ValidateDevices called before initialization.");
218     }
219     sess = session_;
220   }
221   if (is_src && !IsLocalDevice(sess->worker_name(), parsed.src_device)) {
222     return errors::InvalidArgument(
223         "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
224         sess->worker_name());
225   }
226   if (!is_src && !IsLocalDevice(sess->worker_name(), parsed.dst_device)) {
227     return errors::InvalidArgument(
228         "Invalid rendezvous key (dst): ", parsed.FullKey(), " @ ",
229         sess->worker_name());
230   }
231   return OkStatus();
232 }
233 
SameWorkerRecvDone(const Rendezvous::ParsedKey & parsed,const Rendezvous::Args & send_args,const Rendezvous::Args & recv_args,const Tensor & in,Tensor * out,StatusCallback done)234 void BaseRemoteRendezvous::SameWorkerRecvDone(
235     const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
236     const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
237     StatusCallback done) {
238   // Do a quick copy (sharing the underlying buffer) if both tensors
239   // are on host memory.
240   const bool src_host =
241       (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
242   const bool dst_host =
243       (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
244   if (src_host && dst_host) {
245     *out = in;
246     done(OkStatus());
247     return;
248   }
249 
250   // This copy must involve a GPU. Hence, "in" must support DMA
251   // (e.g., string tensors do not work on GPU).  Variant copy DMA
252   // checks happen inside CopyTensor::ViaDMA.
253   if (!DMAHelper::CanUseDMA(&in) && in.dtype() != DT_VARIANT &&
254       in.dtype() != DT_RESOURCE) {
255     done(errors::InvalidArgument(
256         "Non-DMA-safe ", DataTypeString(in.dtype()),
257         " tensor may not be copied from/to a device. Key: ", parsed.FullKey()));
258     return;
259   }
260 
261   WorkerSession* sess = session();
262   Device* src_device;
263   Status s = sess->device_mgr()->LookupDevice(parsed.src_device, &src_device);
264   if (!s.ok()) {
265     done(s);
266     return;
267   }
268   Device* dst_device;
269   s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
270   if (!s.ok()) {
271     done(s);
272     return;
273   }
274 
275   profiler::ScopedMemoryDebugAnnotation op_annotation(
276       "SameWorkerRecvDone", step_id_, "dynamic", in.dtype(),
277       [&in]() { return in.shape().DebugString(); });
278   AllocatorAttributes attr = recv_args.alloc_attrs;
279   attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
280                           recv_args.alloc_attrs.gpu_compatible());
281   Allocator* out_allocator = dst_device->GetAllocator(attr);
282   AllocationAttributes allocation_attr;
283   uint64 safe_alloc_frontier = dst_device->SafeAllocFrontier(0);
284   bool sync_dst_compute = (safe_alloc_frontier == 0);
285   std::function<uint64()> freed_by_func = [dst_device, &safe_alloc_frontier]() {
286     safe_alloc_frontier = dst_device->SafeAllocFrontier(safe_alloc_frontier);
287     return safe_alloc_frontier;
288   };
289   if (!sync_dst_compute) {
290     allocation_attr.freed_by_func = &freed_by_func;
291   }
292   if (in.dtype() != DT_VARIANT) {
293     // Variants are handled by CopyTensor::ViaDMA.
294     Tensor copy(out_allocator, in.dtype(), in.shape(), allocation_attr);
295     *out = copy;
296   }
297 
298   // The following function takes care of cpu->gpu, gpu->cpu, gpu->gpu copies,
299   // etc.
300   CopyTensor::ViaDMA(
301       parsed.edge_name, send_args.device_context, recv_args.device_context,
302       src_device, dst_device, send_args.alloc_attrs, recv_args.alloc_attrs, &in,
303       out, 0 /*dev_to_dev_stream_index*/, std::move(done), sync_dst_compute);
304 }
305 
IsSameWorker(DeviceNameUtils::ParsedName src,DeviceNameUtils::ParsedName dst)306 bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
307                                         DeviceNameUtils::ParsedName dst) {
308   return DeviceNameUtils::IsSameAddressSpace(src, dst);
309 }
310 
RecvAsync(const ParsedKey & parsed,const Rendezvous::Args & recv_args,DoneCallback done)311 void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
312                                      const Rendezvous::Args& recv_args,
313                                      DoneCallback done) {
314   VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
315   Status s = ValidateDevices(parsed, false /*!is_src*/);
316   if (!s.ok()) {
317     done(s, Args(), recv_args, Tensor(), false);
318     return;
319   }
320 
321   // ValidateDevices() returns an error status if the rendezvous is not
322   // initialized.
323   DCHECK(is_initialized()) << "RecvAsync called when uninitialized (key: "
324                            << parsed.FullKey() << ").";
325 
326   profiler::ScopedMemoryDebugAnnotation op_annotation("RecvAsync", step_id_);
327   // Are src and dst in the same worker?
328   if (IsSameWorker(parsed.src, parsed.dst)) {
329     // Recv the tensor from local_.
330     local_->RecvAsync(
331         parsed, recv_args,
332         [this, parsed, done](
333             const Status& status, const Rendezvous::Args& send_args,
334             const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
335           VLOG(2) << "RemoteRendezvous Finished Recv " << this << " "
336                   << parsed.FullKey();
337           Tensor* out = new Tensor;
338           StatusCallback final_callback = [done, send_args, recv_args, out,
339                                            is_dead](const Status& s) {
340             done(s, send_args, recv_args, *out, is_dead);
341             delete out;
342           };
343 
344           if (status.ok()) {
345             SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
346                                std::move(final_callback));
347           } else {
348             final_callback(status);
349           }
350         });
351     return;
352   } else {
353     RecvFromRemoteAsync(parsed, recv_args, std::move(done));
354   }
355 }
356 
RecvLocalAsync(const ParsedKey & parsed,DoneCallback done)357 void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
358                                           DoneCallback done) {
359   // Test whether the rendezvous is initialized using a shared lock, to avoid
360   // the need for exclusive access in the common case.
361   if (TF_PREDICT_FALSE(!is_initialized())) {
362     mutex_lock l(mu_);
363     if (!is_initialized_locked()) {
364       // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
365       // remote worker) before the RunStep (or PartialRunStep) RPC from the
366       // master arrives. RecvLocalAsync thus buffers the arguments until after
367       // the RemoteRendezvous is Initialize()'d, when it completes the
368       // rendezvous logic. At some point after Initialize() is called, a Tensor
369       // is produced locally that will then be sent in response to the incoming
370       // RPC.
371       DeferredCall call(parsed, std::move(done));
372       deferred_calls_.push_back(call);
373       return;
374     }
375   }
376   RecvLocalAsyncInternal(parsed, std::move(done));
377 }
378 
RecvLocalAsyncInternal(const ParsedKey & parsed,DoneCallback done)379 void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
380                                                   DoneCallback done) {
381   Status s = ValidateDevices(parsed, true /* is_src */);
382   if (!s.ok()) {
383     done(s, Args(), Args(), Tensor(), false);
384     return;
385   }
386   local_->RecvAsync(parsed, Args(), std::move(done));
387 }
388 
StartAbort(const Status & s)389 void BaseRemoteRendezvous::StartAbort(const Status& s) {
390   CHECK(!s.ok());
391   // If the status passed in is a cancelled or aborted error, mark it as
392   // "derived" for the rendezvous. Derived status messages are ignored when
393   // aggregating errors across devices: this allows us to prefer our original
394   // status message over any cancellation related errors.
395   Status derived_status = s;
396   if (errors::IsCancelled(s) || errors::IsAborted(s)) {
397     derived_status = StatusGroup::MakeDerived(s);
398   }
399 
400   local_->StartAbort(derived_status);
401 
402   bool status_ok = false;
403   {
404     mutex_lock l(mu_);
405     status_ok = status_.ok();
406     if (status_ok) {
407       status_ = derived_status;
408     }
409   }
410 
411   if (status_ok) {
412     // Aborts all active RecvTensor calls.
413     mutex_lock l(calls_mu_);
414     for (auto& cm_and_token_and_calls : calls_) {
415       for (auto& call : cm_and_token_and_calls.second.second) {
416         call->StartAbort(derived_status);
417       }
418       auto* cm = cm_and_token_and_calls.first;
419       calls_[cm].second.clear();
420     }
421     calls_.clear();
422   }
423 }
424 
RegisterCall(BaseRecvTensorCall * call,const Rendezvous::Args & args)425 void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call,
426                                         const Rendezvous::Args& args) {
427   CancellationManager* cm = args.cancellation_manager;
428   bool already_cancelled = false;
429   {
430     tf_shared_lock l(mu_);
431     if (!status_.ok()) {
432       call->StartAbort(status_);
433       return;
434     }
435   }
436 
437   CancellationToken token = CancellationManager::kInvalidToken;
438   if (cm != nullptr) {
439     mutex_lock l(calls_mu_);
440     auto it = calls_.find(cm);
441     if (it == calls_.end()) {
442       token = cm->get_cancellation_token();
443       already_cancelled = !cm->RegisterCallback(token, [this, cm]() {
444         mutex_lock l(calls_mu_);
445         // Abort all the RecvTensor calls associated with thie cancellation
446         // manager.
447         for (const auto& call : calls_[cm].second) {
448           call->StartAbort(
449               errors::Cancelled("RecvFromRemoteAsync is cancelled."));
450         }
451       });
452 
453       if (!already_cancelled) {
454         calls_.emplace(
455             cm,
456             std::make_pair(token, absl::flat_hash_set<BaseRecvTensorCall*>{}));
457       }
458     }
459   }
460 
461   if (already_cancelled) {
462     call->StartAbort(errors::Cancelled("RecvFromRemoteAsync is cancelled."));
463   } else {
464     mutex_lock l(calls_mu_);
465     bool emplaced = calls_[cm].second.emplace(call).second;
466     CHECK(emplaced);  // Crash OK.
467   }
468 }
469 
DeregisterCall(BaseRecvTensorCall * call,const Rendezvous::Args & args)470 void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call,
471                                           const Rendezvous::Args& args) {
472   auto cm = args.cancellation_manager;
473   mutex_lock l(calls_mu_);
474   CancellationToken token = calls_[cm].first;
475   calls_[cm].second.erase(call);
476   if (calls_[cm].second.empty()) {
477     calls_.erase(cm);
478     if (cm != nullptr) {
479       cm->TryDeregisterCallback(token);
480     }
481   }
482 }
483 
DeferredCall(const ParsedKey & parsed,DoneCallback done)484 BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed,
485                                                  DoneCallback done)
486     : parsed(parsed), done(std::move(done)) {}
487 
488 }  // end namespace tensorflow
489