xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/dispatcher_state.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/dispatcher_state.h"
16 
17 #include <algorithm>
18 #include <memory>
19 #include <optional>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/string_view.h"
26 #include "tensorflow/core/data/service/common.h"
27 #include "tensorflow/core/data/service/journal.h"
28 #include "tensorflow/core/data/service/journal.pb.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/status.h"
31 #include "tensorflow/core/protobuf/data_service.pb.h"
32 #include "tensorflow/core/protobuf/service_config.pb.h"
33 
34 namespace tensorflow {
35 namespace data {
36 
DispatcherState()37 DispatcherState::DispatcherState()
38     : worker_index_resolver_(std::vector<std::string>{}) {}
39 
DispatcherState(const experimental::DispatcherConfig & dispatcher_config)40 DispatcherState::DispatcherState(
41     const experimental::DispatcherConfig& dispatcher_config)
42     : worker_index_resolver_(dispatcher_config.worker_addresses()) {}
43 
Apply(const Update & update)44 Status DispatcherState::Apply(const Update& update) {
45   switch (update.update_type_case()) {
46     case Update::kRegisterDataset:
47       RegisterDataset(update.register_dataset());
48       break;
49     case Update::kRegisterWorker:
50       RegisterWorker(update.register_worker());
51       break;
52     case Update::kCreateJob:
53       CreateJob(update.create_job());
54       break;
55     case Update::kCreateIteration:
56       CreateIteration(update.create_iteration());
57       break;
58     case Update::kProduceSplit:
59       ProduceSplit(update.produce_split());
60       break;
61     case Update::kAcquireIterationClient:
62       AcquireIterationClient(update.acquire_iteration_client());
63       break;
64     case Update::kReleaseIterationClient:
65       ReleaseIterationClient(update.release_iteration_client());
66       break;
67     case Update::kGarbageCollectIteration:
68       GarbageCollectIteration(update.garbage_collect_iteration());
69       break;
70     case Update::kRemoveTask:
71       RemoveTask(update.remove_task());
72       break;
73     case Update::kCreatePendingTask:
74       CreatePendingTask(update.create_pending_task());
75       break;
76     case Update::kClientHeartbeat:
77       ClientHeartbeat(update.client_heartbeat());
78       break;
79     case Update::kCreateTask:
80       CreateTask(update.create_task());
81       break;
82     case Update::kFinishTask:
83       FinishTask(update.finish_task());
84       break;
85     case Update::UPDATE_TYPE_NOT_SET:
86       return errors::Internal("Update type not set.");
87   }
88 
89   return OkStatus();
90 }
91 
RegisterDataset(const RegisterDatasetUpdate & register_dataset)92 void DispatcherState::RegisterDataset(
93     const RegisterDatasetUpdate& register_dataset) {
94   std::string dataset_id = register_dataset.dataset_id();
95   int64_t fingerprint = register_dataset.fingerprint();
96   auto dataset = std::make_shared<Dataset>(dataset_id, fingerprint,
97                                            register_dataset.metadata());
98   DCHECK(!datasets_by_id_.contains(dataset_id));
99   datasets_by_id_[dataset_id] = dataset;
100   if (!register_dataset.dedupe_by_dataset_id()) {
101     // Only stores the fingerprint if the user has not requested a dataset ID.
102     // If the user has requested a dataset ID, we will look up datasets by their
103     // IDs, not by fingerprints. Otherwise, an anonymous dataset can refer to
104     // a dataset with an explicit dataset ID.
105     DCHECK(!datasets_by_fingerprint_.contains(fingerprint));
106     datasets_by_fingerprint_[fingerprint] = dataset;
107   }
108   UpdateNextAvailableDatasetId();
109 }
110 
RegisterWorker(const RegisterWorkerUpdate & register_worker)111 void DispatcherState::RegisterWorker(
112     const RegisterWorkerUpdate& register_worker) {
113   std::string address = register_worker.worker_address();
114   DCHECK(!workers_.contains(address));
115   workers_[address] = std::make_shared<Worker>(register_worker);
116   tasks_by_worker_[address] =
117       absl::flat_hash_map<int64_t, std::shared_ptr<Task>>();
118   worker_index_resolver_.AddWorker(address);
119 }
120 
CreateJob(const CreateJobUpdate & create_job)121 void DispatcherState::CreateJob(const CreateJobUpdate& create_job) {
122   int64_t job_id = create_job.job_id();
123   std::string job_name = create_job.job_name();
124   std::optional<int64_t> num_consumers;
125   if (create_job.optional_num_consumers_case() ==
126       CreateJobUpdate::kNumConsumers) {
127     num_consumers = create_job.num_consumers();
128   }
129   auto job = std::make_shared<Job>(
130       job_id, create_job.dataset_id(), create_job.processing_mode_def(),
131       job_name, num_consumers, create_job.use_cross_trainer_cache(),
132       create_job.target_workers());
133   DCHECK(!jobs_by_id_.contains(job_id));
134   jobs_by_id_[job_id] = job;
135   DCHECK(!jobs_by_name_.contains(job_name));
136   jobs_by_name_[job_name] = job;
137   next_available_job_id_ = std::max(next_available_job_id_, job_id + 1);
138 }
139 
JobFromId(int64_t job_id,std::shared_ptr<const Job> & job) const140 Status DispatcherState::JobFromId(int64_t job_id,
141                                   std::shared_ptr<const Job>& job) const {
142   auto it = jobs_by_id_.find(job_id);
143   if (it == jobs_by_id_.end()) {
144     return errors::NotFound("Job with id ", job_id, " not found");
145   }
146   job = it->second;
147   return Status::OK();
148 }
149 
JobByName(const std::string & job_name,std::shared_ptr<const Job> & job) const150 Status DispatcherState::JobByName(const std::string& job_name,
151                                   std::shared_ptr<const Job>& job) const {
152   auto it = jobs_by_name_.find(job_name);
153   if (it == jobs_by_name_.end()) {
154     return errors::NotFound("Job with name ", job_name, " not found");
155   }
156   job = it->second;
157   return Status::OK();
158 }
159 
CreateIteration(const CreateIterationUpdate & create_iteration)160 void DispatcherState::CreateIteration(
161     const CreateIterationUpdate& create_iteration) {
162   int64_t iteration_id = create_iteration.iteration_id();
163   int64_t job_id = create_iteration.job_id();
164   DCHECK(jobs_by_id_.contains(job_id));
165   auto& job = jobs_by_id_[job_id];
166   DCHECK(job);
167   IterationKey iteration_key(job->job_name, create_iteration.repetition());
168   auto iteration = std::make_shared<Iteration>(
169       iteration_id, iteration_key, create_iteration.num_split_providers(), job);
170   DCHECK(!iterations_.contains(iteration_id));
171   iterations_[iteration_id] = iteration;
172   tasks_by_iteration_[iteration_id] = std::vector<std::shared_ptr<Task>>();
173   DCHECK(!iterations_by_key_.contains(iteration_key) ||
174          iterations_by_key_[iteration_key]->garbage_collected);
175   iterations_by_key_[iteration_key] = iteration;
176   next_available_iteration_id_ =
177       std::max(next_available_iteration_id_, iteration_id + 1);
178 }
179 
ProduceSplit(const ProduceSplitUpdate & produce_split)180 void DispatcherState::ProduceSplit(const ProduceSplitUpdate& produce_split) {
181   std::shared_ptr<Iteration> iteration =
182       iterations_[produce_split.iteration_id()];
183   DCHECK(iteration->distributed_epoch_state.has_value());
184   DistributedEpochState& state = iteration->distributed_epoch_state.value();
185   int64_t provider_index = produce_split.split_provider_index();
186   DCHECK_EQ(produce_split.repetition(), state.repetitions[provider_index]);
187   if (produce_split.finished()) {
188     state.repetitions[provider_index]++;
189     state.indices[provider_index] = 0;
190     return;
191   }
192   state.indices[provider_index]++;
193 }
194 
AcquireIterationClient(const AcquireIterationClientUpdate & acquire_iteration_client)195 void DispatcherState::AcquireIterationClient(
196     const AcquireIterationClientUpdate& acquire_iteration_client) {
197   int64_t iteration_client_id = acquire_iteration_client.iteration_client_id();
198   std::shared_ptr<Iteration>& iteration =
199       iterations_for_client_ids_[iteration_client_id];
200   DCHECK(!iteration);
201   iteration = iterations_[acquire_iteration_client.iteration_id()];
202   DCHECK(iteration);
203   iteration->num_clients++;
204   next_available_iteration_client_id_ =
205       std::max(next_available_iteration_client_id_, iteration_client_id + 1);
206 }
207 
ReleaseIterationClient(const ReleaseIterationClientUpdate & release_iteration_client)208 void DispatcherState::ReleaseIterationClient(
209     const ReleaseIterationClientUpdate& release_iteration_client) {
210   int64_t iteration_client_id = release_iteration_client.iteration_client_id();
211   std::shared_ptr<Iteration>& iteration =
212       iterations_for_client_ids_[iteration_client_id];
213   DCHECK(iteration);
214   iteration->num_clients--;
215   DCHECK_GE(iteration->num_clients, 0);
216   iteration->last_client_released_micros =
217       release_iteration_client.time_micros();
218   iterations_for_client_ids_.erase(iteration_client_id);
219 }
220 
GarbageCollectIteration(const GarbageCollectIterationUpdate & garbage_collect_iteration)221 void DispatcherState::GarbageCollectIteration(
222     const GarbageCollectIterationUpdate& garbage_collect_iteration) {
223   int64_t iteration_id = garbage_collect_iteration.iteration_id();
224   for (auto& task : tasks_by_iteration_[iteration_id]) {
225     task->finished = true;
226     tasks_by_worker_[task->worker_address].erase(task->task_id);
227   }
228   iterations_[iteration_id]->finished = true;
229   iterations_[iteration_id]->garbage_collected = true;
230 }
231 
RemoveTask(const RemoveTaskUpdate & remove_task)232 void DispatcherState::RemoveTask(const RemoveTaskUpdate& remove_task) {
233   std::shared_ptr<Task>& task = tasks_[remove_task.task_id()];
234   DCHECK(task);
235   task->removed = true;
236   auto& tasks_for_iteration =
237       tasks_by_iteration_[task->iteration->iteration_id];
238   for (auto it = tasks_for_iteration.begin(); it != tasks_for_iteration.end();
239        ++it) {
240     if ((*it)->task_id == task->task_id) {
241       tasks_for_iteration.erase(it);
242       break;
243     }
244   }
245   tasks_by_worker_[task->worker_address].erase(task->task_id);
246   tasks_.erase(task->task_id);
247   VLOG(1) << "Removed task " << remove_task.task_id() << " from worker "
248           << task->worker_address;
249 }
250 
CreatePendingTask(const CreatePendingTaskUpdate & create_pending_task)251 void DispatcherState::CreatePendingTask(
252     const CreatePendingTaskUpdate& create_pending_task) {
253   int64_t task_id = create_pending_task.task_id();
254   auto& task = tasks_[task_id];
255   DCHECK_EQ(task, nullptr);
256   auto& iteration = iterations_[create_pending_task.iteration_id()];
257   DCHECK_NE(iteration, nullptr);
258   task = std::make_shared<Task>(create_pending_task, iteration);
259   iteration->pending_tasks.emplace(task, create_pending_task.starting_round());
260   tasks_by_worker_[create_pending_task.worker_address()][task->task_id] = task;
261   next_available_task_id_ = std::max(next_available_task_id_, task_id + 1);
262 }
263 
ClientHeartbeat(const ClientHeartbeatUpdate & client_heartbeat)264 void DispatcherState::ClientHeartbeat(
265     const ClientHeartbeatUpdate& client_heartbeat) {
266   int64_t iteration_client_id = client_heartbeat.iteration_client_id();
267   auto& iteration = iterations_for_client_ids_[iteration_client_id];
268   DCHECK(!iteration->pending_tasks.empty());
269   auto& task = iteration->pending_tasks.front();
270   if (client_heartbeat.has_task_rejected()) {
271     task.failures++;
272     task.ready_consumers.clear();
273     task.target_round = client_heartbeat.task_rejected().new_target_round();
274   }
275   if (client_heartbeat.task_accepted()) {
276     task.ready_consumers.insert(iteration_client_id);
277     if (task.ready_consumers.size() == iteration->job->num_consumers.value()) {
278       VLOG(1) << "Promoting task " << task.task->task_id
279               << " from pending to active";
280       task.task->starting_round = task.target_round;
281       tasks_by_iteration_[iteration->iteration_id].push_back(task.task);
282       iteration->pending_tasks.pop();
283     }
284   }
285 }
286 
CreateTask(const CreateTaskUpdate & create_task)287 void DispatcherState::CreateTask(const CreateTaskUpdate& create_task) {
288   int64_t task_id = create_task.task_id();
289   auto& task = tasks_[task_id];
290   DCHECK_EQ(task, nullptr);
291   auto& iteration = iterations_[create_task.iteration_id()];
292   DCHECK_NE(iteration, nullptr);
293   task = std::make_shared<Task>(create_task, iteration);
294   tasks_by_iteration_[create_task.iteration_id()].push_back(task);
295   tasks_by_worker_[create_task.worker_address()][task->task_id] = task;
296   next_available_task_id_ = std::max(next_available_task_id_, task_id + 1);
297 }
298 
FinishTask(const FinishTaskUpdate & finish_task)299 void DispatcherState::FinishTask(const FinishTaskUpdate& finish_task) {
300   VLOG(2) << "Marking task " << finish_task.task_id() << " as finished";
301   int64_t task_id = finish_task.task_id();
302   auto& task = tasks_[task_id];
303   DCHECK(task != nullptr);
304   task->finished = true;
305   tasks_by_worker_[task->worker_address].erase(task->task_id);
306   bool all_finished = true;
307   for (const auto& task_for_iteration :
308        tasks_by_iteration_[task->iteration->iteration_id]) {
309     if (!task_for_iteration->finished) {
310       all_finished = false;
311     }
312   }
313   VLOG(3) << "Iteration " << task->iteration->iteration_id
314           << " finished: " << all_finished;
315   iterations_[task->iteration->iteration_id]->finished = all_finished;
316 }
317 
NextAvailableDatasetId() const318 std::string DispatcherState::NextAvailableDatasetId() const {
319   return absl::StrCat(next_available_dataset_id_);
320 }
321 
UpdateNextAvailableDatasetId()322 void DispatcherState::UpdateNextAvailableDatasetId() {
323   while (datasets_by_id_.contains(absl::StrCat(next_available_dataset_id_))) {
324     ++next_available_dataset_id_;
325   }
326 }
327 
DatasetFromId(const std::string & id,std::shared_ptr<const Dataset> & dataset) const328 Status DispatcherState::DatasetFromId(
329     const std::string& id, std::shared_ptr<const Dataset>& dataset) const {
330   auto it = datasets_by_id_.find(id);
331   if (it == datasets_by_id_.end()) {
332     return errors::NotFound("Dataset id ", id, " not found");
333   }
334   dataset = it->second;
335   return OkStatus();
336 }
337 
DatasetFromFingerprint(uint64 fingerprint,std::shared_ptr<const Dataset> & dataset) const338 Status DispatcherState::DatasetFromFingerprint(
339     uint64 fingerprint, std::shared_ptr<const Dataset>& dataset) const {
340   auto it = datasets_by_fingerprint_.find(fingerprint);
341   if (it == datasets_by_fingerprint_.end()) {
342     return errors::NotFound("Dataset fingerprint ", fingerprint, " not found");
343   }
344   dataset = it->second;
345   return OkStatus();
346 }
347 
WorkerFromAddress(const std::string & address,std::shared_ptr<const Worker> & worker) const348 Status DispatcherState::WorkerFromAddress(
349     const std::string& address, std::shared_ptr<const Worker>& worker) const {
350   auto it = workers_.find(address);
351   if (it == workers_.end()) {
352     return errors::NotFound("Worker with address ", address, " not found.");
353   }
354   worker = it->second;
355   return OkStatus();
356 }
357 
358 std::vector<std::shared_ptr<const DispatcherState::Worker>>
ListWorkers() const359 DispatcherState::ListWorkers() const {
360   std::vector<std::shared_ptr<const Worker>> workers;
361   workers.reserve(workers_.size());
362   for (const auto& it : workers_) {
363     workers.push_back(it.second);
364   }
365   return workers;
366 }
367 
368 std::vector<std::shared_ptr<const DispatcherState::Iteration>>
ListIterations() const369 DispatcherState::ListIterations() const {
370   std::vector<std::shared_ptr<const DispatcherState::Iteration>> iterations;
371   iterations.reserve(iterations_.size());
372   for (const auto& it : iterations_) {
373     iterations.push_back(it.second);
374   }
375   return iterations;
376 }
377 
IterationFromId(int64_t id,std::shared_ptr<const Iteration> & iteration) const378 Status DispatcherState::IterationFromId(
379     int64_t id, std::shared_ptr<const Iteration>& iteration) const {
380   auto it = iterations_.find(id);
381   if (it == iterations_.end()) {
382     return errors::NotFound("Iteration id ", id, " not found");
383   }
384   iteration = it->second;
385   return OkStatus();
386 }
387 
IterationByKey(IterationKey iteration_key,std::shared_ptr<const Iteration> & iteration) const388 Status DispatcherState::IterationByKey(
389     IterationKey iteration_key,
390     std::shared_ptr<const Iteration>& iteration) const {
391   auto it = iterations_by_key_.find(iteration_key);
392   if (it == iterations_by_key_.end()) {
393     return errors::NotFound("Iteration key ", iteration_key.DebugString(),
394                             " not found");
395   }
396   iteration = it->second;
397   return OkStatus();
398 }
399 
NextAvailableJobId() const400 int64_t DispatcherState::NextAvailableJobId() const {
401   return next_available_job_id_;
402 }
403 
NextAvailableIterationId() const404 int64_t DispatcherState::NextAvailableIterationId() const {
405   return next_available_iteration_id_;
406 }
407 
IterationForIterationClientId(int64_t iteration_client_id,std::shared_ptr<const Iteration> & iteration)408 Status DispatcherState::IterationForIterationClientId(
409     int64_t iteration_client_id, std::shared_ptr<const Iteration>& iteration) {
410   iteration = iterations_for_client_ids_[iteration_client_id];
411   if (!iteration) {
412     return errors::NotFound("Iteration client id not found: ",
413                             iteration_client_id);
414   }
415   return OkStatus();
416 }
417 
ListActiveClientIds()418 std::vector<int64_t> DispatcherState::ListActiveClientIds() {
419   std::vector<int64_t> ids;
420   for (const auto& it : iterations_for_client_ids_) {
421     if (it.second && !it.second->finished) {
422       ids.push_back(it.first);
423     }
424   }
425   return ids;
426 }
427 
NextAvailableIterationClientId() const428 int64_t DispatcherState::NextAvailableIterationClientId() const {
429   return next_available_iteration_client_id_;
430 }
431 
TaskFromId(int64_t id,std::shared_ptr<const Task> & task) const432 Status DispatcherState::TaskFromId(int64_t id,
433                                    std::shared_ptr<const Task>& task) const {
434   auto it = tasks_.find(id);
435   if (it == tasks_.end()) {
436     return errors::NotFound("Task ", id, " not found");
437   }
438   task = it->second;
439   return OkStatus();
440 }
441 
TasksForIteration(int64_t iteration_id,std::vector<std::shared_ptr<const Task>> & tasks) const442 Status DispatcherState::TasksForIteration(
443     int64_t iteration_id,
444     std::vector<std::shared_ptr<const Task>>& tasks) const {
445   auto it = tasks_by_iteration_.find(iteration_id);
446   if (it == tasks_by_iteration_.end()) {
447     return errors::NotFound("Iteration ", iteration_id, " not found");
448   }
449   tasks.clear();
450   tasks.reserve(it->second.size());
451   for (const auto& task : it->second) {
452     tasks.push_back(task);
453   }
454   return OkStatus();
455 }
456 
TasksForWorker(absl::string_view worker_address,std::vector<std::shared_ptr<const Task>> & tasks) const457 Status DispatcherState::TasksForWorker(
458     absl::string_view worker_address,
459     std::vector<std::shared_ptr<const Task>>& tasks) const {
460   tasks.clear();
461   auto it = tasks_by_worker_.find(worker_address);
462   if (it == tasks_by_worker_.end()) {
463     return errors::NotFound("Worker ", worker_address, " not found");
464   }
465   const absl::flat_hash_map<int64_t, std::shared_ptr<Task>>& worker_tasks =
466       it->second;
467   tasks.reserve(worker_tasks.size());
468   for (const auto& task : worker_tasks) {
469     tasks.push_back(task.second);
470   }
471   return OkStatus();
472 }
473 
NextAvailableTaskId() const474 int64_t DispatcherState::NextAvailableTaskId() const {
475   return next_available_task_id_;
476 }
477 
ValidateWorker(absl::string_view worker_address) const478 Status DispatcherState::ValidateWorker(absl::string_view worker_address) const {
479   return worker_index_resolver_.ValidateWorker(worker_address);
480 }
481 
GetWorkerIndex(absl::string_view worker_address) const482 StatusOr<int64_t> DispatcherState::GetWorkerIndex(
483     absl::string_view worker_address) const {
484   return worker_index_resolver_.GetWorkerIndex(worker_address);
485 }
486 
487 }  // namespace data
488 }  // namespace tensorflow
489