1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_H_ 18 19 #include <functional> 20 #include <string> 21 #include <utility> 22 #include <vector> 23 24 #include "absl/strings/string_view.h" 25 #include "absl/time/time.h" 26 #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h" 27 #include "tensorflow/core/platform/status.h" 28 #include "tensorflow/core/platform/statusor.h" 29 30 namespace tensorflow { 31 class CoordinationServiceDeviceInfo; 32 class ServerDef; 33 class Env; 34 35 // Static registration for coordination service implementations. 36 #define REGISTER_COORDINATION_SERVICE(service_type_name, factory_fn) \ 37 REGISTER_COORDINATION_SERVICE_UNIQ_HELPER(__COUNTER__, service_type_name, \ 38 factory_fn) 39 #define REGISTER_COORDINATION_SERVICE_UNIQ_HELPER(counter, service_type_name, \ 40 factory_fn) \ 41 static bool static_coordination_service_##counter TF_ATTRIBUTE_UNUSED = \ 42 []() { \ 43 ::tensorflow::CoordinationServiceInterface:: \ 44 RegisterCoordinationService(service_type_name, \ 45 std::move(factory_fn)); \ 46 return true; \ 47 }() 48 49 // Coordination service is used for controlling and coordinating distributed 50 // execution in a cluster of multiple tasks. 51 // 52 // When enabled, the service keeps track of cluster configurations and the state 53 // of cluster members. TF runtime and libraries can use it to orchastrate 54 // cluster initialization, check the healthiness of tasks, and propagate error 55 // messages to the cluster. 56 // 57 // Normally, the service should first Start(), then perform the supported 58 // coordination operations, and finally Stop(). When service runs into error or 59 // SetError() is called, all subsequent operations will be in error state. 60 // 61 // CoordinationServiceInterface defines the service interface for distributed 62 // coordination. One instance of the service should be deployed in a cluster, 63 // handling various requests and stores configuration key-value data for the 64 // tasks. Each task interacts with the service through CoordinationServiceAgent. 65 class CoordinationServiceInterface { 66 public: 67 using CoordinationServiceFactory = 68 std::function<std::unique_ptr<CoordinationServiceInterface>( 69 Env* env, const ServerDef& server_def, 70 std::unique_ptr<CoordinationClientCache> cache)>; 71 72 using StatusOrValueCallback = 73 std::function<void(const StatusOr<std::string>&)>; 74 ~CoordinationServiceInterface()75 virtual ~CoordinationServiceInterface() {} 76 RegisterCoordinationService(const std::string & service_type_name,CoordinationServiceFactory factory_fn)77 static void RegisterCoordinationService( 78 const std::string& service_type_name, 79 CoordinationServiceFactory factory_fn) { 80 auto factories = GetCoordinationServiceFactories(); 81 factories->emplace(service_type_name, factory_fn); 82 } 83 84 static std::unique_ptr<CoordinationServiceInterface> EnableCoordinationService(const std::string & service_type,Env * env,const ServerDef & server_def,std::unique_ptr<CoordinationClientCache> cache)85 EnableCoordinationService(const std::string& service_type, Env* env, 86 const ServerDef& server_def, 87 std::unique_ptr<CoordinationClientCache> cache) { 88 const auto* factories = GetCoordinationServiceFactories(); 89 auto factories_iter = factories->find(service_type); 90 if (factories_iter == factories->end()) { 91 LOG(ERROR) << "No coordination service factory found for service type " 92 << service_type; 93 return nullptr; 94 } 95 auto service = factories_iter->second(env, server_def, std::move(cache)); 96 if (service != nullptr) { 97 *GetCoordinationServiceInstancePtr() = service.get(); 98 } 99 return service; 100 } 101 GetCoordinationServiceInstance()102 static CoordinationServiceInterface* GetCoordinationServiceInstance() { 103 return *GetCoordinationServiceInstancePtr(); 104 } 105 106 // Register a task to the service. 107 virtual Status RegisterTask(const CoordinatedTask& task, 108 uint64_t incarnation) = 0; 109 110 // Wait for all tasks to be up and running, and register local device 111 // info. The callback is invoked when all tasks are up and registered, or some 112 // error occurs. 113 virtual void WaitForAllTasks(const CoordinatedTask& task, 114 const CoordinationServiceDeviceInfo& devices, 115 StatusCallback done) = 0; 116 117 // Disconnects task from the service. If `shutdown_barrier_timeout_in_ms` is 118 // specified in the config, blocks until all tasks reach the barrier before 119 // disconnecting together. 120 // Possible service errors: 121 // - InvalidArgument: Unexpected task request. 122 // - FailedPrecondition: task has already disconnected. 123 virtual void ShutdownTaskAsync(const CoordinatedTask& task, 124 StatusCallback done) = 0; 125 126 // Disconnects task from the service and cleans up its internal error state. 127 // Possible service errors: 128 // - InvalidArgument: Unexpected task request. 129 // - FailedPrecondition: task has already disconnected. 130 virtual Status ResetTask(const CoordinatedTask& task) = 0; 131 132 // Update the heartbeat timestamp of a task. This should only be invoked on 133 // the leader of the cluster. 134 virtual Status RecordHeartbeat(const CoordinatedTask& task, 135 uint64_t incarnation) = 0; 136 137 // Set a task in error state permanently. 138 virtual Status ReportTaskError(const CoordinatedTask& task, Status error) = 0; 139 140 // Insert a configuration key-value in the coordination service. 141 // For now, a key-value can only be inserted once and cannot be updated. 142 // The key-values are not persisted and will be lost if the leader fails. 143 virtual Status InsertKeyValue(const std::string& key, 144 const std::string& value) = 0; 145 146 // Get a configuration key-value from the coordination service. The `done` 147 // callback is invoked when the key-value becomes available. 148 virtual void GetKeyValueAsync(const std::string& key, 149 StatusOrValueCallback done) = 0; 150 151 // Get a configuration key-value from the coordination service. If the key 152 // does not exist, return NotFound error. 153 virtual StatusOr<std::string> TryGetKeyValue(const std::string& key) = 0; 154 155 // Gets all values under a directory (key). 156 // A value is considered to be in the directory if its key is prefixed with 157 // the directory. This is not a blocking call. Agent does not need to be 158 // connected to utilize the distributed key-value store. 159 virtual std::vector<KeyValueEntry> GetKeyValueDir( 160 absl::string_view directory_key) = 0; 161 162 // Delete configuration key-value. If key is a directory, recursively clean 163 // up all key-values under the directory. 164 virtual Status DeleteKeyValue(const std::string& key) = 0; 165 166 // Blocks until all (or a subset of) tasks are at the barrier or the barrier 167 // fails. 168 // 169 // `barrier_id` should be unique across barriers. Once the barrier has passed 170 // or failed, subsequent calls will not block, and immediately respond with 171 // the previous response. 172 // 173 // The first WaitAtBarrier() call received by the service for a particular 174 // barrier id is special in that it determines the barrier deadline based on 175 // timeout duration. 176 // However, if subsequent calls by different agents specify a different set of 177 // `participating_tasks` for the same `barrier_id`, the barrier will fail 178 // instantly. 179 // 180 // If no tasks are specified (default), the barrier will block for all the 181 // connected tasks. 182 // 183 // Possible service errors: 184 // - DeadlineExceeded: Timed out waiting for specified tasks at the barrier. 185 // Deadline is determined by the server timestamp when it receives the 186 // first WaitAtBarrier() + timeout duration. 187 // - Cancelled: One of the tasks called CancelBarrier(). 188 // - Aborted: Service is shutting down. 189 // - Internal: Any participating task is in ERROR state. 190 // - InvalidArgument: (1) Conflicting tasks specified by different agents 191 // for the same barrier, (2) one of the participating tasks is not in 192 // the cluster, or (3) task making the request is not included in the 193 // list of participating tasks. 194 // - FailedPrecondition: Agent is in UNINITIALIZED or ERROR state. 195 virtual void BarrierAsync( 196 const std::string& barrier_id, absl::Duration timeout, 197 const CoordinatedTask& task, 198 const std::vector<CoordinatedTask>& participating_tasks, 199 StatusCallback done) = 0; 200 201 // Aborts the barrier if it is ongoing. 202 // Current and future WaitAtBarrier() calls with the same id will return a 203 // CANCELLED error status. 204 // Possible service errors: 205 // - FailedPrecondition: Barrier has already been passed. 206 virtual Status CancelBarrier(const std::string& barrier_id, 207 const CoordinatedTask& task) = 0; 208 209 private: 210 friend class CoordinationServiceRpcHandler; 211 friend class CoordinationServiceTest_ListClusterDevices_TfDevice_Test; 212 friend class CoordinationServiceTest_ListClusterDevices_XlaDevice_Test; 213 friend class 214 CoordinationServiceTest_ListClusterDevices_DevicesAreNotAddedTwice_Test; 215 216 virtual const CoordinationServiceDeviceInfo& ListClusterDevices() = 0; 217 virtual uint64_t GetServiceIncarnation() = 0; 218 219 static std::unordered_map<std::string, CoordinationServiceFactory>* GetCoordinationServiceFactories()220 GetCoordinationServiceFactories() { 221 static auto* coordination_service_factories = 222 new std::unordered_map<std::string, CoordinationServiceFactory>(); 223 return coordination_service_factories; 224 } 225 GetCoordinationServiceInstancePtr()226 static CoordinationServiceInterface** GetCoordinationServiceInstancePtr() { 227 static CoordinationServiceInterface* instance = nullptr; 228 return &instance; 229 } 230 }; 231 232 } // namespace tensorflow 233 234 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_H_ 235