1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_AGENT_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_AGENT_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 
24 #include "absl/time/time.h"
25 #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
26 #include "tensorflow/core/platform/status.h"
27 #include "tensorflow/core/platform/statusor.h"
28 #include "tensorflow/core/protobuf/coordination_service.pb.h"
29 
30 namespace tensorflow {
31 class CoordinationServiceConfig;
32 class CoordinatedTask;
33 class Env;
34 class ServerDef;
35 
36 // CoordinationServiceAgent defines the interface for tasks to communicate with
37 // the coordination service instance (which implements
38 // CoordinationServiceInterface). One instance of the agent should be deployed
39 // on each task for it to send various requests and stores / retrieves config
40 // key-value data to the service.
41 //
42 // See CoordinationServiceInterface for more details on coordination service.
43 //
44 // All coordination service errors will have an additional
45 // CoordinationServiceError payload to distinguish themselves from RPC failures.
46 // The payload can optionally specify the error origin, and if the error is
47 // reported by the user via `agent->ReportError()`.
48 //
49 // Possible service errors:
50 //    - errors::Internal: Coordination service is not enabled.
51 //                        If it was previously accessible, coordination service
52 //                        has been shut down.
53 //    - errors::Aborted: Incarnation mismatch during heartbeat (either remote
54 //                       task or coordination service has restarted).
55 //    - errors::Unavailable: Heartbeat timeout from remote task (failed,
56 //                           crashed or got preempted).
57 //    - errors::InvalidArgument: Unexpected heartbeat from remote task (not
58 //                               registered or wrong config).
59 class CoordinationServiceAgent {
60  public:
61   using StatusOrValueCallback =
62       std::function<void(const StatusOr<std::string>&)>;
63   // Collection of key-value pairs in the same directory.
64   using StatusOrValueDirCallback =
65       std::function<void(const StatusOr<std::vector<KeyValueEntry>>&)>;
66   using ChangedKeyValuesCallback =
67       std::function<void(const std::map<std::string, std::string>&)>;
68 
~CoordinationServiceAgent()69   virtual ~CoordinationServiceAgent() {}
70 
71   // Initialize coordination service agent.
72   virtual Status Initialize(
73       Env* env, const ServerDef& server_def,
74       std::unique_ptr<CoordinationClientCache> client_cache,
75       StatusCallback error_fn) = 0;
76   virtual Status Initialize(Env* env, const std::string& job_name, int task_id,
77                             const CoordinationServiceConfig& configs,
78                             std::unique_ptr<CoordinationClient> leader_client,
79                             StatusCallback error_fn) = 0;
80   virtual Status Initialize(Env* env, const CoordinatedTask& task,
81                             const CoordinationServiceConfig& configs,
82                             std::unique_ptr<CoordinationClient> leader_client,
83                             StatusCallback error_fn) = 0;
84 
85   // Return true if the coordination service agent has been initialized.
86   virtual bool IsInitialized() = 0;
87 
88   // Connect to coordination service with the following steps:
89   //   - connect to service address specified in the config of `server_def`
90   //   - register itself as a task to the service
91   //   - start a thread to periodically send heartbeat message with the service
92   // Possible service errors:
93   //   - FailedPrecondition: Agent is not in DISCONNECTED state.
94   //   - InvalidArgument: Unexpected task registration
95   //   - Aborted: Duplicate task registration
96   virtual Status Connect() = 0;
97 
98   // Wait for all tasks to be up and registered. The call blocks until all tasks
99   // in the cluster are up, or some error occurs.
100   // Possible service errors:
101   //   - FailedPrecondition: Agent is not in CONNECTED state.
102   //   - InvalidArgument: Unexpected task request
103   virtual Status WaitForAllTasks(
104       const CoordinationServiceDeviceInfo& local_devices) = 0;
105 
106   // Get the device attributes of tasks from remote tasks in the cluster.
107   virtual const CoordinationServiceDeviceInfo& GetClusterDeviceInfo() = 0;
108 
109   // State transition in coordination service agent:
110   //
111   //                 Init              Connect           SetError
112   //   UNINITIALIZED ---> DISCONNECTED ------> CONNECTED -------> ERROR
113   //                           ^                                  |
114   //                           |__________________________________|
115   //                                         Reset
116 
117   // Get task associated with this agent.
118   virtual StatusOr<CoordinatedTask> GetOwnTask() = 0;
119 
120   // Get status of a remote task.
121   virtual StatusOr<CoordinatedTaskState> GetTaskStatus(
122       const CoordinatedTask& task) = 0;
123 
124   // Report error to coordination service. This will invoke the error callback.
125   // Note that the error payload will set `is_reported_error` to true, to
126   // distinguish user-specified errors from internal service or RPC failures.
127   // Possible service errors:
128   //   - FailedPrecondition: Uninitialized/disconnected/already in error state.
129   //   - InvalidArgument: Unexpected task request
130   virtual Status ReportError(const Status& error) = 0;
131 
132   // Shuts down by disconnecting from the service. Should only be called if
133   // agent is connected and no further agent calls (except the destructor) are
134   // expected. If `shutdown_barrier_timeout_in_ms` is specified in the config,
135   // blocks until all tasks reach the barrier before shutting down together. If
136   // the barrier times out, this agent will still disconnect, while an error is
137   // reported to other agents that did not reach the barrier on time.
138   // Possible service errors:
139   //   - InvalidArgument: Unexpected task request.
140   //   - FailedPrecondition: Task was in error state (note: agent is still
141   //                         shut down forcefully).
142   virtual Status Shutdown() = 0;
143 
144   // Disconnect from the service, and clean up the internal error status.
145   // Possible service errors:
146   //   - InvalidArgument: Unexpected task request.
147   //   - FailedPrecondition: task is not in error state/has already
148   //       disconnected.
149   virtual Status Reset() = 0;
150 
151   // Get config key-value from the service.
152   // If the key-value is not inserted yet, this is a blocking call that waits
153   // until the corresponding key is inserted.
154   // Agent does not need to be connected to utilize the distributed key-value
155   // store.
156   //   - errors::DeadlineExceeded: timed out waiting for key.
157   virtual StatusOr<std::string> GetKeyValue(const std::string& key) = 0;
158   virtual StatusOr<std::string> GetKeyValue(const std::string& key,
159                                             absl::Duration timeout) = 0;
160   // Note: Cancel the underlying RPC call with `call_opts->StartCancel()` and
161   // `call_opts->ClearCancelCallback()`.
162   virtual std::shared_ptr<CallOptions> GetKeyValueAsync(
163       const std::string& key, StatusOrValueCallback done) = 0;
164 
165   // Get config key-value from the service.
166   // If the key-value does not exist, this call returns NotFound error.
167   // Agent does not need to be connected to utilize the distributed key-value
168   // store.
169   //   - errors::NotFound: the requested key does not exist.
170   virtual StatusOr<std::string> TryGetKeyValue(const std::string& key) = 0;
171 
172   // Get all values under a directory (key).
173   // A value is considered to be in the directory if its key is prefixed with
174   // the directory.
175   // This is not a blocking call.
176   // Agent does not need to be connected to utilize the distributed key-value
177   // store.
178   virtual StatusOr<std::vector<KeyValueEntry>> GetKeyValueDir(
179       const std::string& key) = 0;
180   virtual void GetKeyValueDirAsync(const std::string& key,
181                                    StatusOrValueDirCallback done) = 0;
182 
183   // Insert config key-value to the service.
184   //   - errors::AlreadyExists: key is already set.
185   virtual Status InsertKeyValue(const std::string& key,
186                                 const std::string& value) = 0;
187 
188   // Delete config keys in the coordination service.
189   virtual Status DeleteKeyValue(const std::string& key) = 0;
190 
191   // Update the value of a config key.
192   virtual Status UpdateKeyValue(const std::string& key,
193                                 const std::string& value) = 0;
194 
195   // Register a callback that will be invoked when the key or keys under the key
196   // directory are changed (inserted, deleted, or updated).
197   virtual Status StartWatchKey(const std::string& key,
198                                ChangedKeyValuesCallback on_change) = 0;
199   virtual Status StopWatchKey(const std::string& key) = 0;
200 
201   // Blocks until all (or a subset of) tasks are at the barrier or the barrier
202   // fails.
203   //
204   // `barrier_id` should be unique across barriers.
205   //
206   // The first WaitAtBarrier() call received by the service for a particular
207   // barrier_id is special in that it determines the barrier deadline based on
208   // timeout duration.
209   // However, if subsequent calls by different agents specify a different set of
210   // `tasks` for the same `barrier_id`, the barrier will fail instantly.
211   // For example,
212   //   agent_1->WaitAtBarrier(“barrier”, 10min, <<”worker”, 1>, <”worker”, 2>>);
213   //   agent_2->WaitAtBarrier(“barrier”, 10min, <<”worker”, 2>, <”worker”, 3>>);
214   // Barrier fails after agent_2’s call because it specifies a different set of
215   // participating tasks.
216   //
217   // If no tasks are specified (default), the barrier will block for all the
218   // connected tasks.
219   //
220   // Possible service errors:
221   //   - DeadlineExceeded: Timed out waiting for specified tasks at the barrier.
222   //      Deadline is determined by the server timestamp when it receives the
223   //      first WaitAtBarrier() + timeout duration.
224   //   - Cancelled: One of the tasks called CancelBarrier().
225   //   - Aborted: Service is shutting down.
226   //   - Internal: Any participating task is in ERROR state.
227   //   - InvalidArgument: (1) Conflicting tasks specified by different agents
228   //       for the same barrier, (2) one of the participating tasks is not in
229   //       the cluster, or (3) task making the request is not included in the
230   //       list of participating tasks.
231   //   - FailedPrecondition: Agent is in UNINITIALIZED or ERROR state. Or the
232   //       same barrier_id was already used previously.
233   virtual Status WaitAtBarrier(const std::string& barrier_id,
234                                absl::Duration timeout,
235                                const std::vector<CoordinatedTask>& tasks) = 0;
236 
237   virtual void WaitAtBarrierAsync(const std::string& barrier_id,
238                                   absl::Duration timeout,
239                                   const std::vector<CoordinatedTask>& tasks,
240                                   StatusCallback done) = 0;
241 
242   // Aborts the barrier if it is ongoing.
243   // Current and future WaitAtBarrier() calls with the same id will return a
244   // CANCELLED error status.
245   // Possible service errors:
246   //   - FailedPrecondition: Barrier has already been passed.
247   virtual Status CancelBarrier(const std::string& barrier_id) = 0;
248   virtual void CancelBarrierAsync(const std::string& barrier_id,
249                                   StatusCallback done) = 0;
250 
251   // Get unowned Env* that the agent was initialized with.
252   virtual StatusOr<Env*> GetEnv() = 0;
253 
254  protected:
255   // Set the service agent to error status and invoke the error callback.
256   // Note: different from ReportError, this does not report the error status to
257   // remote coordination service.
258   virtual void SetError(const Status& error) = 0;
259 
260   // Activate the key-value callback watch.
261   virtual Status ActivateWatch(const std::string& key,
262                                const std::map<std::string, std::string>&) = 0;
263 
264  private:
265   friend class CoordinationServiceRpcHandler;
266 };
267 
268 std::unique_ptr<CoordinationServiceAgent> CreateCoordinationServiceAgent();
269 
270 }  // namespace tensorflow
271 
272 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_AGENT_H_
273