1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/substitute.h"
26 #include "absl/synchronization/notification.h"
27 #include "absl/time/time.h"
28 #include "tensorflow/core/distributed_runtime/call_options.h"
29 #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
30 #include "tensorflow/core/distributed_runtime/coordination/coordination_service_error_util.h"
31 #include "tensorflow/core/framework/cancellation.h"
32 #include "tensorflow/core/platform/env.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/random.h"
36 #include "tensorflow/core/platform/strcat.h"
37 #include "tensorflow/core/platform/thread_annotations.h"
38 #include "tensorflow/core/protobuf/config.pb.h"
39 #include "tensorflow/core/protobuf/coordination_config.pb.h"
40 #include "tensorflow/core/protobuf/coordination_service.pb.h"
41 #include "tensorflow/core/protobuf/tensorflow_server.pb.h"
42
43 namespace tensorflow {
44 namespace {
45
46 constexpr absl::Duration kDefaultClusterRegisterTimeout = absl::Hours(1);
47 constexpr absl::Duration kDefaultHeartbeatTimeout = absl::Seconds(10);
48 constexpr absl::Duration kDefaultShutdownTimeout = absl::Seconds(10);
49 constexpr char kHeartbeatThread[] = "CoordinationServiceHeartbeatLoop";
50
51 class CoordinationServiceAgentImpl : public CoordinationServiceAgent {
52 public:
53 CoordinationServiceAgentImpl() = default;
~CoordinationServiceAgentImpl()54 ~CoordinationServiceAgentImpl() override {
55 Status s = Shutdown();
56 if (!s.ok()) {
57 LOG(ERROR) << "Coordination agent shutdown failed with status: " << s;
58 }
59 }
60 Status Initialize(Env* env, const ServerDef& server_def,
61 std::unique_ptr<CoordinationClientCache> client_cache,
62 StatusCallback error_fn) override;
63 Status Initialize(Env* env, const std::string& job_name, int task_id,
64 const CoordinationServiceConfig& configs,
65 std::unique_ptr<CoordinationClient> leader_client,
66 StatusCallback error_fn) override;
67 Status Initialize(Env* env, const CoordinatedTask& task,
68 const CoordinationServiceConfig& configs,
69 std::unique_ptr<CoordinationClient> leader_client,
70 StatusCallback error_fn) override;
71 bool IsInitialized() override;
72
73 Status Connect() override;
74 Status WaitForAllTasks(
75 const CoordinationServiceDeviceInfo& local_devices) override;
76 const CoordinationServiceDeviceInfo& GetClusterDeviceInfo() override;
77 StatusOr<CoordinatedTask> GetOwnTask() override;
78 StatusOr<CoordinatedTaskState> GetTaskStatus(
79 const CoordinatedTask& task) override;
80 Status ReportError(const Status& error) override;
81 Status Shutdown() override;
82 Status Reset() override;
83
84 StatusOr<std::string> GetKeyValue(const std::string& key) override;
85 StatusOr<std::string> GetKeyValue(const std::string& key,
86 absl::Duration timeout) override;
87 std::shared_ptr<CallOptions> GetKeyValueAsync(
88 const std::string& key, StatusOrValueCallback done) override;
89 StatusOr<std::string> TryGetKeyValue(const std::string& key) override;
90 StatusOr<std::vector<KeyValueEntry>> GetKeyValueDir(
91 const std::string& key) override;
92 void GetKeyValueDirAsync(const std::string& key,
93 StatusOrValueDirCallback done) override;
94 Status InsertKeyValue(const std::string& key,
95 const std::string& value) override;
96 Status DeleteKeyValue(const std::string& key) override;
97 Status UpdateKeyValue(const std::string& key,
98 const std::string& value) override;
99
100 Status StartWatchKey(const std::string& key,
101 ChangedKeyValuesCallback on_change) override;
102 Status StopWatchKey(const std::string& key) override;
103 Status WaitAtBarrier(const std::string& barrier_id, absl::Duration timeout,
104 const std::vector<CoordinatedTask>& tasks) override;
105 void WaitAtBarrierAsync(const std::string& barrier_id, absl::Duration timeout,
106 const std::vector<CoordinatedTask>& tasks,
107 StatusCallback done) override;
108 Status CancelBarrier(const std::string& barrier_id) override;
109 void CancelBarrierAsync(const std::string& barrier_id,
110 StatusCallback done) override;
111
112 StatusOr<Env*> GetEnv() override;
113
114 protected:
115 void SetError(const Status& error) override;
116 Status ActivateWatch(const std::string& key,
117 const std::map<std::string, std::string>&) override;
118 // Returns an error if agent is not running. If `allow_disconnected` is true,
119 // returns OK even if the agent is in DISCONNECTED state.
120 Status ValidateRunningAgent(bool allow_disconnected = false);
121 void StopHeartbeat();
122
123 private:
124 Env* env_ = nullptr; // Not owned.
125 const uint64_t incarnation_id_ = random::New64();
126 CoordinatedTask task_;
127 CoordinationServiceConfig configs_;
128 StatusCallback error_fn_;
129
130 mutable mutex state_mu_;
131 CoordinatedTaskState state_ TF_GUARDED_BY(state_mu_) =
132 CoordinatedTaskState::TASKSTATE_UNINITIALIZED;
133 Status status_ TF_GUARDED_BY(state_mu_) = OkStatus();
134 // Note: this set grows without bounds. For now, this is okay as most users
135 // require < 100 barriers. If there is a use case that requires many barriers,
136 // consider using a monotonic sequence number to track instead.
137 absl::flat_hash_set<std::string> used_barrier_ids_ TF_GUARDED_BY(state_mu_);
138
139 uint64_t leader_incarnation_ = 0;
140 CoordinationServiceDeviceInfo cluster_devices_;
141
142 mutex heartbeat_thread_shutdown_mu_;
143 condition_variable heartbeat_thread_cv_;
144 bool shutting_down_ TF_GUARDED_BY(heartbeat_thread_shutdown_mu_) = false;
145 std::unique_ptr<Thread> heartbeat_thread_;
146 // Must outlive coordination client which may need to access it within
147 // GetKeyValueAsync() callbacks.
148 CancellationManager cancellation_manager_;
149 std::unique_ptr<CoordinationClient> leader_client_;
150
151 TF_DISALLOW_COPY_AND_ASSIGN(CoordinationServiceAgentImpl);
152 };
153
Initialize(Env * env,const ServerDef & server_def,std::unique_ptr<CoordinationClientCache> client_cache,StatusCallback error_fn)154 Status CoordinationServiceAgentImpl::Initialize(
155 Env* env, const ServerDef& server_def,
156 std::unique_ptr<CoordinationClientCache> client_cache,
157 StatusCallback error_fn) {
158 CoordinationServiceConfig configs =
159 server_def.default_session_config().experimental().coordination_config();
160 if (configs.service_leader().empty()) {
161 const std::string& collective_leader = server_def.default_session_config()
162 .experimental()
163 .collective_group_leader();
164 if (!collective_leader.empty()) {
165 configs.set_service_leader(collective_leader);
166 LOG(INFO) << "No coordination leader is set, using the collective leader "
167 << collective_leader;
168 } else {
169 const std::string& default_leader =
170 strings::StrCat("/job:", server_def.job_name(), "/replica:0/task:0");
171 configs.set_service_leader(default_leader);
172 LOG(INFO) << "No coordination leader is set, using the default leader "
173 << default_leader;
174 }
175 }
176 return Initialize(
177 env, server_def.job_name(), server_def.task_index(), configs,
178 client_cache->GetOwnedClient(configs.service_leader()), error_fn);
179 }
180
Initialize(Env * env,const std::string & job_name,int task_id,const CoordinationServiceConfig & configs,std::unique_ptr<CoordinationClient> leader_client,StatusCallback error_fn)181 Status CoordinationServiceAgentImpl::Initialize(
182 Env* env, const std::string& job_name, int task_id,
183 const CoordinationServiceConfig& configs,
184 std::unique_ptr<CoordinationClient> leader_client,
185 StatusCallback error_fn) {
186 CoordinatedTask task;
187 task.set_job_name(job_name);
188 task.set_task_id(task_id);
189 return Initialize(env, task, configs, std::move(leader_client), error_fn);
190 }
191
Initialize(Env * env,const CoordinatedTask & task,const CoordinationServiceConfig & configs,std::unique_ptr<CoordinationClient> leader_client,StatusCallback error_fn)192 Status CoordinationServiceAgentImpl::Initialize(
193 Env* env, const CoordinatedTask& task,
194 const CoordinationServiceConfig& configs,
195 std::unique_ptr<CoordinationClient> leader_client,
196 StatusCallback error_fn) {
197 mutex_lock l(state_mu_);
198 if (state_ != CoordinatedTaskState::TASKSTATE_UNINITIALIZED) {
199 return MakeCoordinationError(errors::FailedPrecondition(
200 "Coordination service agent has already been initialized."));
201 }
202
203 env_ = env;
204 task_ = task;
205 configs_ = configs;
206 if (configs_.service_leader().empty()) {
207 return MakeCoordinationError(errors::InvalidArgument(
208 "CoordinationServiceAgent must be initialized with a valid leader."));
209 }
210 leader_client_ = std::move(leader_client);
211 if (leader_client_ == nullptr) {
212 return MakeCoordinationError(errors::InvalidArgument(
213 "CoordinationServiceAgent must have a valid leader client."));
214 }
215 error_fn_ = error_fn;
216 state_ = CoordinatedTaskState::TASKSTATE_DISCONNECTED;
217 return OkStatus();
218 }
219
IsInitialized()220 bool CoordinationServiceAgentImpl::IsInitialized() {
221 mutex_lock l(state_mu_);
222 return state_ != CoordinatedTaskState::TASKSTATE_UNINITIALIZED;
223 }
224
StopHeartbeat()225 void CoordinationServiceAgentImpl::StopHeartbeat() {
226 {
227 mutex_lock l(heartbeat_thread_shutdown_mu_);
228 shutting_down_ = true;
229 heartbeat_thread_cv_.notify_all();
230 }
231 heartbeat_thread_.reset();
232 }
233
Connect()234 Status CoordinationServiceAgentImpl::Connect() {
235 {
236 mutex_lock l(state_mu_);
237 if (state_ != CoordinatedTaskState::TASKSTATE_DISCONNECTED) {
238 return MakeCoordinationError(errors::FailedPrecondition(
239 "Coordination service agent is not in DISCONNECTED state."));
240 }
241 }
242 RegisterTaskRequest request;
243 *request.mutable_source_task() = task_;
244 request.set_incarnation(incarnation_id_);
245 RegisterTaskResponse response;
246 absl::Notification n;
247
248 // Block until the remote service is up and the task is registered.
249 CallOptions call_opts;
250 const int64_t register_timeout =
251 configs_.cluster_register_timeout_in_ms() > 0
252 ? configs_.cluster_register_timeout_in_ms()
253 : absl::ToInt64Milliseconds(kDefaultClusterRegisterTimeout);
254 call_opts.SetTimeout(register_timeout);
255 leader_client_->RegisterTaskAsync(
256 &call_opts, &request, &response, [&](Status s) {
257 if (!s.ok()) {
258 SetError(s);
259 } else {
260 leader_incarnation_ = response.leader_incarnation();
261 {
262 mutex_lock l(state_mu_);
263 state_ = CoordinatedTaskState::TASKSTATE_CONNECTED;
264 }
265 }
266 n.Notify();
267 });
268 n.WaitForNotification();
269 {
270 mutex_lock l(state_mu_);
271 if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) {
272 return status_;
273 }
274 }
275
276 LOG(INFO) << "Coordination agent has successfully connected.";
277 heartbeat_thread_.reset(
278 env_->StartThread(ThreadOptions(), kHeartbeatThread, [this]() -> void {
279 HeartbeatRequest request;
280 *request.mutable_source_task() = task_;
281 request.set_incarnation(incarnation_id_);
282 HeartbeatResponse response;
283 const int64_t heartbeat_interval_ms =
284 configs_.heartbeat_timeout_in_ms() > 0
285 ? configs_.heartbeat_timeout_in_ms() / 2
286 : absl::ToInt64Milliseconds(kDefaultHeartbeatTimeout) / 2;
287 CallOptions call_opts;
288 call_opts.SetTimeout(heartbeat_interval_ms);
289
290 while (true) {
291 {
292 mutex_lock l(heartbeat_thread_shutdown_mu_);
293 heartbeat_thread_cv_.wait_for(
294 l, std::chrono::milliseconds(heartbeat_interval_ms));
295 if (shutting_down_) {
296 return;
297 }
298 }
299 Status status;
300 absl::Notification n;
301 // Heartbeat RPC implementation automatically retries to tolerate
302 // transient network failures.
303 leader_client_->HeartbeatAsync(&call_opts, &request, &response,
304 [&](Status s) {
305 status = s;
306 n.Notify();
307 });
308 n.WaitForNotification();
309 if (!status.ok()) {
310 SetError(status);
311 } else if (response.leader_incarnation() != leader_incarnation_) {
312 SetError(MakeCoordinationError(
313 errors::Aborted("Leader incarnation ID mismatch: the "
314 "coordination leader has restarted.")));
315 }
316 }
317 }));
318 return OkStatus();
319 }
320
WaitForAllTasks(const CoordinationServiceDeviceInfo & local_devices)321 Status CoordinationServiceAgentImpl::WaitForAllTasks(
322 const CoordinationServiceDeviceInfo& local_devices) {
323 Status agent_running_status = ValidateRunningAgent();
324 if (!agent_running_status.ok()) {
325 return agent_running_status;
326 }
327 WaitForAllTasksRequest request;
328 *request.mutable_source_task() = task_;
329 *request.mutable_local_device_info() = local_devices;
330 WaitForAllTasksResponse response;
331 Status status;
332 absl::Notification n;
333 leader_client_->WaitForAllTasksAsync(&request, &response, [&](Status s) {
334 status = s;
335 n.Notify();
336 });
337 n.WaitForNotification();
338 if (!status.ok()) {
339 SetError(status);
340 return status;
341 }
342 cluster_devices_.MergeFrom(response.cluster_device_info());
343 return OkStatus();
344 }
345
346 const CoordinationServiceDeviceInfo&
GetClusterDeviceInfo()347 CoordinationServiceAgentImpl::GetClusterDeviceInfo() {
348 return cluster_devices_;
349 }
350
GetOwnTask()351 StatusOr<CoordinatedTask> CoordinationServiceAgentImpl::GetOwnTask() {
352 if (!IsInitialized()) {
353 return MakeCoordinationError(
354 errors::FailedPrecondition("Agent has not been initialized; we do not "
355 "know the associated task yet."));
356 }
357 return task_;
358 }
359
GetTaskStatus(const CoordinatedTask & task)360 StatusOr<CoordinatedTaskState> CoordinationServiceAgentImpl::GetTaskStatus(
361 const CoordinatedTask& task) {
362 return MakeCoordinationError(errors::Unimplemented(
363 "CoordinationServiceAgentImpl::GetTaskStatus is not implemented."));
364 }
365
ReportError(const Status & error)366 Status CoordinationServiceAgentImpl::ReportError(const Status& error) {
367 {
368 mutex_lock l(state_mu_);
369 if (state_ == CoordinatedTaskState::TASKSTATE_UNINITIALIZED) {
370 return MakeCoordinationError(errors::FailedPrecondition(
371 "Coordination service agent must be initialized first before "
372 "reporting error."));
373 } else if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) {
374 return MakeCoordinationError(errors::FailedPrecondition(
375 "Coordination service agent is already in error state."));
376 }
377 }
378 SetError(MakeCoordinationError(error, task_,
379 /*is_reported_error=*/true));
380 LOG(INFO) << "Reporting error to coordination service: " << error;
381 ReportErrorToServiceRequest request;
382 request.set_error_code(error.code());
383 request.set_error_message(error.error_message());
384 *request.mutable_error_origin() = task_;
385 ReportErrorToServiceResponse response;
386
387 absl::Notification n;
388 leader_client_->ReportErrorToServiceAsync(&request, &response, [&](Status s) {
389 if (!s.ok()) {
390 LOG(ERROR) << "Encountered another error when reporting error to "
391 "coordination service: "
392 << s;
393 }
394 n.Notify();
395 });
396 n.WaitForNotification();
397 return OkStatus();
398 }
399
Shutdown()400 Status CoordinationServiceAgentImpl::Shutdown() {
401 LOG(INFO) << "Coordination agent has initiated Shutdown().";
402 Status status = OkStatus();
403 bool is_connected = false;
404 {
405 mutex_lock l(state_mu_);
406 is_connected = state_ == CoordinatedTaskState::TASKSTATE_CONNECTED;
407 }
408 // Disconnect agent from service.
409 if (!configs_.agent_destruction_without_shutdown() && is_connected) {
410 ShutdownTaskRequest request;
411 *request.mutable_source_task() = task_;
412 ShutdownTaskResponse response;
413 CallOptions call_opts;
414 const int64_t shutdown_timeout =
415 configs_.shutdown_barrier_timeout_in_ms() > 0
416 ? configs_.shutdown_barrier_timeout_in_ms()
417 : absl::ToInt64Milliseconds(kDefaultShutdownTimeout);
418 call_opts.SetTimeout(shutdown_timeout);
419
420 absl::Notification n;
421 leader_client_->ShutdownTaskAsync(&call_opts, &request, &response,
422 [&status, &n](Status s) {
423 status = s;
424 n.Notify();
425 });
426 n.WaitForNotification();
427 if (status.ok()) {
428 LOG(INFO) << "Coordination agent has successfully shut down.";
429 } else {
430 LOG(ERROR)
431 << "Failed to disconnect from coordination service with status: "
432 << status << ". Proceeding with agent shutdown anyway.";
433 }
434 }
435
436 // Tear down agent.
437 StopHeartbeat();
438 {
439 mutex_lock l(state_mu_);
440 if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) {
441 status = MakeCoordinationError(errors::FailedPrecondition(absl::StrCat(
442 "Shutdown() was called while coordination agent is in error state, "
443 "implying that distributed execution failed. Note: agent will still "
444 "shutdown anyway. Agent status: ",
445 status_.ToString())));
446 }
447 state_ = CoordinatedTaskState::TASKSTATE_DISCONNECTED;
448 }
449
450 // Cancel all pending GetKeyValue() RPC calls.
451 cancellation_manager_.StartCancel();
452 return status;
453 }
454
Reset()455 Status CoordinationServiceAgentImpl::Reset() {
456 {
457 mutex_lock l(state_mu_);
458 if (state_ != CoordinatedTaskState::TASKSTATE_ERROR) {
459 return MakeCoordinationError(errors::FailedPrecondition(
460 "Reset() failed: coordination service agent is not in ERROR state."));
461 }
462 }
463
464 ResetTaskRequest request;
465 *request.mutable_source_task() = task_;
466 ResetTaskResponse response;
467
468 Status status;
469 absl::Notification n;
470 leader_client_->ResetTaskAsync(&request, &response, [&status, &n](Status s) {
471 status = s;
472 n.Notify();
473 });
474 n.WaitForNotification();
475 if (!status.ok()) {
476 return status;
477 }
478
479 // Reset agent state.
480 StopHeartbeat();
481 {
482 mutex_lock l(state_mu_);
483 state_ = CoordinatedTaskState::TASKSTATE_DISCONNECTED;
484 }
485 {
486 mutex_lock l(heartbeat_thread_shutdown_mu_);
487 shutting_down_ = false;
488 }
489
490 LOG(INFO) << "Coordination agent has been reset.";
491 return status;
492 }
493
GetKeyValue(const std::string & key)494 StatusOr<std::string> CoordinationServiceAgentImpl::GetKeyValue(
495 const std::string& key) {
496 return GetKeyValue(key, /*timeout=*/absl::InfiniteDuration());
497 }
498
GetKeyValue(const std::string & key,absl::Duration timeout)499 StatusOr<std::string> CoordinationServiceAgentImpl::GetKeyValue(
500 const std::string& key, absl::Duration timeout) {
501 auto n = std::make_shared<absl::Notification>();
502 auto result = std::make_shared<StatusOr<std::string>>();
503 GetKeyValueAsync(key,
504 [n, result](const StatusOr<std::string>& status_or_value) {
505 *result = status_or_value;
506 n->Notify();
507 });
508 bool call_completed_before_timeout =
509 n->WaitForNotificationWithTimeout(timeout);
510 if (!call_completed_before_timeout) {
511 return MakeCoordinationError(errors::DeadlineExceeded(absl::Substitute(
512 "GetKeyValue() timed out with key: $0 and duration: $1", key,
513 absl::FormatDuration(timeout))));
514 }
515 return *result;
516 }
517
GetKeyValueAsync(const std::string & key,StatusOrValueCallback done)518 std::shared_ptr<CallOptions> CoordinationServiceAgentImpl::GetKeyValueAsync(
519 const std::string& key, StatusOrValueCallback done) {
520 auto request = std::make_shared<GetKeyValueRequest>();
521 request->set_key(key);
522 auto response = std::make_shared<GetKeyValueResponse>();
523 auto call_opts = std::make_shared<CallOptions>();
524
525 const CancellationToken token =
526 cancellation_manager_.get_cancellation_token();
527 const bool already_cancelled = !cancellation_manager_.RegisterCallback(
528 token, [call_opts]() { call_opts->StartCancel(); });
529 if (already_cancelled) {
530 done(errors::Cancelled("GetKeyValueAsync() was cancelled."));
531 return call_opts;
532 }
533 leader_client_->GetKeyValueAsync(
534 call_opts.get(), request.get(), response.get(),
535 [call_opts, request, response, done = std::move(done),
536 &cm = cancellation_manager_, token](const Status& s) {
537 // RPC call has completed (no longer needs to be cancelled if agent is
538 // destroyed).
539 cm.TryDeregisterCallback(token);
540
541 // Retrieve server response.
542 if (!s.ok()) {
543 done(s);
544 } else {
545 done(response->kv().value());
546 }
547 });
548 return call_opts;
549 }
550
TryGetKeyValue(const std::string & key)551 StatusOr<std::string> CoordinationServiceAgentImpl::TryGetKeyValue(
552 const std::string& key) {
553 absl::Notification n;
554 StatusOr<std::string> result;
555 TryGetKeyValueRequest request;
556 request.set_key(key);
557 TryGetKeyValueResponse response;
558 leader_client_->TryGetKeyValueAsync(&request, &response,
559 [&](const Status& s) {
560 if (s.ok()) {
561 result = response.kv().value();
562 } else {
563 result = s;
564 }
565 n.Notify();
566 });
567 n.WaitForNotification();
568 return result;
569 }
570
571 StatusOr<std::vector<KeyValueEntry>>
GetKeyValueDir(const std::string & key)572 CoordinationServiceAgentImpl::GetKeyValueDir(const std::string& key) {
573 absl::Notification n;
574 StatusOr<std::vector<KeyValueEntry>> result;
575 GetKeyValueDirAsync(
576 key, [&n, &result](StatusOr<std::vector<KeyValueEntry>> status_or_value) {
577 result = std::move(status_or_value);
578 n.Notify();
579 });
580
581 n.WaitForNotification();
582 return result;
583 }
584
GetKeyValueDirAsync(const std::string & key,StatusOrValueDirCallback done)585 void CoordinationServiceAgentImpl::GetKeyValueDirAsync(
586 const std::string& key, StatusOrValueDirCallback done) {
587 auto request = std::make_shared<GetKeyValueDirRequest>();
588 request->set_directory_key(key);
589 auto response = std::make_shared<GetKeyValueDirResponse>();
590 leader_client_->GetKeyValueDirAsync(
591 request.get(), response.get(),
592 [request, response, done = std::move(done)](const Status& s) {
593 if (!s.ok()) {
594 done(s);
595 } else {
596 std::vector<KeyValueEntry> kv_in_directory = {
597 std::make_move_iterator(response->kv().begin()),
598 std::make_move_iterator(response->kv().end())};
599 done(kv_in_directory);
600 }
601 });
602 }
603
InsertKeyValue(const std::string & key,const std::string & value)604 Status CoordinationServiceAgentImpl::InsertKeyValue(const std::string& key,
605 const std::string& value) {
606 InsertKeyValueRequest request;
607 request.mutable_kv()->set_key(key.data(), key.size());
608 request.mutable_kv()->set_value(value.data(), value.size());
609 InsertKeyValueResponse response;
610
611 Status status;
612 absl::Notification n;
613 leader_client_->InsertKeyValueAsync(&request, &response, [&](Status s) {
614 status = s;
615 n.Notify();
616 });
617 n.WaitForNotification();
618 return status;
619 }
620
DeleteKeyValue(const std::string & key)621 Status CoordinationServiceAgentImpl::DeleteKeyValue(const std::string& key) {
622 DeleteKeyValueRequest request;
623 request.set_key(key);
624 request.set_is_directory(true);
625 DeleteKeyValueResponse response;
626
627 Status status;
628 absl::Notification n;
629 leader_client_->DeleteKeyValueAsync(&request, &response, [&](Status s) {
630 status = s;
631 n.Notify();
632 });
633 n.WaitForNotification();
634 return OkStatus();
635 }
636
UpdateKeyValue(const std::string & key,const std::string & value)637 Status CoordinationServiceAgentImpl::UpdateKeyValue(const std::string& key,
638 const std::string& value) {
639 return MakeCoordinationError(errors::Unimplemented(
640 "CoordinationServiceAgent::UpdateKeyValue is not implemented."));
641 }
642
StartWatchKey(const std::string & key,CoordinationServiceAgentImpl::ChangedKeyValuesCallback on_change)643 Status CoordinationServiceAgentImpl::StartWatchKey(
644 const std::string& key,
645 CoordinationServiceAgentImpl::ChangedKeyValuesCallback on_change) {
646 return MakeCoordinationError(errors::Unimplemented(
647 "CoordinationServiceAgent::StartWatchKey is not implemented."));
648 }
649
StopWatchKey(const std::string & key)650 Status CoordinationServiceAgentImpl::StopWatchKey(const std::string& key) {
651 return MakeCoordinationError(errors::Unimplemented(
652 "CoordinationServiceAgent::StopWatchKey is not implemented."));
653 }
654
SetError(const Status & error)655 void CoordinationServiceAgentImpl::SetError(const Status& error) {
656 assert(!error.ok());
657 mutex_lock l(state_mu_);
658 if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) return;
659
660 LOG(ERROR) << "Coordination agent is in ERROR: " << error;
661 state_ = CoordinatedTaskState::TASKSTATE_ERROR;
662 status_ = error;
663 error_fn_(error);
664 }
665
ActivateWatch(const std::string & key,const std::map<std::string,std::string> & kvs)666 Status CoordinationServiceAgentImpl::ActivateWatch(
667 const std::string& key, const std::map<std::string, std::string>& kvs) {
668 return MakeCoordinationError(errors::Unimplemented(
669 "CoordinationServiceAgent::ActivateWatch is not implemented."));
670 }
671
WaitAtBarrier(const std::string & barrier_id,absl::Duration timeout,const std::vector<CoordinatedTask> & tasks)672 Status CoordinationServiceAgentImpl::WaitAtBarrier(
673 const std::string& barrier_id, absl::Duration timeout,
674 const std::vector<CoordinatedTask>& tasks) {
675 Status status;
676 absl::Notification n;
677 WaitAtBarrierAsync(barrier_id, timeout, tasks, [&](Status s) {
678 status = s;
679 n.Notify();
680 });
681 n.WaitForNotification();
682 return status;
683 }
684
WaitAtBarrierAsync(const std::string & barrier_id,absl::Duration timeout,const std::vector<CoordinatedTask> & tasks,StatusCallback done)685 void CoordinationServiceAgentImpl::WaitAtBarrierAsync(
686 const std::string& barrier_id, absl::Duration timeout,
687 const std::vector<CoordinatedTask>& tasks, StatusCallback done) {
688 Status agent_running_status =
689 ValidateRunningAgent(/*allow_disconnected=*/true);
690 if (!agent_running_status.ok()) {
691 done(agent_running_status);
692 return;
693 }
694 {
695 mutex_lock l(state_mu_);
696 auto [it, inserted] = used_barrier_ids_.insert(barrier_id);
697 if (!inserted) {
698 done(errors::FailedPrecondition(
699 "WaitAtBarrier() should not be called with the same id more than "
700 "once. Barrier id: ",
701 barrier_id));
702 return;
703 }
704 }
705 auto request = std::make_shared<BarrierRequest>();
706 auto response = std::make_shared<BarrierResponse>();
707 request->set_barrier_id(barrier_id);
708 request->set_barrier_timeout_in_ms(timeout / absl::Milliseconds(1));
709 *request->mutable_source_task() = task_;
710 *request->mutable_tasks() = {tasks.begin(), tasks.end()};
711 leader_client_->BarrierAsync(request.get(), response.get(),
712 [request, response, done = std::move(done)](
713 const Status& s) { done(s); });
714 }
715
CancelBarrier(const std::string & barrier_id)716 Status CoordinationServiceAgentImpl::CancelBarrier(
717 const std::string& barrier_id) {
718 Status status;
719 absl::Notification n;
720 CancelBarrierAsync(barrier_id, [&](const Status& s) {
721 status = s;
722 n.Notify();
723 });
724 n.WaitForNotification();
725 return status;
726 }
727
CancelBarrierAsync(const std::string & barrier_id,StatusCallback done)728 void CoordinationServiceAgentImpl::CancelBarrierAsync(
729 const std::string& barrier_id, StatusCallback done) {
730 Status agent_running_status =
731 ValidateRunningAgent(/*allow_disconnected=*/true);
732 if (!agent_running_status.ok()) {
733 done(agent_running_status);
734 return;
735 }
736 auto request = std::make_shared<CancelBarrierRequest>();
737 auto response = std::make_shared<CancelBarrierResponse>();
738 request->set_barrier_id(barrier_id);
739 *request->mutable_source_task() = task_;
740 leader_client_->CancelBarrierAsync(
741 request.get(), response.get(),
742 [request, response, done = std::move(done)](const Status& s) {
743 done(s);
744 });
745 }
746
747 // Returns an error if agent is not running.
ValidateRunningAgent(bool allow_disconnected)748 Status CoordinationServiceAgentImpl::ValidateRunningAgent(
749 bool allow_disconnected) {
750 mutex_lock l(state_mu_);
751 switch (state_) {
752 case CoordinatedTaskState::TASKSTATE_CONNECTED:
753 return OkStatus();
754
755 case CoordinatedTaskState::TASKSTATE_UNINITIALIZED:
756 return MakeCoordinationError(errors::FailedPrecondition(
757 "Agent must be in CONNECTED state. It is currently UNINITIALIZED."));
758
759 case CoordinatedTaskState::TASKSTATE_DISCONNECTED:
760 if (allow_disconnected) return OkStatus();
761 return MakeCoordinationError(errors::FailedPrecondition(
762 "Agent must be in CONNECTED state. It is currently DISCONNECTED."));
763
764 case CoordinatedTaskState::TASKSTATE_ERROR:
765 return MakeCoordinationError(errors::FailedPrecondition(
766 "Agent must be in CONNECTED state. It is currently in ERROR."));
767
768 default:
769 return MakeCoordinationError(errors::FailedPrecondition(absl::StrCat(
770 "Agent is not in CONNECTED state. Current state: ", state_)));
771 }
772 }
773
GetEnv()774 StatusOr<Env*> CoordinationServiceAgentImpl::GetEnv() {
775 if (!IsInitialized()) {
776 return MakeCoordinationError(errors::FailedPrecondition(
777 "Coordination service agent has not been initialized."));
778 }
779 if (env_ == nullptr) {
780 return MakeCoordinationError(
781 errors::FailedPrecondition("Coordination service agent was not "
782 "initialized with a valid Env* object."));
783 }
784 return env_;
785 }
786
787 } // namespace
788
CreateCoordinationServiceAgent()789 std::unique_ptr<CoordinationServiceAgent> CreateCoordinationServiceAgent() {
790 return std::make_unique<CoordinationServiceAgentImpl>();
791 }
792
793 } // namespace tensorflow
794