1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_H_
18 
19 #include <functional>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/strings/string_view.h"
25 #include "absl/time/time.h"
26 #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h"
27 #include "tensorflow/core/platform/status.h"
28 #include "tensorflow/core/platform/statusor.h"
29 
30 namespace tensorflow {
31 class CoordinationServiceDeviceInfo;
32 class ServerDef;
33 class Env;
34 
35 // Static registration for coordination service implementations.
36 #define REGISTER_COORDINATION_SERVICE(service_type_name, factory_fn)        \
37   REGISTER_COORDINATION_SERVICE_UNIQ_HELPER(__COUNTER__, service_type_name, \
38                                             factory_fn)
39 #define REGISTER_COORDINATION_SERVICE_UNIQ_HELPER(counter, service_type_name, \
40                                                   factory_fn)                 \
41   static bool static_coordination_service_##counter TF_ATTRIBUTE_UNUSED =     \
42       []() {                                                                  \
43         ::tensorflow::CoordinationServiceInterface::                          \
44             RegisterCoordinationService(service_type_name,                    \
45                                         std::move(factory_fn));               \
46         return true;                                                          \
47       }()
48 
49 // Coordination service is used for controlling and coordinating distributed
50 // execution in a cluster of multiple tasks.
51 //
52 // When enabled, the service keeps track of cluster configurations and the state
53 // of cluster members. TF runtime and libraries can use it to orchastrate
54 // cluster initialization, check the healthiness of tasks, and propagate error
55 // messages to the cluster.
56 //
57 // Normally, the service should first Start(), then perform the supported
58 // coordination operations, and finally Stop(). When service runs into error or
59 // SetError() is called, all subsequent operations will be in error state.
60 //
61 // CoordinationServiceInterface defines the service interface for distributed
62 // coordination. One instance of the service should be deployed in a cluster,
63 // handling various requests and stores configuration key-value data for the
64 // tasks. Each task interacts with the service through CoordinationServiceAgent.
65 class CoordinationServiceInterface {
66  public:
67   using CoordinationServiceFactory =
68       std::function<std::unique_ptr<CoordinationServiceInterface>(
69           Env* env, const ServerDef& server_def,
70           std::unique_ptr<CoordinationClientCache> cache)>;
71 
72   using StatusOrValueCallback =
73       std::function<void(const StatusOr<std::string>&)>;
74 
~CoordinationServiceInterface()75   virtual ~CoordinationServiceInterface() {}
76 
RegisterCoordinationService(const std::string & service_type_name,CoordinationServiceFactory factory_fn)77   static void RegisterCoordinationService(
78       const std::string& service_type_name,
79       CoordinationServiceFactory factory_fn) {
80     auto factories = GetCoordinationServiceFactories();
81     factories->emplace(service_type_name, factory_fn);
82   }
83 
84   static std::unique_ptr<CoordinationServiceInterface>
EnableCoordinationService(const std::string & service_type,Env * env,const ServerDef & server_def,std::unique_ptr<CoordinationClientCache> cache)85   EnableCoordinationService(const std::string& service_type, Env* env,
86                             const ServerDef& server_def,
87                             std::unique_ptr<CoordinationClientCache> cache) {
88     const auto* factories = GetCoordinationServiceFactories();
89     auto factories_iter = factories->find(service_type);
90     if (factories_iter == factories->end()) {
91       LOG(ERROR) << "No coordination service factory found for service type "
92                  << service_type;
93       return nullptr;
94     }
95     auto service = factories_iter->second(env, server_def, std::move(cache));
96     if (service != nullptr) {
97       *GetCoordinationServiceInstancePtr() = service.get();
98     }
99     return service;
100   }
101 
GetCoordinationServiceInstance()102   static CoordinationServiceInterface* GetCoordinationServiceInstance() {
103     return *GetCoordinationServiceInstancePtr();
104   }
105 
106   // Register a task to the service.
107   virtual Status RegisterTask(const CoordinatedTask& task,
108                               uint64_t incarnation) = 0;
109 
110   // Wait for all tasks to be up and running, and register local device
111   // info. The callback is invoked when all tasks are up and registered, or some
112   // error occurs.
113   virtual void WaitForAllTasks(const CoordinatedTask& task,
114                                const CoordinationServiceDeviceInfo& devices,
115                                StatusCallback done) = 0;
116 
117   // Disconnects task from the service. If `shutdown_barrier_timeout_in_ms` is
118   // specified in the config, blocks until all tasks reach the barrier before
119   // disconnecting together.
120   // Possible service errors:
121   //   - InvalidArgument: Unexpected task request.
122   //   - FailedPrecondition: task has already disconnected.
123   virtual void ShutdownTaskAsync(const CoordinatedTask& task,
124                                  StatusCallback done) = 0;
125 
126   // Disconnects task from the service and cleans up its internal error state.
127   // Possible service errors:
128   //   - InvalidArgument: Unexpected task request.
129   //   - FailedPrecondition: task has already disconnected.
130   virtual Status ResetTask(const CoordinatedTask& task) = 0;
131 
132   // Update the heartbeat timestamp of a task. This should only be invoked on
133   // the leader of the cluster.
134   virtual Status RecordHeartbeat(const CoordinatedTask& task,
135                                  uint64_t incarnation) = 0;
136 
137   // Set a task in error state permanently.
138   virtual Status ReportTaskError(const CoordinatedTask& task, Status error) = 0;
139 
140   // Insert a configuration key-value in the coordination service.
141   // For now, a key-value can only be inserted once and cannot be updated.
142   // The key-values are not persisted and will be lost if the leader fails.
143   virtual Status InsertKeyValue(const std::string& key,
144                                 const std::string& value) = 0;
145 
146   // Get a configuration key-value from the coordination service. The `done`
147   // callback is invoked when the key-value becomes available.
148   virtual void GetKeyValueAsync(const std::string& key,
149                                 StatusOrValueCallback done) = 0;
150 
151   // Get a configuration key-value from the coordination service. If the key
152   // does not exist, return NotFound error.
153   virtual StatusOr<std::string> TryGetKeyValue(const std::string& key) = 0;
154 
155   // Gets all values under a directory (key).
156   // A value is considered to be in the directory if its key is prefixed with
157   // the directory. This is not a blocking call. Agent does not need to be
158   // connected to utilize the distributed key-value store.
159   virtual std::vector<KeyValueEntry> GetKeyValueDir(
160       absl::string_view directory_key) = 0;
161 
162   // Delete configuration key-value. If key is a directory, recursively clean
163   // up all key-values under the directory.
164   virtual Status DeleteKeyValue(const std::string& key) = 0;
165 
166   // Blocks until all (or a subset of) tasks are at the barrier or the barrier
167   // fails.
168   //
169   // `barrier_id` should be unique across barriers. Once the barrier has passed
170   // or failed, subsequent calls will not block, and immediately respond with
171   // the previous response.
172   //
173   // The first WaitAtBarrier() call received by the service for a particular
174   // barrier id is special in that it determines the barrier deadline based on
175   // timeout duration.
176   // However, if subsequent calls by different agents specify a different set of
177   // `participating_tasks` for the same `barrier_id`, the barrier will fail
178   // instantly.
179   //
180   // If no tasks are specified (default), the barrier will block for all the
181   // connected tasks.
182   //
183   // Possible service errors:
184   //   - DeadlineExceeded: Timed out waiting for specified tasks at the barrier.
185   //      Deadline is determined by the server timestamp when it receives the
186   //      first WaitAtBarrier() + timeout duration.
187   //   - Cancelled: One of the tasks called CancelBarrier().
188   //   - Aborted: Service is shutting down.
189   //   - Internal: Any participating task is in ERROR state.
190   //   - InvalidArgument: (1) Conflicting tasks specified by different agents
191   //       for the same barrier, (2) one of the participating tasks is not in
192   //       the cluster, or (3) task making the request is not included in the
193   //       list of participating tasks.
194   //   - FailedPrecondition: Agent is in UNINITIALIZED or ERROR state.
195   virtual void BarrierAsync(
196       const std::string& barrier_id, absl::Duration timeout,
197       const CoordinatedTask& task,
198       const std::vector<CoordinatedTask>& participating_tasks,
199       StatusCallback done) = 0;
200 
201   // Aborts the barrier if it is ongoing.
202   // Current and future WaitAtBarrier() calls with the same id will return a
203   // CANCELLED error status.
204   // Possible service errors:
205   //   - FailedPrecondition: Barrier has already been passed.
206   virtual Status CancelBarrier(const std::string& barrier_id,
207                                const CoordinatedTask& task) = 0;
208 
209  private:
210   friend class CoordinationServiceRpcHandler;
211   friend class CoordinationServiceTest_ListClusterDevices_TfDevice_Test;
212   friend class CoordinationServiceTest_ListClusterDevices_XlaDevice_Test;
213   friend class
214       CoordinationServiceTest_ListClusterDevices_DevicesAreNotAddedTwice_Test;
215 
216   virtual const CoordinationServiceDeviceInfo& ListClusterDevices() = 0;
217   virtual uint64_t GetServiceIncarnation() = 0;
218 
219   static std::unordered_map<std::string, CoordinationServiceFactory>*
GetCoordinationServiceFactories()220   GetCoordinationServiceFactories() {
221     static auto* coordination_service_factories =
222         new std::unordered_map<std::string, CoordinationServiceFactory>();
223     return coordination_service_factories;
224   }
225 
GetCoordinationServiceInstancePtr()226   static CoordinationServiceInterface** GetCoordinationServiceInstancePtr() {
227     static CoordinationServiceInterface* instance = nullptr;
228     return &instance;
229   }
230 };
231 
232 }  // namespace tensorflow
233 
234 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_H_
235