1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_AGENT_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_AGENT_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <string> 22 #include <utility> 23 24 #include "absl/time/time.h" 25 #include "tensorflow/core/distributed_runtime/coordination/coordination_client.h" 26 #include "tensorflow/core/platform/status.h" 27 #include "tensorflow/core/platform/statusor.h" 28 #include "tensorflow/core/protobuf/coordination_service.pb.h" 29 30 namespace tensorflow { 31 class CoordinationServiceConfig; 32 class CoordinatedTask; 33 class Env; 34 class ServerDef; 35 36 // CoordinationServiceAgent defines the interface for tasks to communicate with 37 // the coordination service instance (which implements 38 // CoordinationServiceInterface). One instance of the agent should be deployed 39 // on each task for it to send various requests and stores / retrieves config 40 // key-value data to the service. 41 // 42 // See CoordinationServiceInterface for more details on coordination service. 43 // 44 // All coordination service errors will have an additional 45 // CoordinationServiceError payload to distinguish themselves from RPC failures. 46 // The payload can optionally specify the error origin, and if the error is 47 // reported by the user via `agent->ReportError()`. 48 // 49 // Possible service errors: 50 // - errors::Internal: Coordination service is not enabled. 51 // If it was previously accessible, coordination service 52 // has been shut down. 53 // - errors::Aborted: Incarnation mismatch during heartbeat (either remote 54 // task or coordination service has restarted). 55 // - errors::Unavailable: Heartbeat timeout from remote task (failed, 56 // crashed or got preempted). 57 // - errors::InvalidArgument: Unexpected heartbeat from remote task (not 58 // registered or wrong config). 59 class CoordinationServiceAgent { 60 public: 61 using StatusOrValueCallback = 62 std::function<void(const StatusOr<std::string>&)>; 63 // Collection of key-value pairs in the same directory. 64 using StatusOrValueDirCallback = 65 std::function<void(const StatusOr<std::vector<KeyValueEntry>>&)>; 66 using ChangedKeyValuesCallback = 67 std::function<void(const std::map<std::string, std::string>&)>; 68 ~CoordinationServiceAgent()69 virtual ~CoordinationServiceAgent() {} 70 71 // Initialize coordination service agent. 72 virtual Status Initialize( 73 Env* env, const ServerDef& server_def, 74 std::unique_ptr<CoordinationClientCache> client_cache, 75 StatusCallback error_fn) = 0; 76 virtual Status Initialize(Env* env, const std::string& job_name, int task_id, 77 const CoordinationServiceConfig& configs, 78 std::unique_ptr<CoordinationClient> leader_client, 79 StatusCallback error_fn) = 0; 80 virtual Status Initialize(Env* env, const CoordinatedTask& task, 81 const CoordinationServiceConfig& configs, 82 std::unique_ptr<CoordinationClient> leader_client, 83 StatusCallback error_fn) = 0; 84 85 // Return true if the coordination service agent has been initialized. 86 virtual bool IsInitialized() = 0; 87 88 // Connect to coordination service with the following steps: 89 // - connect to service address specified in the config of `server_def` 90 // - register itself as a task to the service 91 // - start a thread to periodically send heartbeat message with the service 92 // Possible service errors: 93 // - FailedPrecondition: Agent is not in DISCONNECTED state. 94 // - InvalidArgument: Unexpected task registration 95 // - Aborted: Duplicate task registration 96 virtual Status Connect() = 0; 97 98 // Wait for all tasks to be up and registered. The call blocks until all tasks 99 // in the cluster are up, or some error occurs. 100 // Possible service errors: 101 // - FailedPrecondition: Agent is not in CONNECTED state. 102 // - InvalidArgument: Unexpected task request 103 virtual Status WaitForAllTasks( 104 const CoordinationServiceDeviceInfo& local_devices) = 0; 105 106 // Get the device attributes of tasks from remote tasks in the cluster. 107 virtual const CoordinationServiceDeviceInfo& GetClusterDeviceInfo() = 0; 108 109 // State transition in coordination service agent: 110 // 111 // Init Connect SetError 112 // UNINITIALIZED ---> DISCONNECTED ------> CONNECTED -------> ERROR 113 // ^ | 114 // |__________________________________| 115 // Reset 116 117 // Get task associated with this agent. 118 virtual StatusOr<CoordinatedTask> GetOwnTask() = 0; 119 120 // Get status of a remote task. 121 virtual StatusOr<CoordinatedTaskState> GetTaskStatus( 122 const CoordinatedTask& task) = 0; 123 124 // Report error to coordination service. This will invoke the error callback. 125 // Note that the error payload will set `is_reported_error` to true, to 126 // distinguish user-specified errors from internal service or RPC failures. 127 // Possible service errors: 128 // - FailedPrecondition: Uninitialized/disconnected/already in error state. 129 // - InvalidArgument: Unexpected task request 130 virtual Status ReportError(const Status& error) = 0; 131 132 // Shuts down by disconnecting from the service. Should only be called if 133 // agent is connected and no further agent calls (except the destructor) are 134 // expected. If `shutdown_barrier_timeout_in_ms` is specified in the config, 135 // blocks until all tasks reach the barrier before shutting down together. If 136 // the barrier times out, this agent will still disconnect, while an error is 137 // reported to other agents that did not reach the barrier on time. 138 // Possible service errors: 139 // - InvalidArgument: Unexpected task request. 140 // - FailedPrecondition: Task was in error state (note: agent is still 141 // shut down forcefully). 142 virtual Status Shutdown() = 0; 143 144 // Disconnect from the service, and clean up the internal error status. 145 // Possible service errors: 146 // - InvalidArgument: Unexpected task request. 147 // - FailedPrecondition: task is not in error state/has already 148 // disconnected. 149 virtual Status Reset() = 0; 150 151 // Get config key-value from the service. 152 // If the key-value is not inserted yet, this is a blocking call that waits 153 // until the corresponding key is inserted. 154 // Agent does not need to be connected to utilize the distributed key-value 155 // store. 156 // - errors::DeadlineExceeded: timed out waiting for key. 157 virtual StatusOr<std::string> GetKeyValue(const std::string& key) = 0; 158 virtual StatusOr<std::string> GetKeyValue(const std::string& key, 159 absl::Duration timeout) = 0; 160 // Note: Cancel the underlying RPC call with `call_opts->StartCancel()` and 161 // `call_opts->ClearCancelCallback()`. 162 virtual std::shared_ptr<CallOptions> GetKeyValueAsync( 163 const std::string& key, StatusOrValueCallback done) = 0; 164 165 // Get config key-value from the service. 166 // If the key-value does not exist, this call returns NotFound error. 167 // Agent does not need to be connected to utilize the distributed key-value 168 // store. 169 // - errors::NotFound: the requested key does not exist. 170 virtual StatusOr<std::string> TryGetKeyValue(const std::string& key) = 0; 171 172 // Get all values under a directory (key). 173 // A value is considered to be in the directory if its key is prefixed with 174 // the directory. 175 // This is not a blocking call. 176 // Agent does not need to be connected to utilize the distributed key-value 177 // store. 178 virtual StatusOr<std::vector<KeyValueEntry>> GetKeyValueDir( 179 const std::string& key) = 0; 180 virtual void GetKeyValueDirAsync(const std::string& key, 181 StatusOrValueDirCallback done) = 0; 182 183 // Insert config key-value to the service. 184 // - errors::AlreadyExists: key is already set. 185 virtual Status InsertKeyValue(const std::string& key, 186 const std::string& value) = 0; 187 188 // Delete config keys in the coordination service. 189 virtual Status DeleteKeyValue(const std::string& key) = 0; 190 191 // Update the value of a config key. 192 virtual Status UpdateKeyValue(const std::string& key, 193 const std::string& value) = 0; 194 195 // Register a callback that will be invoked when the key or keys under the key 196 // directory are changed (inserted, deleted, or updated). 197 virtual Status StartWatchKey(const std::string& key, 198 ChangedKeyValuesCallback on_change) = 0; 199 virtual Status StopWatchKey(const std::string& key) = 0; 200 201 // Blocks until all (or a subset of) tasks are at the barrier or the barrier 202 // fails. 203 // 204 // `barrier_id` should be unique across barriers. 205 // 206 // The first WaitAtBarrier() call received by the service for a particular 207 // barrier_id is special in that it determines the barrier deadline based on 208 // timeout duration. 209 // However, if subsequent calls by different agents specify a different set of 210 // `tasks` for the same `barrier_id`, the barrier will fail instantly. 211 // For example, 212 // agent_1->WaitAtBarrier(“barrier”, 10min, <<”worker”, 1>, <”worker”, 2>>); 213 // agent_2->WaitAtBarrier(“barrier”, 10min, <<”worker”, 2>, <”worker”, 3>>); 214 // Barrier fails after agent_2’s call because it specifies a different set of 215 // participating tasks. 216 // 217 // If no tasks are specified (default), the barrier will block for all the 218 // connected tasks. 219 // 220 // Possible service errors: 221 // - DeadlineExceeded: Timed out waiting for specified tasks at the barrier. 222 // Deadline is determined by the server timestamp when it receives the 223 // first WaitAtBarrier() + timeout duration. 224 // - Cancelled: One of the tasks called CancelBarrier(). 225 // - Aborted: Service is shutting down. 226 // - Internal: Any participating task is in ERROR state. 227 // - InvalidArgument: (1) Conflicting tasks specified by different agents 228 // for the same barrier, (2) one of the participating tasks is not in 229 // the cluster, or (3) task making the request is not included in the 230 // list of participating tasks. 231 // - FailedPrecondition: Agent is in UNINITIALIZED or ERROR state. Or the 232 // same barrier_id was already used previously. 233 virtual Status WaitAtBarrier(const std::string& barrier_id, 234 absl::Duration timeout, 235 const std::vector<CoordinatedTask>& tasks) = 0; 236 237 virtual void WaitAtBarrierAsync(const std::string& barrier_id, 238 absl::Duration timeout, 239 const std::vector<CoordinatedTask>& tasks, 240 StatusCallback done) = 0; 241 242 // Aborts the barrier if it is ongoing. 243 // Current and future WaitAtBarrier() calls with the same id will return a 244 // CANCELLED error status. 245 // Possible service errors: 246 // - FailedPrecondition: Barrier has already been passed. 247 virtual Status CancelBarrier(const std::string& barrier_id) = 0; 248 virtual void CancelBarrierAsync(const std::string& barrier_id, 249 StatusCallback done) = 0; 250 251 // Get unowned Env* that the agent was initialized with. 252 virtual StatusOr<Env*> GetEnv() = 0; 253 254 protected: 255 // Set the service agent to error status and invoke the error callback. 256 // Note: different from ReportError, this does not report the error status to 257 // remote coordination service. 258 virtual void SetError(const Status& error) = 0; 259 260 // Activate the key-value callback watch. 261 virtual Status ActivateWatch(const std::string& key, 262 const std::map<std::string, std::string>&) = 0; 263 264 private: 265 friend class CoordinationServiceRpcHandler; 266 }; 267 268 std::unique_ptr<CoordinationServiceAgent> CreateCoordinationServiceAgent(); 269 270 } // namespace tensorflow 271 272 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_AGENT_H_ 273