xref: /aosp_15_r20/external/tensorflow/tensorflow/core/distributed_runtime/master.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Master implements the service MasterService.
17 //
18 // A Master maintains the state of live graph computation
19 // sessions, each session orchestrates both local and remote devices
20 // to carry out the graph computation.
21 //
22 // A Master knows ahead of time local devices available as
23 // client devices.
24 //
25 // A Master discovers remote devices on-demand and keeps track of
26 // statistics of those remote devices.
27 //
28 // Each session analyzes the graph, places nodes across available
29 // devices, and ultimately drives the graph computation by initiating
30 // RunGraph on the workers.
31 
32 #include "tensorflow/core/distributed_runtime/master.h"
33 
34 #include <unordered_set>
35 #include <vector>
36 
37 #include "tensorflow/core/common_runtime/device_set.h"
38 #include "tensorflow/core/common_runtime/process_util.h"
39 #include "tensorflow/core/distributed_runtime/remote_device.h"
40 #include "tensorflow/core/distributed_runtime/worker_cache.h"
41 #include "tensorflow/core/distributed_runtime/worker_interface.h"
42 #include "tensorflow/core/framework/graph_def_util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/notification.h"
45 #include "tensorflow/core/lib/gtl/array_slice.h"
46 #include "tensorflow/core/lib/gtl/cleanup.h"
47 #include "tensorflow/core/lib/gtl/map_util.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/platform/macros.h"
50 #include "tensorflow/core/platform/mutex.h"
51 #include "tensorflow/core/platform/regexp.h"
52 #include "tensorflow/core/platform/types.h"
53 #include "tensorflow/core/protobuf/cluster.pb.h"
54 #include "tensorflow/core/protobuf/master.pb.h"
55 #include "tensorflow/core/protobuf/worker.pb.h"
56 #include "tensorflow/core/public/session_options.h"
57 #include "tensorflow/core/util/device_name_utils.h"
58 
59 namespace tensorflow {
60 
61 namespace {
62 constexpr char kGrpcPrefixRegex[] = "^grpc.*://";
63 }  // namespace
64 
Master(MasterEnv * env,double session_gc_seconds)65 Master::Master(MasterEnv* env, double session_gc_seconds)
66     : env_(env),
67       last_1000_steps_(1000),
68       step_count_(0),
69       session_gc_seconds_(session_gc_seconds),
70       recent_request_ids_(10000) {
71   // Right now, a master service must be co-located with a device.
72   // Otherwise, fetches do not work.
73   CHECK(!env->local_devices.empty());
74 
75   if (session_gc_seconds_ > 0.0) {
76     gc_thread_ = env_->env->StartThread(ThreadOptions(), "TF_master_GC",
77                                         [this]() { GC(); });
78   } else {
79     gc_thread_ = nullptr;
80   }
81 }
82 
~Master()83 Master::~Master() {
84   if (gc_thread_) {
85     mutex_lock l(mu_);
86     shutdown_ = true;
87     shutdown_cv_.notify_all();
88     delete gc_thread_;
89   }
90 }
91 
GC()92 void Master::GC() {
93   Env* env = Env::Default();
94   while (true) {
95     mutex_lock l(mu_);
96     const int kTimeoutMilliseconds = 10 * 1000;  // 10 seconds.
97     WaitForMilliseconds(&l, &shutdown_cv_, kTimeoutMilliseconds);
98     if (shutdown_) {
99       break;
100     }
101     std::vector<string> handles;
102     const int64_t num_micros =
103         static_cast<int64_t>(session_gc_seconds_ * 1000000);
104     for (const auto& entry : sessions_) {
105       int64_t lat = entry.second->last_access_time_usec();
106       if (static_cast<int64_t>(env->NowMicros()) - lat > num_micros) {
107         handles.push_back(entry.first);
108         auto* sess = entry.second;
109         SchedClosure([this, sess]() {
110           LOG(WARNING) << "GC session " << sess->handle() << " after "
111                        << session_gc_seconds_ << " seconds.  "
112                        << "Note that if you are starting multiple replicas "
113                        << "on a staggered delay, session_gc_seconds may need "
114                        << "to be raised.";
115           sess->GarbageCollect();
116         });
117       }
118     }
119     for (const auto& handle : handles) sessions_.erase(handle);
120   }
121 }
122 
FindMasterSession(const string & handle)123 MasterSession* Master::FindMasterSession(const string& handle) {
124   MasterSession* session = nullptr;
125   {
126     mutex_lock l(mu_);
127     session = gtl::FindPtrOrNull(sessions_, handle);
128     if (session != nullptr) {
129       session->Ref();
130     }
131   }
132   return session;
133 }
134 
135 class DeviceFinder {
136  public:
GetRemoteDevices(const protobuf::RepeatedPtrField<string> & device_filters,MasterEnv * env,WorkerCacheInterface * worker_cache,std::vector<std::unique_ptr<Device>> * out_remote)137   static Status GetRemoteDevices(
138       const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
139       WorkerCacheInterface* worker_cache,
140       std::vector<std::unique_ptr<Device>>* out_remote) {
141     DeviceFinder finder(device_filters, env, worker_cache);
142     finder.Start();
143     TF_RETURN_IF_ERROR(finder.Wait());
144     finder.GetRemoteDevices(env->local_devices, out_remote);
145     return OkStatus();
146   }
147 
GetRemoteWorkers(const protobuf::RepeatedPtrField<string> & device_filters,MasterEnv * env,WorkerCacheInterface * worker_cache,std::vector<string> * workers)148   static void GetRemoteWorkers(
149       const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
150       WorkerCacheInterface* worker_cache, std::vector<string>* workers) {
151     DeviceFinder finder(device_filters, env, worker_cache);
152     *workers = finder.targets_;
153   }
154 
155  private:
DeviceFinder(const protobuf::RepeatedPtrField<string> & device_filters,MasterEnv * env,WorkerCacheInterface * worker_cache)156   explicit DeviceFinder(
157       const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
158       WorkerCacheInterface* worker_cache)
159       : env_(env), worker_cache_(worker_cache) {
160     CHECK(worker_cache) << "Worker cache was null!";
161     auto process_filter = [this](const string& filter) {
162       DeviceNameUtils::ParsedName parsed;
163       if (DeviceNameUtils::ParseFullName(filter, &parsed)) {
164         filters_.push_back(parsed);
165       } else {
166         LOG(FATAL) << "Skipping invalid filter: " << filter;
167       }
168     };
169     for (const string& filter : device_filters) {
170       process_filter(filter);
171     }
172     // Enumerates all known workers' target. A target name is a
173     // prefix of a device name. E.g., /job:mnist/replica:0/task:10.
174     if (filters_.empty()) {
175       // If no filters were specified, we list all known workers in
176       // `worker_cache`.
177       std::vector<string> workers;
178       worker_cache->ListWorkers(&workers);
179       std::swap(workers, targets_);
180     } else {
181       // When applying filters, we must include the local worker, even if it
182       // does not match any of the filters.
183       CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
184       const string& local_device_name = env_->local_devices[0]->name();
185       DeviceNameUtils::ParsedName local_parsed_name;
186       CHECK(DeviceNameUtils::ParseFullName(local_device_name,
187                                            &local_parsed_name));
188       bool all_filters_have_job = true;
189       std::unordered_set<string> filter_job_names({local_parsed_name.job});
190       for (const DeviceNameUtils::ParsedName& filter : filters_) {
191         all_filters_have_job = all_filters_have_job && filter.has_job;
192         if (filter.has_job) {
193           filter_job_names.insert(filter.job);
194         }
195       }
196 
197       std::vector<string> workers;
198       if (all_filters_have_job) {
199         // If all of the device filters have a job specified, then we only need
200         // to list the workers in the jobs named in the filter, because a worker
201         // in any other job would not match any filter.
202         for (const string& job_name : filter_job_names) {
203           VLOG(2) << "Selectively listing workers in job: " << job_name;
204           std::vector<string> workers_in_job;
205           worker_cache->ListWorkersInJob(job_name, &workers_in_job);
206           workers.insert(workers.end(), workers_in_job.begin(),
207                          workers_in_job.end());
208         }
209       } else {
210         // If any of the device filters does not have a job specified, then we
211         // must list the workers from all jobs.
212         VLOG(2) << "Listing workers in all jobs because some device "
213                 << "filter has no job specified. Filters were:";
214         if (device_filters.empty()) {
215           VLOG(2) << "- <NO FILTERS>";
216         } else {
217           for (const string& filter : device_filters) {
218             VLOG(2) << "- " << filter;
219           }
220         }
221         worker_cache->ListWorkers(&workers);
222       }
223       for (const string& name : workers) {
224         if (MatchFilters(name) ||
225             DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) {
226           targets_.push_back(name);
227         }
228       }
229     }
230     seen_targets_.assign(targets_.size(), false);
231   }
232 
~DeviceFinder()233   ~DeviceFinder() {
234     for (Device* dev : found_) delete dev;
235   }
236 
Start()237   void Start() {
238     {
239       mutex_lock l(mu_);
240       num_pending_ = targets_.size();
241       if (num_pending_ == 0) {
242         pending_zero_.notify_all();
243       }
244     }
245     // Talk to all workers to get the list of available devices.
246     using std::placeholders::_1;
247     using std::placeholders::_2;
248     for (size_t i = 0; i < targets_.size(); ++i) {
249       // TODO(mrry): Propagate a timeout here, since `this->WhenFound()` may
250       // never be called.
251       NewRemoteDevices(env_->env, worker_cache_, targets_[i],
252                        std::bind(&ME::WhenFound, this, i, _1, _2));
253     }
254   }
255 
256   // Every `kLoggingPeriodMs`, while the DeviceFinder is still waiting
257   // to hear from workers, log a list of the workers who have not
258   // responded.
259   const int32 kLoggingPeriodMs = 10 * 1000;
260 
Wait()261   Status Wait() {
262     mutex_lock l(mu_);
263     // TODO(mrry): Propagate a timeout here, since `num_pending_` may
264     // never become zero.
265     while (num_pending_ != 0) {
266       pending_zero_.wait_for(l, std::chrono::milliseconds(kLoggingPeriodMs));
267       if (num_pending_ != 0) {
268         for (size_t i = 0; i < targets_.size(); ++i) {
269           if (!seen_targets_[i]) {
270             LOG(INFO)
271                 << "CreateSession still waiting for response from worker: "
272                 << targets_[i];
273           }
274         }
275       }
276     }
277     return status_;
278   }
279 
280   // The caller takes the ownership of returned remote devices.
GetRemoteDevices(const std::vector<Device * > & local,std::vector<std::unique_ptr<Device>> * remote)281   void GetRemoteDevices(const std::vector<Device*>& local,
282                         std::vector<std::unique_ptr<Device>>* remote) {
283     std::unordered_set<string> names(local.size());
284     for (Device* dev : local) names.insert(dev->name());
285     mutex_lock l(mu_);
286     for (Device* dev : found_) {
287       const string& name = dev->name();
288       if (names.insert(name).second && MatchFilters(name)) {
289         remote->push_back(std::unique_ptr<Device>(dev));
290       } else {
291         delete dev;
292       }
293     }
294     found_.clear();
295   }
296 
297   typedef DeviceFinder ME;
298   const MasterEnv* env_;
299   WorkerCacheInterface* worker_cache_;
300   std::vector<DeviceNameUtils::ParsedName> filters_;
301 
302   mutex mu_;
303   int num_pending_ TF_GUARDED_BY(mu_);
304   condition_variable pending_zero_;
305   std::vector<Device*> found_ TF_GUARDED_BY(mu_);
306   // List of targets to be contacted by this DeviceFinder. The
307   // respective `bool` in `seen_targets_` indicates whether we have
308   // heard from this target or not.
309   std::vector<string> targets_;
310   std::vector<bool> seen_targets_ TF_GUARDED_BY(mu_);
311   Status status_;
312 
WhenFound(int target_index,const Status & s,std::vector<Device * > * devices)313   void WhenFound(int target_index, const Status& s,
314                  std::vector<Device*>* devices) {
315     mutex_lock l(mu_);
316     seen_targets_[target_index] = true;
317     if (!s.ok()) {
318       LOG(ERROR) << "CreateSession failed because worker "
319                  << targets_[target_index] << " returned error: " << s;
320       status_.Update(s);
321     } else {
322       found_.insert(found_.end(), devices->begin(), devices->end());
323       devices->clear();
324     }
325     --num_pending_;
326     if (num_pending_ == 0) {
327       pending_zero_.notify_all();
328     }
329   }
330 
331   // Returns true iff the set of devices allowed by 'x' intersects
332   // with the set of devices allowed by 'y'.
Intersects(const DeviceNameUtils::ParsedName & x,const DeviceNameUtils::ParsedName & y)333   bool Intersects(const DeviceNameUtils::ParsedName& x,
334                   const DeviceNameUtils::ParsedName& y) {
335     return (!x.has_job || !y.has_job || x.job == y.job) &&
336            (!x.has_replica || !y.has_replica || x.replica == y.replica) &&
337            (!x.has_task || !y.has_task || x.task == y.task) &&
338            (!x.has_type || !y.has_type || x.type == y.type) &&
339            (!x.has_id || !y.has_id || x.id == y.id);
340   }
341 
342   // Returns true iff 'name' matches one of the filters_.
MatchFilters(const string & name)343   bool MatchFilters(const string& name) {
344     if (filters_.empty()) return true;
345     DeviceNameUtils::ParsedName x;
346     if (DeviceNameUtils::ParseFullName(name, &x)) {
347       for (const auto& filter : filters_) {
348         if (Intersects(x, filter)) return true;
349       }
350     }
351     return false;
352   }
353 
354   TF_DISALLOW_COPY_AND_ASSIGN(DeviceFinder);
355 };
356 
CreateSession(const CreateSessionRequest * req,CreateSessionResponse * resp,MyClosure done)357 void Master::CreateSession(const CreateSessionRequest* req,
358                            CreateSessionResponse* resp, MyClosure done) {
359   SchedClosure([this, req, resp, done]() {
360     Status status;
361     WorkerCacheFactoryOptions worker_cache_factory_options;
362     string grpc_protocol("grpc");
363     worker_cache_factory_options.protocol = &grpc_protocol;
364     auto call_done = gtl::MakeCleanup([&status, &done] { done(status); });
365     status = ValidateExternalGraphDefSyntax(req->graph_def());
366     if (!status.ok()) return;
367 
368     // The following 4 variables are set differently, depending on whether this
369     // session uses a client-provided clusterspec or not.
370     WorkerCacheInterface* worker_cache = nullptr;
371     // Note: worker_cache_ptr will be null except if this session is using a
372     // client-supplied ClusterDef (ClusterSpec propagation).
373     std::unique_ptr<WorkerCacheInterface> worker_cache_ptr;
374     std::unique_ptr<DeviceSet> device_set;
375     // TODO(saeta): Convert to std::make_unique when available.
376     std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devices(
377         new std::vector<std::unique_ptr<Device>>());
378 
379     const ClusterDef& cluster_def = req->config().cluster_def();
380     if (!cluster_def.job().empty()) {
381       worker_cache_factory_options.cluster_def = &cluster_def;
382       // If the target starts with gRPC protocol prefix, remove the prefix
383       string normalized_string(req->target());
384       RE2::Replace(&normalized_string, kGrpcPrefixRegex, "");
385 
386       // Set the server_def's job_name and task_index fields.
387       for (auto&& job : cluster_def.job()) {
388         for (auto&& task : job.tasks()) {
389           if (task.second == normalized_string) {
390             if (worker_cache_factory_options.job_name != nullptr) {
391               status = errors::InvalidArgument(
392                   "Found multiple matching tasks that correspond to "
393                   "to the master. Master target: '",
394                   req->target(),
395                   "'. ClusterDef: ", cluster_def.ShortDebugString());
396               LOG(ERROR) << status;
397               return;
398             }
399             if (env_->local_devices[0]->parsed_name().job == job.name() &&
400                 env_->local_devices[0]->parsed_name().task == task.first) {
401               // TODO(b/37868888): Remove this limitation when resolved
402               status = errors::InvalidArgument(
403                   "The ClusterSpec names the job and task index to be the same "
404                   "names that were provided when the server booted. This is "
405                   "currently not allowed. Job: ",
406                   job.name(), ", task index: ", task.first);
407               return;
408             }
409             worker_cache_factory_options.job_name = &job.name();
410             worker_cache_factory_options.task_index = task.first;
411           }
412         }
413       }
414       worker_cache_factory_options.rpc_options = &req->config().rpc_options();
415       // Create the worker cache from the computed server_def.
416       status = env_->worker_cache_factory(worker_cache_factory_options,
417                                           &worker_cache);
418       if (!status.ok()) return;
419       worker_cache_ptr = std::unique_ptr<WorkerCacheInterface>(worker_cache);
420       // Ping all the workers and build the list of devices that the
421       // session will use.
422       status =
423           DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
424                                          worker_cache, remote_devices.get());
425       if (!status.ok()) return;
426       device_set.reset(new DeviceSet);
427       for (auto&& d : *remote_devices) {
428         device_set->AddDevice(d.get());
429         DeviceNameUtils::ParsedName name = d->parsed_name();
430         if (name.job == *worker_cache_factory_options.job_name &&
431             name.task == worker_cache_factory_options.task_index &&
432             name.type == "CPU" && name.id == 0) {
433           device_set->set_client_device(d.get());
434         }
435       }
436     } else {
437       worker_cache = env_->worker_cache;
438       // Ping all the workers and build the list of devices that the
439       // session will use.
440       status =
441           DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
442                                          worker_cache, remote_devices.get());
443       if (!status.ok()) return;
444       device_set.reset(new DeviceSet);
445       for (auto&& d : *remote_devices) {
446         device_set->AddDevice(d.get());
447       }
448       int num_local_devices = 0;
449       for (Device* d : env_->local_devices) {
450         device_set->AddDevice(d);
451         if (num_local_devices == 0) {
452           // Uses the first local device as the client device.
453           device_set->set_client_device(d);
454         }
455         num_local_devices++;
456       }
457     }
458 
459     CHECK(device_set->client_device()) << "No client device found. Missing "
460                                        << "CPU:0 device?";
461 
462     SessionOptions options;
463     options.target = req->target();
464     options.config = req->config();
465 
466     std::vector<string> filtered_worker_list;
467     DeviceFinder::GetRemoteWorkers(req->config().device_filters(), env_,
468                                    worker_cache, &filtered_worker_list);
469 
470     MasterSession* session = env_->master_session_factory(
471         options, env_, std::move(remote_devices), std::move(worker_cache_ptr),
472         std::move(device_set), std::move(filtered_worker_list));
473 
474     GraphDef* gdef =
475         const_cast<CreateSessionRequest*>(req)->mutable_graph_def();
476 
477     status = session->Create(std::move(*gdef), cluster_def);
478     if (!status.ok()) {
479       session->Close().IgnoreError();
480       session->Unref();
481       return;
482     }
483     resp->set_session_handle(session->handle());
484     // Insert into the session map, which takes ownership of the session.
485     {
486       mutex_lock l(mu_);
487       CHECK(sessions_.insert({session->handle(), session}).second);
488     }
489   });
490 }
491 
ExtendSession(const ExtendSessionRequest * req,ExtendSessionResponse * resp,MyClosure done)492 void Master::ExtendSession(const ExtendSessionRequest* req,
493                            ExtendSessionResponse* resp, MyClosure done) {
494   auto session = FindMasterSession(req->session_handle());
495   if (session == nullptr) {
496     done(errors::Aborted("Session ", req->session_handle(), " is not found."));
497     return;
498   }
499 
500   SchedClosure([session, req, resp, done]() {
501     Status status = ValidateExternalGraphDefSyntax(req->graph_def());
502     if (status.ok()) {
503       status = session->Extend(req, resp);
504     }
505     session->Unref();
506     done(status);
507   });
508 }
509 
PartialRunSetup(const PartialRunSetupRequest * req,PartialRunSetupResponse * resp,MyClosure done)510 void Master::PartialRunSetup(const PartialRunSetupRequest* req,
511                              PartialRunSetupResponse* resp, MyClosure done) {
512   Status s = recent_request_ids_.TrackUnique(req->request_id(),
513                                              "PartialRunSetup (Master)", *req);
514   if (!s.ok()) {
515     done(s);
516     return;
517   }
518   auto session = FindMasterSession(req->session_handle());
519   if (session == nullptr) {
520     done(errors::Aborted("Session ", req->session_handle(), " is not found."));
521     return;
522   }
523 
524   SchedClosure([session, req, resp, done]() {
525     Status s = session->PartialRunSetup(req, resp);
526     session->Unref();
527     done(s);
528   });
529 }
530 
RunStep(CallOptions * opts,const RunStepRequestWrapper * req,MutableRunStepResponseWrapper * resp,MyClosure done)531 void Master::RunStep(CallOptions* opts, const RunStepRequestWrapper* req,
532                      MutableRunStepResponseWrapper* resp, MyClosure done) {
533   Status s = recent_request_ids_.TrackUnique(req->request_id(),
534                                              "RunStep (Master)", req);
535   if (!s.ok()) {
536     done(s);
537     return;
538   }
539   auto start_time = env_->env->NowMicros();
540   auto session = FindMasterSession(req->session_handle());
541   if (session == nullptr) {
542     done(errors::Aborted("Session ", req->session_handle(), " is not found."));
543     return;
544   }
545 
546   SchedClosure([this, start_time, session, opts, req, resp, done]() {
547     Status status = session->Run(opts, *req, resp);
548     session->Unref();
549     uint64 done_time = env_->env->NowMicros();
550     done(status);
551     mutex_lock l(mu_);
552     last_1000_steps_.AddValue((done_time - start_time) / 1e9);
553     ++step_count_;
554   });
555 }
556 
CloseSession(const CloseSessionRequest * req,CloseSessionResponse * resp,MyClosure done)557 void Master::CloseSession(const CloseSessionRequest* req,
558                           CloseSessionResponse* resp, MyClosure done) {
559   MasterSession* session = nullptr;
560   {
561     mu_.lock();
562     auto iter = sessions_.find(req->session_handle());
563     if (iter == sessions_.end()) {
564       mu_.unlock();
565       done(errors::Aborted(
566           "Session ", req->session_handle(),
567           " is not found. Possibly, this master has restarted."));
568       return;
569     }
570     // NOTE(mrry): One reference to the session is transferred from
571     // `sessions_[req->session_handle()]` to `session`.
572     session = iter->second;
573     sessions_.erase(iter);
574     mu_.unlock();
575   }
576 
577   // Session Close() blocks on thread shutdown. Therefore, we need to
578   // delete it in non-critical thread.
579   SchedClosure([session, done]() {
580     Status s = session->Close();
581     session->Unref();
582     done(s);
583   });
584 }
585 
ListDevices(const ListDevicesRequest * req,ListDevicesResponse * resp,MyClosure done)586 void Master::ListDevices(const ListDevicesRequest* req,
587                          ListDevicesResponse* resp, MyClosure done) {
588   SchedClosure([this, req, resp, done]() {
589     if (!req->session_handle().empty()) {
590       auto session = FindMasterSession(req->session_handle());
591       if (session == nullptr) {
592         done(errors::InvalidArgument(
593             "Session ", req->session_handle(),
594             " is not found. Possibly, this master has restarted."));
595         return;
596       }
597       core::ScopedUnref ref(session);
598       Status s = session->ListDevices(resp);
599       done(s);
600       return;
601     }
602     std::vector<std::unique_ptr<Device>> remote_devices;
603     Status s = DeviceFinder::GetRemoteDevices({}, env_, env_->worker_cache,
604                                               &remote_devices);
605     if (s.ok()) {
606       for (Device* dev : env_->local_devices) {
607         *(resp->add_local_device()) = dev->attributes();
608       }
609       for (auto&& dev : remote_devices) {
610         *(resp->add_remote_device()) = dev->attributes();
611       }
612     }
613     done(s);
614   });
615 }
616 
CleanupWorkers(const ResetRequest & reset)617 void Master::CleanupWorkers(const ResetRequest& reset) {
618   std::vector<string> worker_names;
619   DeviceFinder::GetRemoteWorkers(reset.device_filters(), env_,
620                                  env_->worker_cache, &worker_names);
621   if (!worker_names.empty()) {
622     const int num_workers = worker_names.size();
623     std::vector<Notification> n(num_workers);
624     CleanupAllRequest req;
625     (*req.mutable_container()) = reset.container();
626     std::vector<CleanupAllResponse> resp(num_workers);
627     int c = 0;
628     for (int i = 0; i < num_workers; ++i) {
629       const string& worker_name = worker_names[i];
630       auto worker = env_->worker_cache->GetOrCreateWorker(worker_name);
631       if (worker) {
632         worker->CleanupAllAsync(
633             &req, &resp[i], [this, &n, worker_name, worker, c](Status s) {
634               TF_CHECK_OK(s);
635               env_->worker_cache->ReleaseWorker(worker_name, worker);
636               n[c].Notify();
637             });
638       } else {
639         n[c].Notify();
640       }
641       ++c;
642     }
643     for (size_t i = 0; i < n.size(); ++i) {
644       n[i].WaitForNotification();
645     }
646   }
647 }
648 
Reset(const ResetRequest * req,ResetResponse * resp,MyClosure done)649 void Master::Reset(const ResetRequest* req, ResetResponse* resp,
650                    MyClosure done) {
651   // Vector to hold the session pointers present in the sessions_
652   // (string->Session*) map.
653   std::vector<MasterSession*> sessions_to_close;
654   {
655     mutex_lock l(mu_);
656     // NOTE(mrry): Transfer one reference to each session from the
657     // `sessions_` map to the `sessions_to_close` vector.
658     for (const auto& entry : sessions_) {
659       sessions_to_close.push_back(entry.second);
660     }
661     sessions_.clear();
662   }
663 
664   CleanupWorkers(*req);
665 
666   SchedClosure([sessions_to_close, done]() {
667     Status s;
668     for (MasterSession* session : sessions_to_close) {
669       s.Update(session->Close());
670       session->Unref();
671     }
672     done(s);
673   });
674 }
675 
MakeCallable(const MakeCallableRequest * req,MakeCallableResponse * resp,MyClosure done)676 void Master::MakeCallable(const MakeCallableRequest* req,
677                           MakeCallableResponse* resp, MyClosure done) {
678   Status s = recent_request_ids_.TrackUnique(req->request_id(),
679                                              "MakeCallable (Master)", *req);
680   if (!s.ok()) {
681     done(s);
682     return;
683   }
684   auto session = FindMasterSession(req->session_handle());
685   if (session == nullptr) {
686     done(errors::Aborted("Session ", req->session_handle(), " is not found."));
687     return;
688   }
689 
690   SchedClosure([session, req, resp, done = std::move(done)]() {
691     Status s = session->MakeCallable(*req, resp);
692     session->Unref();
693     done(s);
694   });
695 }
696 
RunCallable(CallOptions * opts,const RunCallableRequest * req,RunCallableResponse * resp,MyClosure done)697 void Master::RunCallable(CallOptions* opts, const RunCallableRequest* req,
698                          RunCallableResponse* resp, MyClosure done) {
699   Status s = recent_request_ids_.TrackUnique(req->request_id(),
700                                              "RunCallable (Master)", *req);
701   if (!s.ok()) {
702     done(s);
703     return;
704   }
705   auto session = FindMasterSession(req->session_handle());
706   if (session == nullptr) {
707     done(errors::Aborted("Session ", req->session_handle(), " is not found."));
708     return;
709   }
710 
711   SchedClosure([session, opts, req, resp, done = std::move(done)]() {
712     Status s = session->RunCallable(opts, *req, resp);
713     session->Unref();
714     done(s);
715   });
716 }
717 
ReleaseCallable(const ReleaseCallableRequest * req,ReleaseCallableResponse * resp,MyClosure done)718 void Master::ReleaseCallable(const ReleaseCallableRequest* req,
719                              ReleaseCallableResponse* resp, MyClosure done) {
720   auto session = FindMasterSession(req->session_handle());
721   if (session == nullptr) {
722     done(errors::Aborted("Session ", req->session_handle(), " is not found."));
723     return;
724   }
725 
726   SchedClosure([session, req, resp, done = std::move(done)]() {
727     Status s = session->ReleaseCallable(*req, resp);
728     session->Unref();
729     done(s);
730   });
731 }
732 
733 }  // end namespace tensorflow
734