1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Master implements the service MasterService.
17 //
18 // A Master maintains the state of live graph computation
19 // sessions, each session orchestrates both local and remote devices
20 // to carry out the graph computation.
21 //
22 // A Master knows ahead of time local devices available as
23 // client devices.
24 //
25 // A Master discovers remote devices on-demand and keeps track of
26 // statistics of those remote devices.
27 //
28 // Each session analyzes the graph, places nodes across available
29 // devices, and ultimately drives the graph computation by initiating
30 // RunGraph on the workers.
31
32 #include "tensorflow/core/distributed_runtime/master.h"
33
34 #include <unordered_set>
35 #include <vector>
36
37 #include "tensorflow/core/common_runtime/device_set.h"
38 #include "tensorflow/core/common_runtime/process_util.h"
39 #include "tensorflow/core/distributed_runtime/remote_device.h"
40 #include "tensorflow/core/distributed_runtime/worker_cache.h"
41 #include "tensorflow/core/distributed_runtime/worker_interface.h"
42 #include "tensorflow/core/framework/graph_def_util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/notification.h"
45 #include "tensorflow/core/lib/gtl/array_slice.h"
46 #include "tensorflow/core/lib/gtl/cleanup.h"
47 #include "tensorflow/core/lib/gtl/map_util.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/platform/macros.h"
50 #include "tensorflow/core/platform/mutex.h"
51 #include "tensorflow/core/platform/regexp.h"
52 #include "tensorflow/core/platform/types.h"
53 #include "tensorflow/core/protobuf/cluster.pb.h"
54 #include "tensorflow/core/protobuf/master.pb.h"
55 #include "tensorflow/core/protobuf/worker.pb.h"
56 #include "tensorflow/core/public/session_options.h"
57 #include "tensorflow/core/util/device_name_utils.h"
58
59 namespace tensorflow {
60
61 namespace {
62 constexpr char kGrpcPrefixRegex[] = "^grpc.*://";
63 } // namespace
64
Master(MasterEnv * env,double session_gc_seconds)65 Master::Master(MasterEnv* env, double session_gc_seconds)
66 : env_(env),
67 last_1000_steps_(1000),
68 step_count_(0),
69 session_gc_seconds_(session_gc_seconds),
70 recent_request_ids_(10000) {
71 // Right now, a master service must be co-located with a device.
72 // Otherwise, fetches do not work.
73 CHECK(!env->local_devices.empty());
74
75 if (session_gc_seconds_ > 0.0) {
76 gc_thread_ = env_->env->StartThread(ThreadOptions(), "TF_master_GC",
77 [this]() { GC(); });
78 } else {
79 gc_thread_ = nullptr;
80 }
81 }
82
~Master()83 Master::~Master() {
84 if (gc_thread_) {
85 mutex_lock l(mu_);
86 shutdown_ = true;
87 shutdown_cv_.notify_all();
88 delete gc_thread_;
89 }
90 }
91
GC()92 void Master::GC() {
93 Env* env = Env::Default();
94 while (true) {
95 mutex_lock l(mu_);
96 const int kTimeoutMilliseconds = 10 * 1000; // 10 seconds.
97 WaitForMilliseconds(&l, &shutdown_cv_, kTimeoutMilliseconds);
98 if (shutdown_) {
99 break;
100 }
101 std::vector<string> handles;
102 const int64_t num_micros =
103 static_cast<int64_t>(session_gc_seconds_ * 1000000);
104 for (const auto& entry : sessions_) {
105 int64_t lat = entry.second->last_access_time_usec();
106 if (static_cast<int64_t>(env->NowMicros()) - lat > num_micros) {
107 handles.push_back(entry.first);
108 auto* sess = entry.second;
109 SchedClosure([this, sess]() {
110 LOG(WARNING) << "GC session " << sess->handle() << " after "
111 << session_gc_seconds_ << " seconds. "
112 << "Note that if you are starting multiple replicas "
113 << "on a staggered delay, session_gc_seconds may need "
114 << "to be raised.";
115 sess->GarbageCollect();
116 });
117 }
118 }
119 for (const auto& handle : handles) sessions_.erase(handle);
120 }
121 }
122
FindMasterSession(const string & handle)123 MasterSession* Master::FindMasterSession(const string& handle) {
124 MasterSession* session = nullptr;
125 {
126 mutex_lock l(mu_);
127 session = gtl::FindPtrOrNull(sessions_, handle);
128 if (session != nullptr) {
129 session->Ref();
130 }
131 }
132 return session;
133 }
134
135 class DeviceFinder {
136 public:
GetRemoteDevices(const protobuf::RepeatedPtrField<string> & device_filters,MasterEnv * env,WorkerCacheInterface * worker_cache,std::vector<std::unique_ptr<Device>> * out_remote)137 static Status GetRemoteDevices(
138 const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
139 WorkerCacheInterface* worker_cache,
140 std::vector<std::unique_ptr<Device>>* out_remote) {
141 DeviceFinder finder(device_filters, env, worker_cache);
142 finder.Start();
143 TF_RETURN_IF_ERROR(finder.Wait());
144 finder.GetRemoteDevices(env->local_devices, out_remote);
145 return OkStatus();
146 }
147
GetRemoteWorkers(const protobuf::RepeatedPtrField<string> & device_filters,MasterEnv * env,WorkerCacheInterface * worker_cache,std::vector<string> * workers)148 static void GetRemoteWorkers(
149 const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
150 WorkerCacheInterface* worker_cache, std::vector<string>* workers) {
151 DeviceFinder finder(device_filters, env, worker_cache);
152 *workers = finder.targets_;
153 }
154
155 private:
DeviceFinder(const protobuf::RepeatedPtrField<string> & device_filters,MasterEnv * env,WorkerCacheInterface * worker_cache)156 explicit DeviceFinder(
157 const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env,
158 WorkerCacheInterface* worker_cache)
159 : env_(env), worker_cache_(worker_cache) {
160 CHECK(worker_cache) << "Worker cache was null!";
161 auto process_filter = [this](const string& filter) {
162 DeviceNameUtils::ParsedName parsed;
163 if (DeviceNameUtils::ParseFullName(filter, &parsed)) {
164 filters_.push_back(parsed);
165 } else {
166 LOG(FATAL) << "Skipping invalid filter: " << filter;
167 }
168 };
169 for (const string& filter : device_filters) {
170 process_filter(filter);
171 }
172 // Enumerates all known workers' target. A target name is a
173 // prefix of a device name. E.g., /job:mnist/replica:0/task:10.
174 if (filters_.empty()) {
175 // If no filters were specified, we list all known workers in
176 // `worker_cache`.
177 std::vector<string> workers;
178 worker_cache->ListWorkers(&workers);
179 std::swap(workers, targets_);
180 } else {
181 // When applying filters, we must include the local worker, even if it
182 // does not match any of the filters.
183 CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
184 const string& local_device_name = env_->local_devices[0]->name();
185 DeviceNameUtils::ParsedName local_parsed_name;
186 CHECK(DeviceNameUtils::ParseFullName(local_device_name,
187 &local_parsed_name));
188 bool all_filters_have_job = true;
189 std::unordered_set<string> filter_job_names({local_parsed_name.job});
190 for (const DeviceNameUtils::ParsedName& filter : filters_) {
191 all_filters_have_job = all_filters_have_job && filter.has_job;
192 if (filter.has_job) {
193 filter_job_names.insert(filter.job);
194 }
195 }
196
197 std::vector<string> workers;
198 if (all_filters_have_job) {
199 // If all of the device filters have a job specified, then we only need
200 // to list the workers in the jobs named in the filter, because a worker
201 // in any other job would not match any filter.
202 for (const string& job_name : filter_job_names) {
203 VLOG(2) << "Selectively listing workers in job: " << job_name;
204 std::vector<string> workers_in_job;
205 worker_cache->ListWorkersInJob(job_name, &workers_in_job);
206 workers.insert(workers.end(), workers_in_job.begin(),
207 workers_in_job.end());
208 }
209 } else {
210 // If any of the device filters does not have a job specified, then we
211 // must list the workers from all jobs.
212 VLOG(2) << "Listing workers in all jobs because some device "
213 << "filter has no job specified. Filters were:";
214 if (device_filters.empty()) {
215 VLOG(2) << "- <NO FILTERS>";
216 } else {
217 for (const string& filter : device_filters) {
218 VLOG(2) << "- " << filter;
219 }
220 }
221 worker_cache->ListWorkers(&workers);
222 }
223 for (const string& name : workers) {
224 if (MatchFilters(name) ||
225 DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) {
226 targets_.push_back(name);
227 }
228 }
229 }
230 seen_targets_.assign(targets_.size(), false);
231 }
232
~DeviceFinder()233 ~DeviceFinder() {
234 for (Device* dev : found_) delete dev;
235 }
236
Start()237 void Start() {
238 {
239 mutex_lock l(mu_);
240 num_pending_ = targets_.size();
241 if (num_pending_ == 0) {
242 pending_zero_.notify_all();
243 }
244 }
245 // Talk to all workers to get the list of available devices.
246 using std::placeholders::_1;
247 using std::placeholders::_2;
248 for (size_t i = 0; i < targets_.size(); ++i) {
249 // TODO(mrry): Propagate a timeout here, since `this->WhenFound()` may
250 // never be called.
251 NewRemoteDevices(env_->env, worker_cache_, targets_[i],
252 std::bind(&ME::WhenFound, this, i, _1, _2));
253 }
254 }
255
256 // Every `kLoggingPeriodMs`, while the DeviceFinder is still waiting
257 // to hear from workers, log a list of the workers who have not
258 // responded.
259 const int32 kLoggingPeriodMs = 10 * 1000;
260
Wait()261 Status Wait() {
262 mutex_lock l(mu_);
263 // TODO(mrry): Propagate a timeout here, since `num_pending_` may
264 // never become zero.
265 while (num_pending_ != 0) {
266 pending_zero_.wait_for(l, std::chrono::milliseconds(kLoggingPeriodMs));
267 if (num_pending_ != 0) {
268 for (size_t i = 0; i < targets_.size(); ++i) {
269 if (!seen_targets_[i]) {
270 LOG(INFO)
271 << "CreateSession still waiting for response from worker: "
272 << targets_[i];
273 }
274 }
275 }
276 }
277 return status_;
278 }
279
280 // The caller takes the ownership of returned remote devices.
GetRemoteDevices(const std::vector<Device * > & local,std::vector<std::unique_ptr<Device>> * remote)281 void GetRemoteDevices(const std::vector<Device*>& local,
282 std::vector<std::unique_ptr<Device>>* remote) {
283 std::unordered_set<string> names(local.size());
284 for (Device* dev : local) names.insert(dev->name());
285 mutex_lock l(mu_);
286 for (Device* dev : found_) {
287 const string& name = dev->name();
288 if (names.insert(name).second && MatchFilters(name)) {
289 remote->push_back(std::unique_ptr<Device>(dev));
290 } else {
291 delete dev;
292 }
293 }
294 found_.clear();
295 }
296
297 typedef DeviceFinder ME;
298 const MasterEnv* env_;
299 WorkerCacheInterface* worker_cache_;
300 std::vector<DeviceNameUtils::ParsedName> filters_;
301
302 mutex mu_;
303 int num_pending_ TF_GUARDED_BY(mu_);
304 condition_variable pending_zero_;
305 std::vector<Device*> found_ TF_GUARDED_BY(mu_);
306 // List of targets to be contacted by this DeviceFinder. The
307 // respective `bool` in `seen_targets_` indicates whether we have
308 // heard from this target or not.
309 std::vector<string> targets_;
310 std::vector<bool> seen_targets_ TF_GUARDED_BY(mu_);
311 Status status_;
312
WhenFound(int target_index,const Status & s,std::vector<Device * > * devices)313 void WhenFound(int target_index, const Status& s,
314 std::vector<Device*>* devices) {
315 mutex_lock l(mu_);
316 seen_targets_[target_index] = true;
317 if (!s.ok()) {
318 LOG(ERROR) << "CreateSession failed because worker "
319 << targets_[target_index] << " returned error: " << s;
320 status_.Update(s);
321 } else {
322 found_.insert(found_.end(), devices->begin(), devices->end());
323 devices->clear();
324 }
325 --num_pending_;
326 if (num_pending_ == 0) {
327 pending_zero_.notify_all();
328 }
329 }
330
331 // Returns true iff the set of devices allowed by 'x' intersects
332 // with the set of devices allowed by 'y'.
Intersects(const DeviceNameUtils::ParsedName & x,const DeviceNameUtils::ParsedName & y)333 bool Intersects(const DeviceNameUtils::ParsedName& x,
334 const DeviceNameUtils::ParsedName& y) {
335 return (!x.has_job || !y.has_job || x.job == y.job) &&
336 (!x.has_replica || !y.has_replica || x.replica == y.replica) &&
337 (!x.has_task || !y.has_task || x.task == y.task) &&
338 (!x.has_type || !y.has_type || x.type == y.type) &&
339 (!x.has_id || !y.has_id || x.id == y.id);
340 }
341
342 // Returns true iff 'name' matches one of the filters_.
MatchFilters(const string & name)343 bool MatchFilters(const string& name) {
344 if (filters_.empty()) return true;
345 DeviceNameUtils::ParsedName x;
346 if (DeviceNameUtils::ParseFullName(name, &x)) {
347 for (const auto& filter : filters_) {
348 if (Intersects(x, filter)) return true;
349 }
350 }
351 return false;
352 }
353
354 TF_DISALLOW_COPY_AND_ASSIGN(DeviceFinder);
355 };
356
CreateSession(const CreateSessionRequest * req,CreateSessionResponse * resp,MyClosure done)357 void Master::CreateSession(const CreateSessionRequest* req,
358 CreateSessionResponse* resp, MyClosure done) {
359 SchedClosure([this, req, resp, done]() {
360 Status status;
361 WorkerCacheFactoryOptions worker_cache_factory_options;
362 string grpc_protocol("grpc");
363 worker_cache_factory_options.protocol = &grpc_protocol;
364 auto call_done = gtl::MakeCleanup([&status, &done] { done(status); });
365 status = ValidateExternalGraphDefSyntax(req->graph_def());
366 if (!status.ok()) return;
367
368 // The following 4 variables are set differently, depending on whether this
369 // session uses a client-provided clusterspec or not.
370 WorkerCacheInterface* worker_cache = nullptr;
371 // Note: worker_cache_ptr will be null except if this session is using a
372 // client-supplied ClusterDef (ClusterSpec propagation).
373 std::unique_ptr<WorkerCacheInterface> worker_cache_ptr;
374 std::unique_ptr<DeviceSet> device_set;
375 // TODO(saeta): Convert to std::make_unique when available.
376 std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devices(
377 new std::vector<std::unique_ptr<Device>>());
378
379 const ClusterDef& cluster_def = req->config().cluster_def();
380 if (!cluster_def.job().empty()) {
381 worker_cache_factory_options.cluster_def = &cluster_def;
382 // If the target starts with gRPC protocol prefix, remove the prefix
383 string normalized_string(req->target());
384 RE2::Replace(&normalized_string, kGrpcPrefixRegex, "");
385
386 // Set the server_def's job_name and task_index fields.
387 for (auto&& job : cluster_def.job()) {
388 for (auto&& task : job.tasks()) {
389 if (task.second == normalized_string) {
390 if (worker_cache_factory_options.job_name != nullptr) {
391 status = errors::InvalidArgument(
392 "Found multiple matching tasks that correspond to "
393 "to the master. Master target: '",
394 req->target(),
395 "'. ClusterDef: ", cluster_def.ShortDebugString());
396 LOG(ERROR) << status;
397 return;
398 }
399 if (env_->local_devices[0]->parsed_name().job == job.name() &&
400 env_->local_devices[0]->parsed_name().task == task.first) {
401 // TODO(b/37868888): Remove this limitation when resolved
402 status = errors::InvalidArgument(
403 "The ClusterSpec names the job and task index to be the same "
404 "names that were provided when the server booted. This is "
405 "currently not allowed. Job: ",
406 job.name(), ", task index: ", task.first);
407 return;
408 }
409 worker_cache_factory_options.job_name = &job.name();
410 worker_cache_factory_options.task_index = task.first;
411 }
412 }
413 }
414 worker_cache_factory_options.rpc_options = &req->config().rpc_options();
415 // Create the worker cache from the computed server_def.
416 status = env_->worker_cache_factory(worker_cache_factory_options,
417 &worker_cache);
418 if (!status.ok()) return;
419 worker_cache_ptr = std::unique_ptr<WorkerCacheInterface>(worker_cache);
420 // Ping all the workers and build the list of devices that the
421 // session will use.
422 status =
423 DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
424 worker_cache, remote_devices.get());
425 if (!status.ok()) return;
426 device_set.reset(new DeviceSet);
427 for (auto&& d : *remote_devices) {
428 device_set->AddDevice(d.get());
429 DeviceNameUtils::ParsedName name = d->parsed_name();
430 if (name.job == *worker_cache_factory_options.job_name &&
431 name.task == worker_cache_factory_options.task_index &&
432 name.type == "CPU" && name.id == 0) {
433 device_set->set_client_device(d.get());
434 }
435 }
436 } else {
437 worker_cache = env_->worker_cache;
438 // Ping all the workers and build the list of devices that the
439 // session will use.
440 status =
441 DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
442 worker_cache, remote_devices.get());
443 if (!status.ok()) return;
444 device_set.reset(new DeviceSet);
445 for (auto&& d : *remote_devices) {
446 device_set->AddDevice(d.get());
447 }
448 int num_local_devices = 0;
449 for (Device* d : env_->local_devices) {
450 device_set->AddDevice(d);
451 if (num_local_devices == 0) {
452 // Uses the first local device as the client device.
453 device_set->set_client_device(d);
454 }
455 num_local_devices++;
456 }
457 }
458
459 CHECK(device_set->client_device()) << "No client device found. Missing "
460 << "CPU:0 device?";
461
462 SessionOptions options;
463 options.target = req->target();
464 options.config = req->config();
465
466 std::vector<string> filtered_worker_list;
467 DeviceFinder::GetRemoteWorkers(req->config().device_filters(), env_,
468 worker_cache, &filtered_worker_list);
469
470 MasterSession* session = env_->master_session_factory(
471 options, env_, std::move(remote_devices), std::move(worker_cache_ptr),
472 std::move(device_set), std::move(filtered_worker_list));
473
474 GraphDef* gdef =
475 const_cast<CreateSessionRequest*>(req)->mutable_graph_def();
476
477 status = session->Create(std::move(*gdef), cluster_def);
478 if (!status.ok()) {
479 session->Close().IgnoreError();
480 session->Unref();
481 return;
482 }
483 resp->set_session_handle(session->handle());
484 // Insert into the session map, which takes ownership of the session.
485 {
486 mutex_lock l(mu_);
487 CHECK(sessions_.insert({session->handle(), session}).second);
488 }
489 });
490 }
491
ExtendSession(const ExtendSessionRequest * req,ExtendSessionResponse * resp,MyClosure done)492 void Master::ExtendSession(const ExtendSessionRequest* req,
493 ExtendSessionResponse* resp, MyClosure done) {
494 auto session = FindMasterSession(req->session_handle());
495 if (session == nullptr) {
496 done(errors::Aborted("Session ", req->session_handle(), " is not found."));
497 return;
498 }
499
500 SchedClosure([session, req, resp, done]() {
501 Status status = ValidateExternalGraphDefSyntax(req->graph_def());
502 if (status.ok()) {
503 status = session->Extend(req, resp);
504 }
505 session->Unref();
506 done(status);
507 });
508 }
509
PartialRunSetup(const PartialRunSetupRequest * req,PartialRunSetupResponse * resp,MyClosure done)510 void Master::PartialRunSetup(const PartialRunSetupRequest* req,
511 PartialRunSetupResponse* resp, MyClosure done) {
512 Status s = recent_request_ids_.TrackUnique(req->request_id(),
513 "PartialRunSetup (Master)", *req);
514 if (!s.ok()) {
515 done(s);
516 return;
517 }
518 auto session = FindMasterSession(req->session_handle());
519 if (session == nullptr) {
520 done(errors::Aborted("Session ", req->session_handle(), " is not found."));
521 return;
522 }
523
524 SchedClosure([session, req, resp, done]() {
525 Status s = session->PartialRunSetup(req, resp);
526 session->Unref();
527 done(s);
528 });
529 }
530
RunStep(CallOptions * opts,const RunStepRequestWrapper * req,MutableRunStepResponseWrapper * resp,MyClosure done)531 void Master::RunStep(CallOptions* opts, const RunStepRequestWrapper* req,
532 MutableRunStepResponseWrapper* resp, MyClosure done) {
533 Status s = recent_request_ids_.TrackUnique(req->request_id(),
534 "RunStep (Master)", req);
535 if (!s.ok()) {
536 done(s);
537 return;
538 }
539 auto start_time = env_->env->NowMicros();
540 auto session = FindMasterSession(req->session_handle());
541 if (session == nullptr) {
542 done(errors::Aborted("Session ", req->session_handle(), " is not found."));
543 return;
544 }
545
546 SchedClosure([this, start_time, session, opts, req, resp, done]() {
547 Status status = session->Run(opts, *req, resp);
548 session->Unref();
549 uint64 done_time = env_->env->NowMicros();
550 done(status);
551 mutex_lock l(mu_);
552 last_1000_steps_.AddValue((done_time - start_time) / 1e9);
553 ++step_count_;
554 });
555 }
556
CloseSession(const CloseSessionRequest * req,CloseSessionResponse * resp,MyClosure done)557 void Master::CloseSession(const CloseSessionRequest* req,
558 CloseSessionResponse* resp, MyClosure done) {
559 MasterSession* session = nullptr;
560 {
561 mu_.lock();
562 auto iter = sessions_.find(req->session_handle());
563 if (iter == sessions_.end()) {
564 mu_.unlock();
565 done(errors::Aborted(
566 "Session ", req->session_handle(),
567 " is not found. Possibly, this master has restarted."));
568 return;
569 }
570 // NOTE(mrry): One reference to the session is transferred from
571 // `sessions_[req->session_handle()]` to `session`.
572 session = iter->second;
573 sessions_.erase(iter);
574 mu_.unlock();
575 }
576
577 // Session Close() blocks on thread shutdown. Therefore, we need to
578 // delete it in non-critical thread.
579 SchedClosure([session, done]() {
580 Status s = session->Close();
581 session->Unref();
582 done(s);
583 });
584 }
585
ListDevices(const ListDevicesRequest * req,ListDevicesResponse * resp,MyClosure done)586 void Master::ListDevices(const ListDevicesRequest* req,
587 ListDevicesResponse* resp, MyClosure done) {
588 SchedClosure([this, req, resp, done]() {
589 if (!req->session_handle().empty()) {
590 auto session = FindMasterSession(req->session_handle());
591 if (session == nullptr) {
592 done(errors::InvalidArgument(
593 "Session ", req->session_handle(),
594 " is not found. Possibly, this master has restarted."));
595 return;
596 }
597 core::ScopedUnref ref(session);
598 Status s = session->ListDevices(resp);
599 done(s);
600 return;
601 }
602 std::vector<std::unique_ptr<Device>> remote_devices;
603 Status s = DeviceFinder::GetRemoteDevices({}, env_, env_->worker_cache,
604 &remote_devices);
605 if (s.ok()) {
606 for (Device* dev : env_->local_devices) {
607 *(resp->add_local_device()) = dev->attributes();
608 }
609 for (auto&& dev : remote_devices) {
610 *(resp->add_remote_device()) = dev->attributes();
611 }
612 }
613 done(s);
614 });
615 }
616
CleanupWorkers(const ResetRequest & reset)617 void Master::CleanupWorkers(const ResetRequest& reset) {
618 std::vector<string> worker_names;
619 DeviceFinder::GetRemoteWorkers(reset.device_filters(), env_,
620 env_->worker_cache, &worker_names);
621 if (!worker_names.empty()) {
622 const int num_workers = worker_names.size();
623 std::vector<Notification> n(num_workers);
624 CleanupAllRequest req;
625 (*req.mutable_container()) = reset.container();
626 std::vector<CleanupAllResponse> resp(num_workers);
627 int c = 0;
628 for (int i = 0; i < num_workers; ++i) {
629 const string& worker_name = worker_names[i];
630 auto worker = env_->worker_cache->GetOrCreateWorker(worker_name);
631 if (worker) {
632 worker->CleanupAllAsync(
633 &req, &resp[i], [this, &n, worker_name, worker, c](Status s) {
634 TF_CHECK_OK(s);
635 env_->worker_cache->ReleaseWorker(worker_name, worker);
636 n[c].Notify();
637 });
638 } else {
639 n[c].Notify();
640 }
641 ++c;
642 }
643 for (size_t i = 0; i < n.size(); ++i) {
644 n[i].WaitForNotification();
645 }
646 }
647 }
648
Reset(const ResetRequest * req,ResetResponse * resp,MyClosure done)649 void Master::Reset(const ResetRequest* req, ResetResponse* resp,
650 MyClosure done) {
651 // Vector to hold the session pointers present in the sessions_
652 // (string->Session*) map.
653 std::vector<MasterSession*> sessions_to_close;
654 {
655 mutex_lock l(mu_);
656 // NOTE(mrry): Transfer one reference to each session from the
657 // `sessions_` map to the `sessions_to_close` vector.
658 for (const auto& entry : sessions_) {
659 sessions_to_close.push_back(entry.second);
660 }
661 sessions_.clear();
662 }
663
664 CleanupWorkers(*req);
665
666 SchedClosure([sessions_to_close, done]() {
667 Status s;
668 for (MasterSession* session : sessions_to_close) {
669 s.Update(session->Close());
670 session->Unref();
671 }
672 done(s);
673 });
674 }
675
MakeCallable(const MakeCallableRequest * req,MakeCallableResponse * resp,MyClosure done)676 void Master::MakeCallable(const MakeCallableRequest* req,
677 MakeCallableResponse* resp, MyClosure done) {
678 Status s = recent_request_ids_.TrackUnique(req->request_id(),
679 "MakeCallable (Master)", *req);
680 if (!s.ok()) {
681 done(s);
682 return;
683 }
684 auto session = FindMasterSession(req->session_handle());
685 if (session == nullptr) {
686 done(errors::Aborted("Session ", req->session_handle(), " is not found."));
687 return;
688 }
689
690 SchedClosure([session, req, resp, done = std::move(done)]() {
691 Status s = session->MakeCallable(*req, resp);
692 session->Unref();
693 done(s);
694 });
695 }
696
RunCallable(CallOptions * opts,const RunCallableRequest * req,RunCallableResponse * resp,MyClosure done)697 void Master::RunCallable(CallOptions* opts, const RunCallableRequest* req,
698 RunCallableResponse* resp, MyClosure done) {
699 Status s = recent_request_ids_.TrackUnique(req->request_id(),
700 "RunCallable (Master)", *req);
701 if (!s.ok()) {
702 done(s);
703 return;
704 }
705 auto session = FindMasterSession(req->session_handle());
706 if (session == nullptr) {
707 done(errors::Aborted("Session ", req->session_handle(), " is not found."));
708 return;
709 }
710
711 SchedClosure([session, opts, req, resp, done = std::move(done)]() {
712 Status s = session->RunCallable(opts, *req, resp);
713 session->Unref();
714 done(s);
715 });
716 }
717
ReleaseCallable(const ReleaseCallableRequest * req,ReleaseCallableResponse * resp,MyClosure done)718 void Master::ReleaseCallable(const ReleaseCallableRequest* req,
719 ReleaseCallableResponse* resp, MyClosure done) {
720 auto session = FindMasterSession(req->session_handle());
721 if (session == nullptr) {
722 done(errors::Aborted("Session ", req->session_handle(), " is not found."));
723 return;
724 }
725
726 SchedClosure([session, req, resp, done = std::move(done)]() {
727 Status s = session->ReleaseCallable(*req, resp);
728 session->Unref();
729 done(s);
730 });
731 }
732
733 } // end namespace tensorflow
734