xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/dispatcher_state.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_
16 #define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_
17 
18 #include <memory>
19 #include <optional>
20 #include <queue>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/string_view.h"
28 #include "tensorflow/core/data/service/auto_shard_rewriter.h"
29 #include "tensorflow/core/data/service/common.h"
30 #include "tensorflow/core/data/service/common.pb.h"
31 #include "tensorflow/core/data/service/journal.h"
32 #include "tensorflow/core/data/service/journal.pb.h"
33 #include "tensorflow/core/platform/status.h"
34 #include "tensorflow/core/protobuf/data_service.pb.h"
35 #include "tensorflow/core/protobuf/service_config.pb.h"
36 
37 namespace tensorflow {
38 namespace data {
39 
40 // A class encapsulating the journaled state of the dispatcher. All state
41 // modifications must be done via `Apply`. This helps to ensure that
42 // replaying the journal will allow us to restore the exact same state.
43 //
44 // The following usage pattern will keep the journal in sync with the state of
45 // the dispatcher:
46 // {
47 //   mutex_lock l(mu_);
48 //   Update update = ...  // create an update
49 //   dispatcher_state.Apply(update);
50 //   journal_writer.write(Update);
51 //   // Unlock mu_
52 // }
53 //
54 // The division of functionality between DispatcherImpl and DispatcherState is
55 // as follows:
56 //   - DispatcherImpl is responsible for handling RPC requests, reading from
57 //     DispatcherState, and deciding what updates to apply to DispatcherState.
58 //     DispatcherImpl handles all synchronization.
59 //   - DispatcherState is responsible for making the state changes requested by
60 //     DispatcherImpl and for providing DispatcherImpl with read-only access to
61 //     the state.
62 //
63 // DispatcherState is thread-compatible but not thread-safe.
64 class DispatcherState {
65  public:
66   DispatcherState();
67   explicit DispatcherState(
68       const experimental::DispatcherConfig& dispatcher_config);
69   DispatcherState(const DispatcherState&) = delete;
70   DispatcherState& operator=(const DispatcherState&) = delete;
71 
72   // Applies the given update to the dispatcher's state.
73   Status Apply(const Update& update);
74 
75   // A dataset registered with the dispatcher.
76   struct Dataset {
DatasetDataset77     explicit Dataset(const std::string& dataset_id, int64_t fingerprint,
78                      const DataServiceMetadata& metadata)
79         : dataset_id(dataset_id),
80           fingerprint(fingerprint),
81           metadata(metadata) {}
82 
83     const std::string dataset_id;
84     const int64_t fingerprint;
85     const DataServiceMetadata metadata;
86   };
87 
88   // A worker registered with the dispatcher.
89   struct Worker {
WorkerWorker90     explicit Worker(const RegisterWorkerUpdate& register_worker)
91         : address(register_worker.worker_address()),
92           transfer_address(register_worker.transfer_address()),
93           tags(register_worker.worker_tags().begin(),
94                register_worker.worker_tags().end()),
95           uid(register_worker.worker_uid()) {}
96 
97     const std::string address;
98     const std::string transfer_address;
99     const std::vector<std::string> tags;
100     const int64_t uid;
101   };
102 
103   // A key for identifying an iteration. The key contains a job name,
104   // as well as a repetition number describing which repetition of the job
105   // we are on.
106   struct IterationKey {
IterationKeyIterationKey107     explicit IterationKey(absl::string_view name, int64_t repetition)
108         : name(name), repetition(repetition) {}
109 
110     friend bool operator==(const IterationKey& lhs, const IterationKey& rhs) {
111       return lhs.name == rhs.name && lhs.repetition == rhs.repetition;
112     }
113 
114     template <typename H>
AbslHashValueIterationKey115     friend H AbslHashValue(H h, const IterationKey& k) {
116       return H::combine(std::move(h), k.name, k.repetition);
117     }
118 
DebugStringIterationKey119     std::string DebugString() const {
120       return absl::StrCat(name, "/", repetition);
121     }
122 
123     const std::string name;
124     const int64_t repetition;
125   };
126 
127   struct DistributedEpochState {
DistributedEpochStateDistributedEpochState128     explicit DistributedEpochState(int64_t num_split_providers)
129         : repetitions(num_split_providers), indices(num_split_providers) {}
130 
131     // The current repetition for each split provider.
132     std::vector<int64_t> repetitions;
133     // Number of splits produced so far by each split provider.
134     std::vector<int64_t> indices;
135   };
136 
137   struct Task;
138 
139   struct PendingTask {
PendingTaskPendingTask140     explicit PendingTask(std::shared_ptr<Task> task, int64_t target_round)
141         : task(std::move(task)), target_round(target_round) {}
142 
143     std::shared_ptr<Task> task;
144     // The target round where we want to insert the task.
145     int64_t target_round;
146     // Which consumers have responded that they have successfully blocked
147     // before the target round.
148     absl::flat_hash_set<int64_t> ready_consumers;
149     // How many times we have failed to add the task.
150     int64_t failures = 0;
151   };
152 
153   struct Job {
JobJob154     explicit Job(int64_t id, const std::string& dataset_id,
155                  const ProcessingModeDef& processing_mode, std::string job_name,
156                  std::optional<int64_t> num_consumers,
157                  bool use_cross_trainer_cache, TargetWorkers target_workers)
158         : id(id),
159           dataset_id(dataset_id),
160           processing_mode(processing_mode),
161           job_name(job_name),
162           num_consumers(num_consumers),
163           use_cross_trainer_cache(use_cross_trainer_cache),
164           target_workers(target_workers) {}
165 
166     const int64_t id;
167     const std::string dataset_id;
168     const ProcessingModeDef processing_mode;
169     const std::string job_name;
170     const std::optional<int64_t> num_consumers;
171     const bool use_cross_trainer_cache;
172     const TargetWorkers target_workers;
173   };
174 
175   // An iteration for processing a dataset.
176   struct Iteration {
IterationIteration177     explicit Iteration(int64_t iteration_id, IterationKey iteration_key,
178                        int64_t num_split_providers, std::shared_ptr<Job> job)
179         : iteration_id(iteration_id), iteration_key(iteration_key), job(job) {
180       if (IsDynamicShard(job->processing_mode)) {
181         distributed_epoch_state = DistributedEpochState(num_split_providers);
182       }
183     }
184 
IsRoundRobinIteration185     bool IsRoundRobin() const { return job->num_consumers.has_value(); }
186 
DebugStringIteration187     std::string DebugString() const {
188       return absl::StrCat(iteration_key.name, "_", iteration_key.repetition);
189     }
190 
191     const int64_t iteration_id;
192     const IterationKey iteration_key;
193     const std::shared_ptr<Job> job;
194     std::optional<DistributedEpochState> distributed_epoch_state;
195     std::queue<PendingTask> pending_tasks;
196     int64_t num_clients = 0;
197     int64_t last_client_released_micros = -1;
198     bool finished = false;
199     // Indicates whether the iteration was garbage collected.
200     bool garbage_collected = false;
201   };
202 
203   struct Task {
204     template <class T>
TaskTask205     explicit Task(const T& create_task_update,
206                   const std::shared_ptr<Iteration>& iteration)
207         : task_id(create_task_update.task_id()),
208           iteration(iteration),
209           worker_address(create_task_update.worker_address()),
210           transfer_address(create_task_update.transfer_address()),
211           worker_tags(create_task_update.worker_tags().begin(),
212                       create_task_update.worker_tags().end()),
213           worker_uid(create_task_update.worker_uid()) {}
214 
215     const int64_t task_id;
216     const std::shared_ptr<Iteration> iteration;
217     const std::string worker_address;
218     const std::string transfer_address;
219     const std::vector<std::string> worker_tags;
220     const int64_t worker_uid;
221     int64_t starting_round = 0;
222     bool finished = false;
223     bool removed = false;
224   };
225 
226   using TasksById = absl::flat_hash_map<int64_t, std::shared_ptr<Task>>;
227 
228   // Returns the next available dataset ID.
229   std::string NextAvailableDatasetId() const;
230 
231   // Gets a dataset by id. Returns NOT_FOUND if there is no such dataset.
232   Status DatasetFromId(const std::string& id,
233                        std::shared_ptr<const Dataset>& dataset) const;
234   // Gets a dataset by fingerprint. Returns NOT_FOUND if there is no such
235   // dataset.
236   Status DatasetFromFingerprint(uint64 fingerprint,
237                                 std::shared_ptr<const Dataset>& dataset) const;
238 
239   // Gets a worker by address. Returns NOT_FOUND if there is no such worker.
240   Status WorkerFromAddress(const std::string& address,
241                            std::shared_ptr<const Worker>& worker) const;
242   // Lists all workers registered with the dispatcher.
243   std::vector<std::shared_ptr<const Worker>> ListWorkers() const;
244 
245   // Returns the next available job id.
246   int64_t NextAvailableJobId() const;
247   // Gets a job by id. Returns NOT_FOUND if there is no such job.
248   Status JobFromId(int64_t job_id, std::shared_ptr<const Job>& job) const;
249   // Gets a job by name. Returns NOT_FOUND if there is no such job.
250   Status JobByName(const std::string& job_name,
251                    std::shared_ptr<const Job>& job) const;
252 
253   // Returns the next available iteration id.
254   int64_t NextAvailableIterationId() const;
255   // Returns a list of all iterations.
256   std::vector<std::shared_ptr<const Iteration>> ListIterations() const;
257   // Gets an iteration by id. Returns NOT_FOUND if there is no such iteration.
258   Status IterationFromId(int64_t id,
259                          std::shared_ptr<const Iteration>& iteration) const;
260   // Gets an iteration by key. Returns NOT_FOUND if there is no such iteration.
261   Status IterationByKey(IterationKey key,
262                         std::shared_ptr<const Iteration>& iteration) const;
263 
264   // Returns the iteration associated with the given iteration client id.
265   // Returns NOT_FOUND if the iteration_client_id is unknown or has been
266   // released.
267   Status IterationForIterationClientId(
268       int64_t iteration_client_id, std::shared_ptr<const Iteration>& iteration);
269   // Returns a list of all active client ids.
270   std::vector<int64_t> ListActiveClientIds();
271   // Returns the next available iteration client id.
272   int64_t NextAvailableIterationClientId() const;
273 
274   // Returns the next available task id.
275   int64_t NextAvailableTaskId() const;
276   // Gets a task by id. Returns NOT_FOUND if there is no such task.
277   Status TaskFromId(int64_t id, std::shared_ptr<const Task>& task) const;
278   // Stores a list of all tasks for the given iteration to `tasks`. Returns
279   // NOT_FOUND if there is no such iteration.
280   Status TasksForIteration(
281       int64_t iteration_id,
282       std::vector<std::shared_ptr<const Task>>& tasks) const;
283   // Stores a list of all tasks for the given worker to `tasks`. Returns
284   // NOT_FOUND if there is no such worker.
285   Status TasksForWorker(const absl::string_view worker_address,
286                         std::vector<std::shared_ptr<const Task>>& tasks) const;
287 
288   // If the dispatcher config explicitly specifies a list of workers, validates
289   // `worker_address` is in the list.
290   Status ValidateWorker(absl::string_view worker_address) const;
291 
292   // If the dispatcher config specifies worker addresses, `GetWorkerIndex`
293   // returns the worker index according to the list. This is useful for
294   // deterministically sharding a dataset among a fixed set of workers.
295   StatusOr<int64_t> GetWorkerIndex(absl::string_view worker_address) const;
296 
297  private:
298   void RegisterDataset(const RegisterDatasetUpdate& register_dataset);
299   void RegisterWorker(const RegisterWorkerUpdate& register_worker);
300   void CreateJob(const CreateJobUpdate& create_job);
301   void CreateIteration(const CreateIterationUpdate& create_iteration);
302   void ProduceSplit(const ProduceSplitUpdate& produce_split);
303   void AcquireIterationClient(
304       const AcquireIterationClientUpdate& acquire_iteration_client);
305   void ReleaseIterationClient(
306       const ReleaseIterationClientUpdate& release_iteration_client);
307   void GarbageCollectIteration(
308       const GarbageCollectIterationUpdate& garbage_collect_iteration);
309   void RemoveTask(const RemoveTaskUpdate& remove_task);
310   void CreatePendingTask(const CreatePendingTaskUpdate& create_pending_task);
311   void ClientHeartbeat(const ClientHeartbeatUpdate& client_heartbeat);
312   void CreateTask(const CreateTaskUpdate& create_task);
313   void FinishTask(const FinishTaskUpdate& finish_task);
314   // Updates the next available dataset ID.
315   void UpdateNextAvailableDatasetId();
316 
317   int64_t next_available_dataset_id_ = 1000;
318   // Registered datasets, keyed by dataset ids.
319   absl::flat_hash_map<std::string, std::shared_ptr<Dataset>> datasets_by_id_;
320   // Registered datasets, keyed by dataset fingerprints.
321   absl::flat_hash_map<uint64, std::shared_ptr<Dataset>>
322       datasets_by_fingerprint_;
323 
324   // Registered workers, keyed by address.
325   absl::flat_hash_map<std::string, std::shared_ptr<Worker>> workers_;
326 
327   // Assigns an index to each worker according to worker addresses list
328   // specified in the dispatcher config.
329   WorkerIndexResolver worker_index_resolver_;
330 
331   int64_t next_available_job_id_ = 5000;
332   // Jobs, keyed by job ids.
333   absl::flat_hash_map<int64_t, std::shared_ptr<Job>> jobs_by_id_;
334   // Jobs, keyed by job names.
335   absl::flat_hash_map<std::string, std::shared_ptr<Job>> jobs_by_name_;
336 
337   int64_t next_available_iteration_id_ = 2000;
338   // Iterations, keyed by iteration ids.
339   absl::flat_hash_map<int64_t, std::shared_ptr<Iteration>> iterations_;
340   // Iterations, keyed by their iteration keys.
341   absl::flat_hash_map<IterationKey, std::shared_ptr<Iteration>>
342       iterations_by_key_;
343 
344   int64_t next_available_iteration_client_id_ = 3000;
345   // Mapping from client ids to the iterations they are associated with.
346   absl::flat_hash_map<int64_t, std::shared_ptr<Iteration>>
347       iterations_for_client_ids_;
348 
349   int64_t next_available_task_id_ = 4000;
350   // Tasks, keyed by task ids.
351   TasksById tasks_;
352   // List of tasks associated with each iteration.
353   absl::flat_hash_map<int64_t, std::vector<std::shared_ptr<Task>>>
354       tasks_by_iteration_;
355   // Tasks, keyed by worker addresses. The values are a map from task id to
356   // task.
357   absl::flat_hash_map<std::string, TasksById> tasks_by_worker_;
358 };
359 
360 }  // namespace data
361 }  // namespace tensorflow
362 
363 #endif  // TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_STATE_H_
364