1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_ 16 #define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_ 17 18 #include <memory> 19 #include <optional> 20 #include <queue> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 25 #include "absl/container/flat_hash_map.h" 26 #include "absl/container/flat_hash_set.h" 27 #include "absl/strings/string_view.h" 28 #include "tensorflow/core/data/service/auto_shard_rewriter.h" 29 #include "tensorflow/core/data/service/common.h" 30 #include "tensorflow/core/data/service/common.pb.h" 31 #include "tensorflow/core/data/service/journal.h" 32 #include "tensorflow/core/data/service/journal.pb.h" 33 #include "tensorflow/core/platform/status.h" 34 #include "tensorflow/core/protobuf/data_service.pb.h" 35 #include "tensorflow/core/protobuf/service_config.pb.h" 36 37 namespace tensorflow { 38 namespace data { 39 40 // A class encapsulating the journaled state of the dispatcher. All state 41 // modifications must be done via `Apply`. This helps to ensure that 42 // replaying the journal will allow us to restore the exact same state. 43 // 44 // The following usage pattern will keep the journal in sync with the state of 45 // the dispatcher: 46 // { 47 // mutex_lock l(mu_); 48 // Update update = ... // create an update 49 // dispatcher_state.Apply(update); 50 // journal_writer.write(Update); 51 // // Unlock mu_ 52 // } 53 // 54 // The division of functionality between DispatcherImpl and DispatcherState is 55 // as follows: 56 // - DispatcherImpl is responsible for handling RPC requests, reading from 57 // DispatcherState, and deciding what updates to apply to DispatcherState. 58 // DispatcherImpl handles all synchronization. 59 // - DispatcherState is responsible for making the state changes requested by 60 // DispatcherImpl and for providing DispatcherImpl with read-only access to 61 // the state. 62 // 63 // DispatcherState is thread-compatible but not thread-safe. 64 class DispatcherState { 65 public: 66 DispatcherState(); 67 explicit DispatcherState( 68 const experimental::DispatcherConfig& dispatcher_config); 69 DispatcherState(const DispatcherState&) = delete; 70 DispatcherState& operator=(const DispatcherState&) = delete; 71 72 // Applies the given update to the dispatcher's state. 73 Status Apply(const Update& update); 74 75 // A dataset registered with the dispatcher. 76 struct Dataset { DatasetDataset77 explicit Dataset(const std::string& dataset_id, int64_t fingerprint, 78 const DataServiceMetadata& metadata) 79 : dataset_id(dataset_id), 80 fingerprint(fingerprint), 81 metadata(metadata) {} 82 83 const std::string dataset_id; 84 const int64_t fingerprint; 85 const DataServiceMetadata metadata; 86 }; 87 88 // A worker registered with the dispatcher. 89 struct Worker { WorkerWorker90 explicit Worker(const RegisterWorkerUpdate& register_worker) 91 : address(register_worker.worker_address()), 92 transfer_address(register_worker.transfer_address()), 93 tags(register_worker.worker_tags().begin(), 94 register_worker.worker_tags().end()), 95 uid(register_worker.worker_uid()) {} 96 97 const std::string address; 98 const std::string transfer_address; 99 const std::vector<std::string> tags; 100 const int64_t uid; 101 }; 102 103 // A key for identifying an iteration. The key contains a job name, 104 // as well as a repetition number describing which repetition of the job 105 // we are on. 106 struct IterationKey { IterationKeyIterationKey107 explicit IterationKey(absl::string_view name, int64_t repetition) 108 : name(name), repetition(repetition) {} 109 110 friend bool operator==(const IterationKey& lhs, const IterationKey& rhs) { 111 return lhs.name == rhs.name && lhs.repetition == rhs.repetition; 112 } 113 114 template <typename H> AbslHashValueIterationKey115 friend H AbslHashValue(H h, const IterationKey& k) { 116 return H::combine(std::move(h), k.name, k.repetition); 117 } 118 DebugStringIterationKey119 std::string DebugString() const { 120 return absl::StrCat(name, "/", repetition); 121 } 122 123 const std::string name; 124 const int64_t repetition; 125 }; 126 127 struct DistributedEpochState { DistributedEpochStateDistributedEpochState128 explicit DistributedEpochState(int64_t num_split_providers) 129 : repetitions(num_split_providers), indices(num_split_providers) {} 130 131 // The current repetition for each split provider. 132 std::vector<int64_t> repetitions; 133 // Number of splits produced so far by each split provider. 134 std::vector<int64_t> indices; 135 }; 136 137 struct Task; 138 139 struct PendingTask { PendingTaskPendingTask140 explicit PendingTask(std::shared_ptr<Task> task, int64_t target_round) 141 : task(std::move(task)), target_round(target_round) {} 142 143 std::shared_ptr<Task> task; 144 // The target round where we want to insert the task. 145 int64_t target_round; 146 // Which consumers have responded that they have successfully blocked 147 // before the target round. 148 absl::flat_hash_set<int64_t> ready_consumers; 149 // How many times we have failed to add the task. 150 int64_t failures = 0; 151 }; 152 153 struct Job { JobJob154 explicit Job(int64_t id, const std::string& dataset_id, 155 const ProcessingModeDef& processing_mode, std::string job_name, 156 std::optional<int64_t> num_consumers, 157 bool use_cross_trainer_cache, TargetWorkers target_workers) 158 : id(id), 159 dataset_id(dataset_id), 160 processing_mode(processing_mode), 161 job_name(job_name), 162 num_consumers(num_consumers), 163 use_cross_trainer_cache(use_cross_trainer_cache), 164 target_workers(target_workers) {} 165 166 const int64_t id; 167 const std::string dataset_id; 168 const ProcessingModeDef processing_mode; 169 const std::string job_name; 170 const std::optional<int64_t> num_consumers; 171 const bool use_cross_trainer_cache; 172 const TargetWorkers target_workers; 173 }; 174 175 // An iteration for processing a dataset. 176 struct Iteration { IterationIteration177 explicit Iteration(int64_t iteration_id, IterationKey iteration_key, 178 int64_t num_split_providers, std::shared_ptr<Job> job) 179 : iteration_id(iteration_id), iteration_key(iteration_key), job(job) { 180 if (IsDynamicShard(job->processing_mode)) { 181 distributed_epoch_state = DistributedEpochState(num_split_providers); 182 } 183 } 184 IsRoundRobinIteration185 bool IsRoundRobin() const { return job->num_consumers.has_value(); } 186 DebugStringIteration187 std::string DebugString() const { 188 return absl::StrCat(iteration_key.name, "_", iteration_key.repetition); 189 } 190 191 const int64_t iteration_id; 192 const IterationKey iteration_key; 193 const std::shared_ptr<Job> job; 194 std::optional<DistributedEpochState> distributed_epoch_state; 195 std::queue<PendingTask> pending_tasks; 196 int64_t num_clients = 0; 197 int64_t last_client_released_micros = -1; 198 bool finished = false; 199 // Indicates whether the iteration was garbage collected. 200 bool garbage_collected = false; 201 }; 202 203 struct Task { 204 template <class T> TaskTask205 explicit Task(const T& create_task_update, 206 const std::shared_ptr<Iteration>& iteration) 207 : task_id(create_task_update.task_id()), 208 iteration(iteration), 209 worker_address(create_task_update.worker_address()), 210 transfer_address(create_task_update.transfer_address()), 211 worker_tags(create_task_update.worker_tags().begin(), 212 create_task_update.worker_tags().end()), 213 worker_uid(create_task_update.worker_uid()) {} 214 215 const int64_t task_id; 216 const std::shared_ptr<Iteration> iteration; 217 const std::string worker_address; 218 const std::string transfer_address; 219 const std::vector<std::string> worker_tags; 220 const int64_t worker_uid; 221 int64_t starting_round = 0; 222 bool finished = false; 223 bool removed = false; 224 }; 225 226 using TasksById = absl::flat_hash_map<int64_t, std::shared_ptr<Task>>; 227 228 // Returns the next available dataset ID. 229 std::string NextAvailableDatasetId() const; 230 231 // Gets a dataset by id. Returns NOT_FOUND if there is no such dataset. 232 Status DatasetFromId(const std::string& id, 233 std::shared_ptr<const Dataset>& dataset) const; 234 // Gets a dataset by fingerprint. Returns NOT_FOUND if there is no such 235 // dataset. 236 Status DatasetFromFingerprint(uint64 fingerprint, 237 std::shared_ptr<const Dataset>& dataset) const; 238 239 // Gets a worker by address. Returns NOT_FOUND if there is no such worker. 240 Status WorkerFromAddress(const std::string& address, 241 std::shared_ptr<const Worker>& worker) const; 242 // Lists all workers registered with the dispatcher. 243 std::vector<std::shared_ptr<const Worker>> ListWorkers() const; 244 245 // Returns the next available job id. 246 int64_t NextAvailableJobId() const; 247 // Gets a job by id. Returns NOT_FOUND if there is no such job. 248 Status JobFromId(int64_t job_id, std::shared_ptr<const Job>& job) const; 249 // Gets a job by name. Returns NOT_FOUND if there is no such job. 250 Status JobByName(const std::string& job_name, 251 std::shared_ptr<const Job>& job) const; 252 253 // Returns the next available iteration id. 254 int64_t NextAvailableIterationId() const; 255 // Returns a list of all iterations. 256 std::vector<std::shared_ptr<const Iteration>> ListIterations() const; 257 // Gets an iteration by id. Returns NOT_FOUND if there is no such iteration. 258 Status IterationFromId(int64_t id, 259 std::shared_ptr<const Iteration>& iteration) const; 260 // Gets an iteration by key. Returns NOT_FOUND if there is no such iteration. 261 Status IterationByKey(IterationKey key, 262 std::shared_ptr<const Iteration>& iteration) const; 263 264 // Returns the iteration associated with the given iteration client id. 265 // Returns NOT_FOUND if the iteration_client_id is unknown or has been 266 // released. 267 Status IterationForIterationClientId( 268 int64_t iteration_client_id, std::shared_ptr<const Iteration>& iteration); 269 // Returns a list of all active client ids. 270 std::vector<int64_t> ListActiveClientIds(); 271 // Returns the next available iteration client id. 272 int64_t NextAvailableIterationClientId() const; 273 274 // Returns the next available task id. 275 int64_t NextAvailableTaskId() const; 276 // Gets a task by id. Returns NOT_FOUND if there is no such task. 277 Status TaskFromId(int64_t id, std::shared_ptr<const Task>& task) const; 278 // Stores a list of all tasks for the given iteration to `tasks`. Returns 279 // NOT_FOUND if there is no such iteration. 280 Status TasksForIteration( 281 int64_t iteration_id, 282 std::vector<std::shared_ptr<const Task>>& tasks) const; 283 // Stores a list of all tasks for the given worker to `tasks`. Returns 284 // NOT_FOUND if there is no such worker. 285 Status TasksForWorker(const absl::string_view worker_address, 286 std::vector<std::shared_ptr<const Task>>& tasks) const; 287 288 // If the dispatcher config explicitly specifies a list of workers, validates 289 // `worker_address` is in the list. 290 Status ValidateWorker(absl::string_view worker_address) const; 291 292 // If the dispatcher config specifies worker addresses, `GetWorkerIndex` 293 // returns the worker index according to the list. This is useful for 294 // deterministically sharding a dataset among a fixed set of workers. 295 StatusOr<int64_t> GetWorkerIndex(absl::string_view worker_address) const; 296 297 private: 298 void RegisterDataset(const RegisterDatasetUpdate& register_dataset); 299 void RegisterWorker(const RegisterWorkerUpdate& register_worker); 300 void CreateJob(const CreateJobUpdate& create_job); 301 void CreateIteration(const CreateIterationUpdate& create_iteration); 302 void ProduceSplit(const ProduceSplitUpdate& produce_split); 303 void AcquireIterationClient( 304 const AcquireIterationClientUpdate& acquire_iteration_client); 305 void ReleaseIterationClient( 306 const ReleaseIterationClientUpdate& release_iteration_client); 307 void GarbageCollectIteration( 308 const GarbageCollectIterationUpdate& garbage_collect_iteration); 309 void RemoveTask(const RemoveTaskUpdate& remove_task); 310 void CreatePendingTask(const CreatePendingTaskUpdate& create_pending_task); 311 void ClientHeartbeat(const ClientHeartbeatUpdate& client_heartbeat); 312 void CreateTask(const CreateTaskUpdate& create_task); 313 void FinishTask(const FinishTaskUpdate& finish_task); 314 // Updates the next available dataset ID. 315 void UpdateNextAvailableDatasetId(); 316 317 int64_t next_available_dataset_id_ = 1000; 318 // Registered datasets, keyed by dataset ids. 319 absl::flat_hash_map<std::string, std::shared_ptr<Dataset>> datasets_by_id_; 320 // Registered datasets, keyed by dataset fingerprints. 321 absl::flat_hash_map<uint64, std::shared_ptr<Dataset>> 322 datasets_by_fingerprint_; 323 324 // Registered workers, keyed by address. 325 absl::flat_hash_map<std::string, std::shared_ptr<Worker>> workers_; 326 327 // Assigns an index to each worker according to worker addresses list 328 // specified in the dispatcher config. 329 WorkerIndexResolver worker_index_resolver_; 330 331 int64_t next_available_job_id_ = 5000; 332 // Jobs, keyed by job ids. 333 absl::flat_hash_map<int64_t, std::shared_ptr<Job>> jobs_by_id_; 334 // Jobs, keyed by job names. 335 absl::flat_hash_map<std::string, std::shared_ptr<Job>> jobs_by_name_; 336 337 int64_t next_available_iteration_id_ = 2000; 338 // Iterations, keyed by iteration ids. 339 absl::flat_hash_map<int64_t, std::shared_ptr<Iteration>> iterations_; 340 // Iterations, keyed by their iteration keys. 341 absl::flat_hash_map<IterationKey, std::shared_ptr<Iteration>> 342 iterations_by_key_; 343 344 int64_t next_available_iteration_client_id_ = 3000; 345 // Mapping from client ids to the iterations they are associated with. 346 absl::flat_hash_map<int64_t, std::shared_ptr<Iteration>> 347 iterations_for_client_ids_; 348 349 int64_t next_available_task_id_ = 4000; 350 // Tasks, keyed by task ids. 351 TasksById tasks_; 352 // List of tasks associated with each iteration. 353 absl::flat_hash_map<int64_t, std::vector<std::shared_ptr<Task>>> 354 tasks_by_iteration_; 355 // Tasks, keyed by worker addresses. The values are a map from task id to 356 // task. 357 absl::flat_hash_map<std::string, TasksById> tasks_by_worker_; 358 }; 359 360 } // namespace data 361 } // namespace tensorflow 362 363 #endif // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_ 364