1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/substitute.h"
26 #include "absl/synchronization/notification.h"
27 #include "absl/time/time.h"
28 #include "tensorflow/core/distributed_runtime/call_options.h"
29 #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
30 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h"
31 #include "tensorflow/core/framework/cancellation.h"
32 #include "tensorflow/core/platform/env.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/random.h"
36 #include "tensorflow/core/platform/strcat.h"
37 #include "tensorflow/core/platform/thread_annotations.h"
38 #include "tensorflow/core/protobuf/config.pb.h"
39 #include "tensorflow/core/protobuf/coordination_config.pb.h"
40 #include "tensorflow/core/protobuf/coordination_service.pb.h"
41 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
42 
43 namespace tensorflow {
44 namespace {
45 
46 constexpr absl::Duration kDefaultClusterRegisterTimeout = absl::Hours(1);
47 constexpr absl::Duration kDefaultHeartbeatTimeout = absl::Seconds(10);
48 constexpr absl::Duration kDefaultShutdownTimeout = absl::Seconds(10);
49 constexpr char kHeartbeatThread[] = "CoordinationServiceHeartbeatLoop";
50 
51 class CoordinationServiceAgentImpl : public CoordinationServiceAgent {
52  public:
53   CoordinationServiceAgentImpl() = default;
~CoordinationServiceAgentImpl()54   ~CoordinationServiceAgentImpl() override {
55     Status s = Shutdown();
56     if (!s.ok()) {
57       LOG(ERROR) << "Coordination agent shutdown failed with status: " << s;
58     }
59   }
60   Status Initialize(Env* env, const ServerDef& server_def,
61                     std::unique_ptr<CoordinationClientCache> client_cache,
62                     StatusCallback error_fn) override;
63   Status Initialize(Env* env, const std::string& job_name, int task_id,
64                     const CoordinationServiceConfig& configs,
65                     std::unique_ptr<CoordinationClient> leader_client,
66                     StatusCallback error_fn) override;
67   Status Initialize(Env* env, const CoordinatedTask& task,
68                     const CoordinationServiceConfig& configs,
69                     std::unique_ptr<CoordinationClient> leader_client,
70                     StatusCallback error_fn) override;
71   bool IsInitialized() override;
72 
73   Status Connect() override;
74   Status WaitForAllTasks(
75       const CoordinationServiceDeviceInfo& local_devices) override;
76   const CoordinationServiceDeviceInfo& GetClusterDeviceInfo() override;
77   StatusOr<CoordinatedTask> GetOwnTask() override;
78   StatusOr<CoordinatedTaskState> GetTaskStatus(
79       const CoordinatedTask& task) override;
80   Status ReportError(const Status& error) override;
81   Status Shutdown() override;
82   Status Reset() override;
83 
84   StatusOr<std::string> GetKeyValue(const std::string& key) override;
85   StatusOr<std::string> GetKeyValue(const std::string& key,
86                                     absl::Duration timeout) override;
87   std::shared_ptr<CallOptions> GetKeyValueAsync(
88       const std::string& key, StatusOrValueCallback done) override;
89   StatusOr<std::string> TryGetKeyValue(const std::string& key) override;
90   StatusOr<std::vector<KeyValueEntry>> GetKeyValueDir(
91       const std::string& key) override;
92   void GetKeyValueDirAsync(const std::string& key,
93                            StatusOrValueDirCallback done) override;
94   Status InsertKeyValue(const std::string& key,
95                         const std::string& value) override;
96   Status DeleteKeyValue(const std::string& key) override;
97   Status UpdateKeyValue(const std::string& key,
98                         const std::string& value) override;
99 
100   Status StartWatchKey(const std::string& key,
101                        ChangedKeyValuesCallback on_change) override;
102   Status StopWatchKey(const std::string& key) override;
103   Status WaitAtBarrier(const std::string& barrier_id, absl::Duration timeout,
104                        const std::vector<CoordinatedTask>& tasks) override;
105   void WaitAtBarrierAsync(const std::string& barrier_id, absl::Duration timeout,
106                           const std::vector<CoordinatedTask>& tasks,
107                           StatusCallback done) override;
108   Status CancelBarrier(const std::string& barrier_id) override;
109   void CancelBarrierAsync(const std::string& barrier_id,
110                           StatusCallback done) override;
111 
112   StatusOr<Env*> GetEnv() override;
113 
114  protected:
115   void SetError(const Status& error) override;
116   Status ActivateWatch(const std::string& key,
117                        const std::map<std::string, std::string>&) override;
118   // Returns an error if agent is not running. If `allow_disconnected` is true,
119   // returns OK even if the agent is in DISCONNECTED state.
120   Status ValidateRunningAgent(bool allow_disconnected = false);
121   void StopHeartbeat();
122 
123  private:
124   Env* env_ = nullptr;  // Not owned.
125   const uint64_t incarnation_id_ = random::New64();
126   CoordinatedTask task_;
127   CoordinationServiceConfig configs_;
128   StatusCallback error_fn_;
129 
130   mutable mutex state_mu_;
131   CoordinatedTaskState state_ TF_GUARDED_BY(state_mu_) =
132       CoordinatedTaskState::TASKSTATE_UNINITIALIZED;
133   Status status_ TF_GUARDED_BY(state_mu_) = OkStatus();
134   // Note: this set grows without bounds. For now, this is okay as most users
135   // require < 100 barriers. If there is a use case that requires many barriers,
136   // consider using a monotonic sequence number to track instead.
137   absl::flat_hash_set<std::string> used_barrier_ids_ TF_GUARDED_BY(state_mu_);
138 
139   uint64_t leader_incarnation_ = 0;
140   CoordinationServiceDeviceInfo cluster_devices_;
141 
142   mutex heartbeat_thread_shutdown_mu_;
143   condition_variable heartbeat_thread_cv_;
144   bool shutting_down_ TF_GUARDED_BY(heartbeat_thread_shutdown_mu_) = false;
145   std::unique_ptr<Thread> heartbeat_thread_;
146   // Must outlive coordination client which may need to access it within
147   // GetKeyValueAsync() callbacks.
148   CancellationManager cancellation_manager_;
149   std::unique_ptr<CoordinationClient> leader_client_;
150 
151   TF_DISALLOW_COPY_AND_ASSIGN(CoordinationServiceAgentImpl);
152 };
153 
Initialize(Env * env,const ServerDef & server_def,std::unique_ptr<CoordinationClientCache> client_cache,StatusCallback error_fn)154 Status CoordinationServiceAgentImpl::Initialize(
155     Env* env, const ServerDef& server_def,
156     std::unique_ptr<CoordinationClientCache> client_cache,
157     StatusCallback error_fn) {
158   CoordinationServiceConfig configs =
159       server_def.default_session_config().experimental().coordination_config();
160   if (configs.service_leader().empty()) {
161     const std::string& collective_leader = server_def.default_session_config()
162                                                .experimental()
163                                                .collective_group_leader();
164     if (!collective_leader.empty()) {
165       configs.set_service_leader(collective_leader);
166       LOG(INFO) << "No coordination leader is set, using the collective leader "
167                 << collective_leader;
168     } else {
169       const std::string& default_leader =
170           strings::StrCat("/job:", server_def.job_name(), "/replica:0/task:0");
171       configs.set_service_leader(default_leader);
172       LOG(INFO) << "No coordination leader is set, using the default leader "
173                 << default_leader;
174     }
175   }
176   return Initialize(
177       env, server_def.job_name(), server_def.task_index(), configs,
178       client_cache->GetOwnedClient(configs.service_leader()), error_fn);
179 }
180 
Initialize(Env * env,const std::string & job_name,int task_id,const CoordinationServiceConfig & configs,std::unique_ptr<CoordinationClient> leader_client,StatusCallback error_fn)181 Status CoordinationServiceAgentImpl::Initialize(
182     Env* env, const std::string& job_name, int task_id,
183     const CoordinationServiceConfig& configs,
184     std::unique_ptr<CoordinationClient> leader_client,
185     StatusCallback error_fn) {
186   CoordinatedTask task;
187   task.set_job_name(job_name);
188   task.set_task_id(task_id);
189   return Initialize(env, task, configs, std::move(leader_client), error_fn);
190 }
191 
Initialize(Env * env,const CoordinatedTask & task,const CoordinationServiceConfig & configs,std::unique_ptr<CoordinationClient> leader_client,StatusCallback error_fn)192 Status CoordinationServiceAgentImpl::Initialize(
193     Env* env, const CoordinatedTask& task,
194     const CoordinationServiceConfig& configs,
195     std::unique_ptr<CoordinationClient> leader_client,
196     StatusCallback error_fn) {
197   mutex_lock l(state_mu_);
198   if (state_ != CoordinatedTaskState::TASKSTATE_UNINITIALIZED) {
199     return MakeCoordinationError(errors::FailedPrecondition(
200         "Coordination service agent has already been initialized."));
201   }
202 
203   env_ = env;
204   task_ = task;
205   configs_ = configs;
206   if (configs_.service_leader().empty()) {
207     return MakeCoordinationError(errors::InvalidArgument(
208         "CoordinationServiceAgent must be initialized with a valid leader."));
209   }
210   leader_client_ = std::move(leader_client);
211   if (leader_client_ == nullptr) {
212     return MakeCoordinationError(errors::InvalidArgument(
213         "CoordinationServiceAgent must have a valid leader client."));
214   }
215   error_fn_ = error_fn;
216   state_ = CoordinatedTaskState::TASKSTATE_DISCONNECTED;
217   return OkStatus();
218 }
219 
IsInitialized()220 bool CoordinationServiceAgentImpl::IsInitialized() {
221   mutex_lock l(state_mu_);
222   return state_ != CoordinatedTaskState::TASKSTATE_UNINITIALIZED;
223 }
224 
StopHeartbeat()225 void CoordinationServiceAgentImpl::StopHeartbeat() {
226   {
227     mutex_lock l(heartbeat_thread_shutdown_mu_);
228     shutting_down_ = true;
229     heartbeat_thread_cv_.notify_all();
230   }
231   heartbeat_thread_.reset();
232 }
233 
Connect()234 Status CoordinationServiceAgentImpl::Connect() {
235   {
236     mutex_lock l(state_mu_);
237     if (state_ != CoordinatedTaskState::TASKSTATE_DISCONNECTED) {
238       return MakeCoordinationError(errors::FailedPrecondition(
239           "Coordination service agent is not in DISCONNECTED state."));
240     }
241   }
242   RegisterTaskRequest request;
243   *request.mutable_source_task() = task_;
244   request.set_incarnation(incarnation_id_);
245   RegisterTaskResponse response;
246   absl::Notification n;
247 
248   // Block until the remote service is up and the task is registered.
249   CallOptions call_opts;
250   const int64_t register_timeout =
251       configs_.cluster_register_timeout_in_ms() > 0
252           ? configs_.cluster_register_timeout_in_ms()
253           : absl::ToInt64Milliseconds(kDefaultClusterRegisterTimeout);
254   call_opts.SetTimeout(register_timeout);
255   leader_client_->RegisterTaskAsync(
256       &call_opts, &request, &response, [&](Status s) {
257         if (!s.ok()) {
258           SetError(s);
259         } else {
260           leader_incarnation_ = response.leader_incarnation();
261           {
262             mutex_lock l(state_mu_);
263             state_ = CoordinatedTaskState::TASKSTATE_CONNECTED;
264           }
265         }
266         n.Notify();
267       });
268   n.WaitForNotification();
269   {
270     mutex_lock l(state_mu_);
271     if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) {
272       return status_;
273     }
274   }
275 
276   LOG(INFO) << "Coordination agent has successfully connected.";
277   heartbeat_thread_.reset(
278       env_->StartThread(ThreadOptions(), kHeartbeatThread, [this]() -> void {
279         HeartbeatRequest request;
280         *request.mutable_source_task() = task_;
281         request.set_incarnation(incarnation_id_);
282         HeartbeatResponse response;
283         const int64_t heartbeat_interval_ms =
284             configs_.heartbeat_timeout_in_ms() > 0
285                 ? configs_.heartbeat_timeout_in_ms() / 2
286                 : absl::ToInt64Milliseconds(kDefaultHeartbeatTimeout) / 2;
287         CallOptions call_opts;
288         call_opts.SetTimeout(heartbeat_interval_ms);
289 
290         while (true) {
291           {
292             mutex_lock l(heartbeat_thread_shutdown_mu_);
293             heartbeat_thread_cv_.wait_for(
294                 l, std::chrono::milliseconds(heartbeat_interval_ms));
295             if (shutting_down_) {
296               return;
297             }
298           }
299           Status status;
300           absl::Notification n;
301           // Heartbeat RPC implementation automatically retries to tolerate
302           // transient network failures.
303           leader_client_->HeartbeatAsync(&call_opts, &request, &response,
304                                          [&](Status s) {
305                                            status = s;
306                                            n.Notify();
307                                          });
308           n.WaitForNotification();
309           if (!status.ok()) {
310             SetError(status);
311           } else if (response.leader_incarnation() != leader_incarnation_) {
312             SetError(MakeCoordinationError(
313                 errors::Aborted("Leader incarnation ID mismatch: the "
314                                 "coordination leader has restarted.")));
315           }
316         }
317       }));
318   return OkStatus();
319 }
320 
WaitForAllTasks(const CoordinationServiceDeviceInfo & local_devices)321 Status CoordinationServiceAgentImpl::WaitForAllTasks(
322     const CoordinationServiceDeviceInfo& local_devices) {
323   Status agent_running_status = ValidateRunningAgent();
324   if (!agent_running_status.ok()) {
325     return agent_running_status;
326   }
327   WaitForAllTasksRequest request;
328   *request.mutable_source_task() = task_;
329   *request.mutable_local_device_info() = local_devices;
330   WaitForAllTasksResponse response;
331   Status status;
332   absl::Notification n;
333   leader_client_->WaitForAllTasksAsync(&request, &response, [&](Status s) {
334     status = s;
335     n.Notify();
336   });
337   n.WaitForNotification();
338   if (!status.ok()) {
339     SetError(status);
340     return status;
341   }
342   cluster_devices_.MergeFrom(response.cluster_device_info());
343   return OkStatus();
344 }
345 
346 const CoordinationServiceDeviceInfo&
GetClusterDeviceInfo()347 CoordinationServiceAgentImpl::GetClusterDeviceInfo() {
348   return cluster_devices_;
349 }
350 
GetOwnTask()351 StatusOr<CoordinatedTask> CoordinationServiceAgentImpl::GetOwnTask() {
352   if (!IsInitialized()) {
353     return MakeCoordinationError(
354         errors::FailedPrecondition("Agent has not been initialized; we do not "
355                                    "know the associated task yet."));
356   }
357   return task_;
358 }
359 
GetTaskStatus(const CoordinatedTask & task)360 StatusOr<CoordinatedTaskState> CoordinationServiceAgentImpl::GetTaskStatus(
361     const CoordinatedTask& task) {
362   return MakeCoordinationError(errors::Unimplemented(
363       "CoordinationServiceAgentImpl::GetTaskStatus is not implemented."));
364 }
365 
ReportError(const Status & error)366 Status CoordinationServiceAgentImpl::ReportError(const Status& error) {
367   {
368     mutex_lock l(state_mu_);
369     if (state_ == CoordinatedTaskState::TASKSTATE_UNINITIALIZED) {
370       return MakeCoordinationError(errors::FailedPrecondition(
371           "Coordination service agent must be initialized first before "
372           "reporting error."));
373     } else if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) {
374       return MakeCoordinationError(errors::FailedPrecondition(
375           "Coordination service agent is already in error state."));
376     }
377   }
378   SetError(MakeCoordinationError(error, task_,
379                                  /*is_reported_error=*/true));
380   LOG(INFO) << "Reporting error to coordination service: " << error;
381   ReportErrorToServiceRequest request;
382   request.set_error_code(error.code());
383   request.set_error_message(error.error_message());
384   *request.mutable_error_origin() = task_;
385   ReportErrorToServiceResponse response;
386 
387   absl::Notification n;
388   leader_client_->ReportErrorToServiceAsync(&request, &response, [&](Status s) {
389     if (!s.ok()) {
390       LOG(ERROR) << "Encountered another error when reporting error to "
391                     "coordination service: "
392                  << s;
393     }
394     n.Notify();
395   });
396   n.WaitForNotification();
397   return OkStatus();
398 }
399 
Shutdown()400 Status CoordinationServiceAgentImpl::Shutdown() {
401   LOG(INFO) << "Coordination agent has initiated Shutdown().";
402   Status status = OkStatus();
403   bool is_connected = false;
404   {
405     mutex_lock l(state_mu_);
406     is_connected = state_ == CoordinatedTaskState::TASKSTATE_CONNECTED;
407   }
408   // Disconnect agent from service.
409   if (!configs_.agent_destruction_without_shutdown() && is_connected) {
410     ShutdownTaskRequest request;
411     *request.mutable_source_task() = task_;
412     ShutdownTaskResponse response;
413     CallOptions call_opts;
414     const int64_t shutdown_timeout =
415         configs_.shutdown_barrier_timeout_in_ms() > 0
416             ? configs_.shutdown_barrier_timeout_in_ms()
417             : absl::ToInt64Milliseconds(kDefaultShutdownTimeout);
418     call_opts.SetTimeout(shutdown_timeout);
419 
420     absl::Notification n;
421     leader_client_->ShutdownTaskAsync(&call_opts, &request, &response,
422                                       [&status, &n](Status s) {
423                                         status = s;
424                                         n.Notify();
425                                       });
426     n.WaitForNotification();
427     if (status.ok()) {
428       LOG(INFO) << "Coordination agent has successfully shut down.";
429     } else {
430       LOG(ERROR)
431           << "Failed to disconnect from coordination service with status: "
432           << status << ". Proceeding with agent shutdown anyway.";
433     }
434   }
435 
436   // Tear down agent.
437   StopHeartbeat();
438   {
439     mutex_lock l(state_mu_);
440     if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) {
441       status = MakeCoordinationError(errors::FailedPrecondition(absl::StrCat(
442           "Shutdown() was called while coordination agent is in error state, "
443           "implying that distributed execution failed. Note: agent will still "
444           "shutdown anyway. Agent status: ",
445           status_.ToString())));
446     }
447     state_ = CoordinatedTaskState::TASKSTATE_DISCONNECTED;
448   }
449 
450   // Cancel all pending GetKeyValue() RPC calls.
451   cancellation_manager_.StartCancel();
452   return status;
453 }
454 
Reset()455 Status CoordinationServiceAgentImpl::Reset() {
456   {
457     mutex_lock l(state_mu_);
458     if (state_ != CoordinatedTaskState::TASKSTATE_ERROR) {
459       return MakeCoordinationError(errors::FailedPrecondition(
460           "Reset() failed: coordination service agent is not in ERROR state."));
461     }
462   }
463 
464   ResetTaskRequest request;
465   *request.mutable_source_task() = task_;
466   ResetTaskResponse response;
467 
468   Status status;
469   absl::Notification n;
470   leader_client_->ResetTaskAsync(&request, &response, [&status, &n](Status s) {
471     status = s;
472     n.Notify();
473   });
474   n.WaitForNotification();
475   if (!status.ok()) {
476     return status;
477   }
478 
479   // Reset agent state.
480   StopHeartbeat();
481   {
482     mutex_lock l(state_mu_);
483     state_ = CoordinatedTaskState::TASKSTATE_DISCONNECTED;
484   }
485   {
486     mutex_lock l(heartbeat_thread_shutdown_mu_);
487     shutting_down_ = false;
488   }
489 
490   LOG(INFO) << "Coordination agent has been reset.";
491   return status;
492 }
493 
GetKeyValue(const std::string & key)494 StatusOr<std::string> CoordinationServiceAgentImpl::GetKeyValue(
495     const std::string& key) {
496   return GetKeyValue(key, /*timeout=*/absl::InfiniteDuration());
497 }
498 
GetKeyValue(const std::string & key,absl::Duration timeout)499 StatusOr<std::string> CoordinationServiceAgentImpl::GetKeyValue(
500     const std::string& key, absl::Duration timeout) {
501   auto n = std::make_shared<absl::Notification>();
502   auto result = std::make_shared<StatusOr<std::string>>();
503   GetKeyValueAsync(key,
504                    [n, result](const StatusOr<std::string>& status_or_value) {
505                      *result = status_or_value;
506                      n->Notify();
507                    });
508   bool call_completed_before_timeout =
509       n->WaitForNotificationWithTimeout(timeout);
510   if (!call_completed_before_timeout) {
511     return MakeCoordinationError(errors::DeadlineExceeded(absl::Substitute(
512         "GetKeyValue() timed out with key: $0 and duration: $1", key,
513         absl::FormatDuration(timeout))));
514   }
515   return *result;
516 }
517 
GetKeyValueAsync(const std::string & key,StatusOrValueCallback done)518 std::shared_ptr<CallOptions> CoordinationServiceAgentImpl::GetKeyValueAsync(
519     const std::string& key, StatusOrValueCallback done) {
520   auto request = std::make_shared<GetKeyValueRequest>();
521   request->set_key(key);
522   auto response = std::make_shared<GetKeyValueResponse>();
523   auto call_opts = std::make_shared<CallOptions>();
524 
525   const CancellationToken token =
526       cancellation_manager_.get_cancellation_token();
527   const bool already_cancelled = !cancellation_manager_.RegisterCallback(
528       token, [call_opts]() { call_opts->StartCancel(); });
529   if (already_cancelled) {
530     done(errors::Cancelled("GetKeyValueAsync() was cancelled."));
531     return call_opts;
532   }
533   leader_client_->GetKeyValueAsync(
534       call_opts.get(), request.get(), response.get(),
535       [call_opts, request, response, done = std::move(done),
536        &cm = cancellation_manager_, token](const Status& s) {
537         // RPC call has completed (no longer needs to be cancelled if agent is
538         // destroyed).
539         cm.TryDeregisterCallback(token);
540 
541         // Retrieve server response.
542         if (!s.ok()) {
543           done(s);
544         } else {
545           done(response->kv().value());
546         }
547       });
548   return call_opts;
549 }
550 
TryGetKeyValue(const std::string & key)551 StatusOr<std::string> CoordinationServiceAgentImpl::TryGetKeyValue(
552     const std::string& key) {
553   absl::Notification n;
554   StatusOr<std::string> result;
555   TryGetKeyValueRequest request;
556   request.set_key(key);
557   TryGetKeyValueResponse response;
558   leader_client_->TryGetKeyValueAsync(&request, &response,
559                                       [&](const Status& s) {
560                                         if (s.ok()) {
561                                           result = response.kv().value();
562                                         } else {
563                                           result = s;
564                                         }
565                                         n.Notify();
566                                       });
567   n.WaitForNotification();
568   return result;
569 }
570 
571 StatusOr<std::vector<KeyValueEntry>>
GetKeyValueDir(const std::string & key)572 CoordinationServiceAgentImpl::GetKeyValueDir(const std::string& key) {
573   absl::Notification n;
574   StatusOr<std::vector<KeyValueEntry>> result;
575   GetKeyValueDirAsync(
576       key, [&n, &result](StatusOr<std::vector<KeyValueEntry>> status_or_value) {
577         result = std::move(status_or_value);
578         n.Notify();
579       });
580 
581   n.WaitForNotification();
582   return result;
583 }
584 
GetKeyValueDirAsync(const std::string & key,StatusOrValueDirCallback done)585 void CoordinationServiceAgentImpl::GetKeyValueDirAsync(
586     const std::string& key, StatusOrValueDirCallback done) {
587   auto request = std::make_shared<GetKeyValueDirRequest>();
588   request->set_directory_key(key);
589   auto response = std::make_shared<GetKeyValueDirResponse>();
590   leader_client_->GetKeyValueDirAsync(
591       request.get(), response.get(),
592       [request, response, done = std::move(done)](const Status& s) {
593         if (!s.ok()) {
594           done(s);
595         } else {
596           std::vector<KeyValueEntry> kv_in_directory = {
597               std::make_move_iterator(response->kv().begin()),
598               std::make_move_iterator(response->kv().end())};
599           done(kv_in_directory);
600         }
601       });
602 }
603 
InsertKeyValue(const std::string & key,const std::string & value)604 Status CoordinationServiceAgentImpl::InsertKeyValue(const std::string& key,
605                                                     const std::string& value) {
606   InsertKeyValueRequest request;
607   request.mutable_kv()->set_key(key.data(), key.size());
608   request.mutable_kv()->set_value(value.data(), value.size());
609   InsertKeyValueResponse response;
610 
611   Status status;
612   absl::Notification n;
613   leader_client_->InsertKeyValueAsync(&request, &response, [&](Status s) {
614     status = s;
615     n.Notify();
616   });
617   n.WaitForNotification();
618   return status;
619 }
620 
DeleteKeyValue(const std::string & key)621 Status CoordinationServiceAgentImpl::DeleteKeyValue(const std::string& key) {
622   DeleteKeyValueRequest request;
623   request.set_key(key);
624   request.set_is_directory(true);
625   DeleteKeyValueResponse response;
626 
627   Status status;
628   absl::Notification n;
629   leader_client_->DeleteKeyValueAsync(&request, &response, [&](Status s) {
630     status = s;
631     n.Notify();
632   });
633   n.WaitForNotification();
634   return OkStatus();
635 }
636 
UpdateKeyValue(const std::string & key,const std::string & value)637 Status CoordinationServiceAgentImpl::UpdateKeyValue(const std::string& key,
638                                                     const std::string& value) {
639   return MakeCoordinationError(errors::Unimplemented(
640       "CoordinationServiceAgent::UpdateKeyValue is not implemented."));
641 }
642 
StartWatchKey(const std::string & key,CoordinationServiceAgentImpl::ChangedKeyValuesCallback on_change)643 Status CoordinationServiceAgentImpl::StartWatchKey(
644     const std::string& key,
645     CoordinationServiceAgentImpl::ChangedKeyValuesCallback on_change) {
646   return MakeCoordinationError(errors::Unimplemented(
647       "CoordinationServiceAgent::StartWatchKey is not implemented."));
648 }
649 
StopWatchKey(const std::string & key)650 Status CoordinationServiceAgentImpl::StopWatchKey(const std::string& key) {
651   return MakeCoordinationError(errors::Unimplemented(
652       "CoordinationServiceAgent::StopWatchKey is not implemented."));
653 }
654 
SetError(const Status & error)655 void CoordinationServiceAgentImpl::SetError(const Status& error) {
656   assert(!error.ok());
657   mutex_lock l(state_mu_);
658   if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) return;
659 
660   LOG(ERROR) << "Coordination agent is in ERROR: " << error;
661   state_ = CoordinatedTaskState::TASKSTATE_ERROR;
662   status_ = error;
663   error_fn_(error);
664 }
665 
ActivateWatch(const std::string & key,const std::map<std::string,std::string> & kvs)666 Status CoordinationServiceAgentImpl::ActivateWatch(
667     const std::string& key, const std::map<std::string, std::string>& kvs) {
668   return MakeCoordinationError(errors::Unimplemented(
669       "CoordinationServiceAgent::ActivateWatch is not implemented."));
670 }
671 
WaitAtBarrier(const std::string & barrier_id,absl::Duration timeout,const std::vector<CoordinatedTask> & tasks)672 Status CoordinationServiceAgentImpl::WaitAtBarrier(
673     const std::string& barrier_id, absl::Duration timeout,
674     const std::vector<CoordinatedTask>& tasks) {
675   Status status;
676   absl::Notification n;
677   WaitAtBarrierAsync(barrier_id, timeout, tasks, [&](Status s) {
678     status = s;
679     n.Notify();
680   });
681   n.WaitForNotification();
682   return status;
683 }
684 
WaitAtBarrierAsync(const std::string & barrier_id,absl::Duration timeout,const std::vector<CoordinatedTask> & tasks,StatusCallback done)685 void CoordinationServiceAgentImpl::WaitAtBarrierAsync(
686     const std::string& barrier_id, absl::Duration timeout,
687     const std::vector<CoordinatedTask>& tasks, StatusCallback done) {
688   Status agent_running_status =
689       ValidateRunningAgent(/*allow_disconnected=*/true);
690   if (!agent_running_status.ok()) {
691     done(agent_running_status);
692     return;
693   }
694   {
695     mutex_lock l(state_mu_);
696     auto [it, inserted] = used_barrier_ids_.insert(barrier_id);
697     if (!inserted) {
698       done(errors::FailedPrecondition(
699           "WaitAtBarrier() should not be called with the same id more than "
700           "once. Barrier id: ",
701           barrier_id));
702       return;
703     }
704   }
705   auto request = std::make_shared<BarrierRequest>();
706   auto response = std::make_shared<BarrierResponse>();
707   request->set_barrier_id(barrier_id);
708   request->set_barrier_timeout_in_ms(timeout / absl::Milliseconds(1));
709   *request->mutable_source_task() = task_;
710   *request->mutable_tasks() = {tasks.begin(), tasks.end()};
711   leader_client_->BarrierAsync(request.get(), response.get(),
712                                [request, response, done = std::move(done)](
713                                    const Status& s) { done(s); });
714 }
715 
CancelBarrier(const std::string & barrier_id)716 Status CoordinationServiceAgentImpl::CancelBarrier(
717     const std::string& barrier_id) {
718   Status status;
719   absl::Notification n;
720   CancelBarrierAsync(barrier_id, [&](const Status& s) {
721     status = s;
722     n.Notify();
723   });
724   n.WaitForNotification();
725   return status;
726 }
727 
CancelBarrierAsync(const std::string & barrier_id,StatusCallback done)728 void CoordinationServiceAgentImpl::CancelBarrierAsync(
729     const std::string& barrier_id, StatusCallback done) {
730   Status agent_running_status =
731       ValidateRunningAgent(/*allow_disconnected=*/true);
732   if (!agent_running_status.ok()) {
733     done(agent_running_status);
734     return;
735   }
736   auto request = std::make_shared<CancelBarrierRequest>();
737   auto response = std::make_shared<CancelBarrierResponse>();
738   request->set_barrier_id(barrier_id);
739   *request->mutable_source_task() = task_;
740   leader_client_->CancelBarrierAsync(
741       request.get(), response.get(),
742       [request, response, done = std::move(done)](const Status& s) {
743         done(s);
744       });
745 }
746 
747 // Returns an error if agent is not running.
ValidateRunningAgent(bool allow_disconnected)748 Status CoordinationServiceAgentImpl::ValidateRunningAgent(
749     bool allow_disconnected) {
750   mutex_lock l(state_mu_);
751   switch (state_) {
752     case CoordinatedTaskState::TASKSTATE_CONNECTED:
753       return OkStatus();
754 
755     case CoordinatedTaskState::TASKSTATE_UNINITIALIZED:
756       return MakeCoordinationError(errors::FailedPrecondition(
757           "Agent must be in CONNECTED state. It is currently UNINITIALIZED."));
758 
759     case CoordinatedTaskState::TASKSTATE_DISCONNECTED:
760       if (allow_disconnected) return OkStatus();
761       return MakeCoordinationError(errors::FailedPrecondition(
762           "Agent must be in CONNECTED state. It is currently DISCONNECTED."));
763 
764     case CoordinatedTaskState::TASKSTATE_ERROR:
765       return MakeCoordinationError(errors::FailedPrecondition(
766           "Agent must be in CONNECTED state. It is currently in ERROR."));
767 
768     default:
769       return MakeCoordinationError(errors::FailedPrecondition(absl::StrCat(
770           "Agent is not in CONNECTED state. Current state: ", state_)));
771   }
772 }
773 
GetEnv()774 StatusOr<Env*> CoordinationServiceAgentImpl::GetEnv() {
775   if (!IsInitialized()) {
776     return MakeCoordinationError(errors::FailedPrecondition(
777         "Coordination service agent has not been initialized."));
778   }
779   if (env_ == nullptr) {
780     return MakeCoordinationError(
781         errors::FailedPrecondition("Coordination service agent was not "
782                                    "initialized with a valid Env* object."));
783   }
784   return env_;
785 }
786 
787 }  // namespace
788 
CreateCoordinationServiceAgent()789 std::unique_ptr<CoordinationServiceAgent> CreateCoordinationServiceAgent() {
790   return std::make_unique<CoordinationServiceAgentImpl>();
791 }
792 
793 }  // namespace tensorflow
794