1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/data/service/dispatcher_state.h"
16
17 #include <algorithm>
18 #include <memory>
19 #include <optional>
20 #include <string>
21 #include <vector>
22
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/strings/string_view.h"
26 #include "tensorflow/core/data/service/common.h"
27 #include "tensorflow/core/data/service/journal.h"
28 #include "tensorflow/core/data/service/journal.pb.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/status.h"
31 #include "tensorflow/core/protobuf/data_service.pb.h"
32 #include "tensorflow/core/protobuf/service_config.pb.h"
33
34 namespace tensorflow {
35 namespace data {
36
DispatcherState()37 DispatcherState::DispatcherState()
38 : worker_index_resolver_(std::vector<std::string>{}) {}
39
DispatcherState(const experimental::DispatcherConfig & dispatcher_config)40 DispatcherState::DispatcherState(
41 const experimental::DispatcherConfig& dispatcher_config)
42 : worker_index_resolver_(dispatcher_config.worker_addresses()) {}
43
Apply(const Update & update)44 Status DispatcherState::Apply(const Update& update) {
45 switch (update.update_type_case()) {
46 case Update::kRegisterDataset:
47 RegisterDataset(update.register_dataset());
48 break;
49 case Update::kRegisterWorker:
50 RegisterWorker(update.register_worker());
51 break;
52 case Update::kCreateJob:
53 CreateJob(update.create_job());
54 break;
55 case Update::kCreateIteration:
56 CreateIteration(update.create_iteration());
57 break;
58 case Update::kProduceSplit:
59 ProduceSplit(update.produce_split());
60 break;
61 case Update::kAcquireIterationClient:
62 AcquireIterationClient(update.acquire_iteration_client());
63 break;
64 case Update::kReleaseIterationClient:
65 ReleaseIterationClient(update.release_iteration_client());
66 break;
67 case Update::kGarbageCollectIteration:
68 GarbageCollectIteration(update.garbage_collect_iteration());
69 break;
70 case Update::kRemoveTask:
71 RemoveTask(update.remove_task());
72 break;
73 case Update::kCreatePendingTask:
74 CreatePendingTask(update.create_pending_task());
75 break;
76 case Update::kClientHeartbeat:
77 ClientHeartbeat(update.client_heartbeat());
78 break;
79 case Update::kCreateTask:
80 CreateTask(update.create_task());
81 break;
82 case Update::kFinishTask:
83 FinishTask(update.finish_task());
84 break;
85 case Update::UPDATE_TYPE_NOT_SET:
86 return errors::Internal("Update type not set.");
87 }
88
89 return OkStatus();
90 }
91
RegisterDataset(const RegisterDatasetUpdate & register_dataset)92 void DispatcherState::RegisterDataset(
93 const RegisterDatasetUpdate& register_dataset) {
94 std::string dataset_id = register_dataset.dataset_id();
95 int64_t fingerprint = register_dataset.fingerprint();
96 auto dataset = std::make_shared<Dataset>(dataset_id, fingerprint,
97 register_dataset.metadata());
98 DCHECK(!datasets_by_id_.contains(dataset_id));
99 datasets_by_id_[dataset_id] = dataset;
100 if (!register_dataset.dedupe_by_dataset_id()) {
101 // Only stores the fingerprint if the user has not requested a dataset ID.
102 // If the user has requested a dataset ID, we will look up datasets by their
103 // IDs, not by fingerprints. Otherwise, an anonymous dataset can refer to
104 // a dataset with an explicit dataset ID.
105 DCHECK(!datasets_by_fingerprint_.contains(fingerprint));
106 datasets_by_fingerprint_[fingerprint] = dataset;
107 }
108 UpdateNextAvailableDatasetId();
109 }
110
RegisterWorker(const RegisterWorkerUpdate & register_worker)111 void DispatcherState::RegisterWorker(
112 const RegisterWorkerUpdate& register_worker) {
113 std::string address = register_worker.worker_address();
114 DCHECK(!workers_.contains(address));
115 workers_[address] = std::make_shared<Worker>(register_worker);
116 tasks_by_worker_[address] =
117 absl::flat_hash_map<int64_t, std::shared_ptr<Task>>();
118 worker_index_resolver_.AddWorker(address);
119 }
120
CreateJob(const CreateJobUpdate & create_job)121 void DispatcherState::CreateJob(const CreateJobUpdate& create_job) {
122 int64_t job_id = create_job.job_id();
123 std::string job_name = create_job.job_name();
124 std::optional<int64_t> num_consumers;
125 if (create_job.optional_num_consumers_case() ==
126 CreateJobUpdate::kNumConsumers) {
127 num_consumers = create_job.num_consumers();
128 }
129 auto job = std::make_shared<Job>(
130 job_id, create_job.dataset_id(), create_job.processing_mode_def(),
131 job_name, num_consumers, create_job.use_cross_trainer_cache(),
132 create_job.target_workers());
133 DCHECK(!jobs_by_id_.contains(job_id));
134 jobs_by_id_[job_id] = job;
135 DCHECK(!jobs_by_name_.contains(job_name));
136 jobs_by_name_[job_name] = job;
137 next_available_job_id_ = std::max(next_available_job_id_, job_id + 1);
138 }
139
JobFromId(int64_t job_id,std::shared_ptr<const Job> & job) const140 Status DispatcherState::JobFromId(int64_t job_id,
141 std::shared_ptr<const Job>& job) const {
142 auto it = jobs_by_id_.find(job_id);
143 if (it == jobs_by_id_.end()) {
144 return errors::NotFound("Job with id ", job_id, " not found");
145 }
146 job = it->second;
147 return Status::OK();
148 }
149
JobByName(const std::string & job_name,std::shared_ptr<const Job> & job) const150 Status DispatcherState::JobByName(const std::string& job_name,
151 std::shared_ptr<const Job>& job) const {
152 auto it = jobs_by_name_.find(job_name);
153 if (it == jobs_by_name_.end()) {
154 return errors::NotFound("Job with name ", job_name, " not found");
155 }
156 job = it->second;
157 return Status::OK();
158 }
159
CreateIteration(const CreateIterationUpdate & create_iteration)160 void DispatcherState::CreateIteration(
161 const CreateIterationUpdate& create_iteration) {
162 int64_t iteration_id = create_iteration.iteration_id();
163 int64_t job_id = create_iteration.job_id();
164 DCHECK(jobs_by_id_.contains(job_id));
165 auto& job = jobs_by_id_[job_id];
166 DCHECK(job);
167 IterationKey iteration_key(job->job_name, create_iteration.repetition());
168 auto iteration = std::make_shared<Iteration>(
169 iteration_id, iteration_key, create_iteration.num_split_providers(), job);
170 DCHECK(!iterations_.contains(iteration_id));
171 iterations_[iteration_id] = iteration;
172 tasks_by_iteration_[iteration_id] = std::vector<std::shared_ptr<Task>>();
173 DCHECK(!iterations_by_key_.contains(iteration_key) ||
174 iterations_by_key_[iteration_key]->garbage_collected);
175 iterations_by_key_[iteration_key] = iteration;
176 next_available_iteration_id_ =
177 std::max(next_available_iteration_id_, iteration_id + 1);
178 }
179
ProduceSplit(const ProduceSplitUpdate & produce_split)180 void DispatcherState::ProduceSplit(const ProduceSplitUpdate& produce_split) {
181 std::shared_ptr<Iteration> iteration =
182 iterations_[produce_split.iteration_id()];
183 DCHECK(iteration->distributed_epoch_state.has_value());
184 DistributedEpochState& state = iteration->distributed_epoch_state.value();
185 int64_t provider_index = produce_split.split_provider_index();
186 DCHECK_EQ(produce_split.repetition(), state.repetitions[provider_index]);
187 if (produce_split.finished()) {
188 state.repetitions[provider_index]++;
189 state.indices[provider_index] = 0;
190 return;
191 }
192 state.indices[provider_index]++;
193 }
194
AcquireIterationClient(const AcquireIterationClientUpdate & acquire_iteration_client)195 void DispatcherState::AcquireIterationClient(
196 const AcquireIterationClientUpdate& acquire_iteration_client) {
197 int64_t iteration_client_id = acquire_iteration_client.iteration_client_id();
198 std::shared_ptr<Iteration>& iteration =
199 iterations_for_client_ids_[iteration_client_id];
200 DCHECK(!iteration);
201 iteration = iterations_[acquire_iteration_client.iteration_id()];
202 DCHECK(iteration);
203 iteration->num_clients++;
204 next_available_iteration_client_id_ =
205 std::max(next_available_iteration_client_id_, iteration_client_id + 1);
206 }
207
ReleaseIterationClient(const ReleaseIterationClientUpdate & release_iteration_client)208 void DispatcherState::ReleaseIterationClient(
209 const ReleaseIterationClientUpdate& release_iteration_client) {
210 int64_t iteration_client_id = release_iteration_client.iteration_client_id();
211 std::shared_ptr<Iteration>& iteration =
212 iterations_for_client_ids_[iteration_client_id];
213 DCHECK(iteration);
214 iteration->num_clients--;
215 DCHECK_GE(iteration->num_clients, 0);
216 iteration->last_client_released_micros =
217 release_iteration_client.time_micros();
218 iterations_for_client_ids_.erase(iteration_client_id);
219 }
220
GarbageCollectIteration(const GarbageCollectIterationUpdate & garbage_collect_iteration)221 void DispatcherState::GarbageCollectIteration(
222 const GarbageCollectIterationUpdate& garbage_collect_iteration) {
223 int64_t iteration_id = garbage_collect_iteration.iteration_id();
224 for (auto& task : tasks_by_iteration_[iteration_id]) {
225 task->finished = true;
226 tasks_by_worker_[task->worker_address].erase(task->task_id);
227 }
228 iterations_[iteration_id]->finished = true;
229 iterations_[iteration_id]->garbage_collected = true;
230 }
231
RemoveTask(const RemoveTaskUpdate & remove_task)232 void DispatcherState::RemoveTask(const RemoveTaskUpdate& remove_task) {
233 std::shared_ptr<Task>& task = tasks_[remove_task.task_id()];
234 DCHECK(task);
235 task->removed = true;
236 auto& tasks_for_iteration =
237 tasks_by_iteration_[task->iteration->iteration_id];
238 for (auto it = tasks_for_iteration.begin(); it != tasks_for_iteration.end();
239 ++it) {
240 if ((*it)->task_id == task->task_id) {
241 tasks_for_iteration.erase(it);
242 break;
243 }
244 }
245 tasks_by_worker_[task->worker_address].erase(task->task_id);
246 tasks_.erase(task->task_id);
247 VLOG(1) << "Removed task " << remove_task.task_id() << " from worker "
248 << task->worker_address;
249 }
250
CreatePendingTask(const CreatePendingTaskUpdate & create_pending_task)251 void DispatcherState::CreatePendingTask(
252 const CreatePendingTaskUpdate& create_pending_task) {
253 int64_t task_id = create_pending_task.task_id();
254 auto& task = tasks_[task_id];
255 DCHECK_EQ(task, nullptr);
256 auto& iteration = iterations_[create_pending_task.iteration_id()];
257 DCHECK_NE(iteration, nullptr);
258 task = std::make_shared<Task>(create_pending_task, iteration);
259 iteration->pending_tasks.emplace(task, create_pending_task.starting_round());
260 tasks_by_worker_[create_pending_task.worker_address()][task->task_id] = task;
261 next_available_task_id_ = std::max(next_available_task_id_, task_id + 1);
262 }
263
ClientHeartbeat(const ClientHeartbeatUpdate & client_heartbeat)264 void DispatcherState::ClientHeartbeat(
265 const ClientHeartbeatUpdate& client_heartbeat) {
266 int64_t iteration_client_id = client_heartbeat.iteration_client_id();
267 auto& iteration = iterations_for_client_ids_[iteration_client_id];
268 DCHECK(!iteration->pending_tasks.empty());
269 auto& task = iteration->pending_tasks.front();
270 if (client_heartbeat.has_task_rejected()) {
271 task.failures++;
272 task.ready_consumers.clear();
273 task.target_round = client_heartbeat.task_rejected().new_target_round();
274 }
275 if (client_heartbeat.task_accepted()) {
276 task.ready_consumers.insert(iteration_client_id);
277 if (task.ready_consumers.size() == iteration->job->num_consumers.value()) {
278 VLOG(1) << "Promoting task " << task.task->task_id
279 << " from pending to active";
280 task.task->starting_round = task.target_round;
281 tasks_by_iteration_[iteration->iteration_id].push_back(task.task);
282 iteration->pending_tasks.pop();
283 }
284 }
285 }
286
CreateTask(const CreateTaskUpdate & create_task)287 void DispatcherState::CreateTask(const CreateTaskUpdate& create_task) {
288 int64_t task_id = create_task.task_id();
289 auto& task = tasks_[task_id];
290 DCHECK_EQ(task, nullptr);
291 auto& iteration = iterations_[create_task.iteration_id()];
292 DCHECK_NE(iteration, nullptr);
293 task = std::make_shared<Task>(create_task, iteration);
294 tasks_by_iteration_[create_task.iteration_id()].push_back(task);
295 tasks_by_worker_[create_task.worker_address()][task->task_id] = task;
296 next_available_task_id_ = std::max(next_available_task_id_, task_id + 1);
297 }
298
FinishTask(const FinishTaskUpdate & finish_task)299 void DispatcherState::FinishTask(const FinishTaskUpdate& finish_task) {
300 VLOG(2) << "Marking task " << finish_task.task_id() << " as finished";
301 int64_t task_id = finish_task.task_id();
302 auto& task = tasks_[task_id];
303 DCHECK(task != nullptr);
304 task->finished = true;
305 tasks_by_worker_[task->worker_address].erase(task->task_id);
306 bool all_finished = true;
307 for (const auto& task_for_iteration :
308 tasks_by_iteration_[task->iteration->iteration_id]) {
309 if (!task_for_iteration->finished) {
310 all_finished = false;
311 }
312 }
313 VLOG(3) << "Iteration " << task->iteration->iteration_id
314 << " finished: " << all_finished;
315 iterations_[task->iteration->iteration_id]->finished = all_finished;
316 }
317
NextAvailableDatasetId() const318 std::string DispatcherState::NextAvailableDatasetId() const {
319 return absl::StrCat(next_available_dataset_id_);
320 }
321
UpdateNextAvailableDatasetId()322 void DispatcherState::UpdateNextAvailableDatasetId() {
323 while (datasets_by_id_.contains(absl::StrCat(next_available_dataset_id_))) {
324 ++next_available_dataset_id_;
325 }
326 }
327
DatasetFromId(const std::string & id,std::shared_ptr<const Dataset> & dataset) const328 Status DispatcherState::DatasetFromId(
329 const std::string& id, std::shared_ptr<const Dataset>& dataset) const {
330 auto it = datasets_by_id_.find(id);
331 if (it == datasets_by_id_.end()) {
332 return errors::NotFound("Dataset id ", id, " not found");
333 }
334 dataset = it->second;
335 return OkStatus();
336 }
337
DatasetFromFingerprint(uint64 fingerprint,std::shared_ptr<const Dataset> & dataset) const338 Status DispatcherState::DatasetFromFingerprint(
339 uint64 fingerprint, std::shared_ptr<const Dataset>& dataset) const {
340 auto it = datasets_by_fingerprint_.find(fingerprint);
341 if (it == datasets_by_fingerprint_.end()) {
342 return errors::NotFound("Dataset fingerprint ", fingerprint, " not found");
343 }
344 dataset = it->second;
345 return OkStatus();
346 }
347
WorkerFromAddress(const std::string & address,std::shared_ptr<const Worker> & worker) const348 Status DispatcherState::WorkerFromAddress(
349 const std::string& address, std::shared_ptr<const Worker>& worker) const {
350 auto it = workers_.find(address);
351 if (it == workers_.end()) {
352 return errors::NotFound("Worker with address ", address, " not found.");
353 }
354 worker = it->second;
355 return OkStatus();
356 }
357
358 std::vector<std::shared_ptr<const DispatcherState::Worker>>
ListWorkers() const359 DispatcherState::ListWorkers() const {
360 std::vector<std::shared_ptr<const Worker>> workers;
361 workers.reserve(workers_.size());
362 for (const auto& it : workers_) {
363 workers.push_back(it.second);
364 }
365 return workers;
366 }
367
368 std::vector<std::shared_ptr<const DispatcherState::Iteration>>
ListIterations() const369 DispatcherState::ListIterations() const {
370 std::vector<std::shared_ptr<const DispatcherState::Iteration>> iterations;
371 iterations.reserve(iterations_.size());
372 for (const auto& it : iterations_) {
373 iterations.push_back(it.second);
374 }
375 return iterations;
376 }
377
IterationFromId(int64_t id,std::shared_ptr<const Iteration> & iteration) const378 Status DispatcherState::IterationFromId(
379 int64_t id, std::shared_ptr<const Iteration>& iteration) const {
380 auto it = iterations_.find(id);
381 if (it == iterations_.end()) {
382 return errors::NotFound("Iteration id ", id, " not found");
383 }
384 iteration = it->second;
385 return OkStatus();
386 }
387
IterationByKey(IterationKey iteration_key,std::shared_ptr<const Iteration> & iteration) const388 Status DispatcherState::IterationByKey(
389 IterationKey iteration_key,
390 std::shared_ptr<const Iteration>& iteration) const {
391 auto it = iterations_by_key_.find(iteration_key);
392 if (it == iterations_by_key_.end()) {
393 return errors::NotFound("Iteration key ", iteration_key.DebugString(),
394 " not found");
395 }
396 iteration = it->second;
397 return OkStatus();
398 }
399
NextAvailableJobId() const400 int64_t DispatcherState::NextAvailableJobId() const {
401 return next_available_job_id_;
402 }
403
NextAvailableIterationId() const404 int64_t DispatcherState::NextAvailableIterationId() const {
405 return next_available_iteration_id_;
406 }
407
IterationForIterationClientId(int64_t iteration_client_id,std::shared_ptr<const Iteration> & iteration)408 Status DispatcherState::IterationForIterationClientId(
409 int64_t iteration_client_id, std::shared_ptr<const Iteration>& iteration) {
410 iteration = iterations_for_client_ids_[iteration_client_id];
411 if (!iteration) {
412 return errors::NotFound("Iteration client id not found: ",
413 iteration_client_id);
414 }
415 return OkStatus();
416 }
417
ListActiveClientIds()418 std::vector<int64_t> DispatcherState::ListActiveClientIds() {
419 std::vector<int64_t> ids;
420 for (const auto& it : iterations_for_client_ids_) {
421 if (it.second && !it.second->finished) {
422 ids.push_back(it.first);
423 }
424 }
425 return ids;
426 }
427
NextAvailableIterationClientId() const428 int64_t DispatcherState::NextAvailableIterationClientId() const {
429 return next_available_iteration_client_id_;
430 }
431
TaskFromId(int64_t id,std::shared_ptr<const Task> & task) const432 Status DispatcherState::TaskFromId(int64_t id,
433 std::shared_ptr<const Task>& task) const {
434 auto it = tasks_.find(id);
435 if (it == tasks_.end()) {
436 return errors::NotFound("Task ", id, " not found");
437 }
438 task = it->second;
439 return OkStatus();
440 }
441
TasksForIteration(int64_t iteration_id,std::vector<std::shared_ptr<const Task>> & tasks) const442 Status DispatcherState::TasksForIteration(
443 int64_t iteration_id,
444 std::vector<std::shared_ptr<const Task>>& tasks) const {
445 auto it = tasks_by_iteration_.find(iteration_id);
446 if (it == tasks_by_iteration_.end()) {
447 return errors::NotFound("Iteration ", iteration_id, " not found");
448 }
449 tasks.clear();
450 tasks.reserve(it->second.size());
451 for (const auto& task : it->second) {
452 tasks.push_back(task);
453 }
454 return OkStatus();
455 }
456
TasksForWorker(absl::string_view worker_address,std::vector<std::shared_ptr<const Task>> & tasks) const457 Status DispatcherState::TasksForWorker(
458 absl::string_view worker_address,
459 std::vector<std::shared_ptr<const Task>>& tasks) const {
460 tasks.clear();
461 auto it = tasks_by_worker_.find(worker_address);
462 if (it == tasks_by_worker_.end()) {
463 return errors::NotFound("Worker ", worker_address, " not found");
464 }
465 const absl::flat_hash_map<int64_t, std::shared_ptr<Task>>& worker_tasks =
466 it->second;
467 tasks.reserve(worker_tasks.size());
468 for (const auto& task : worker_tasks) {
469 tasks.push_back(task.second);
470 }
471 return OkStatus();
472 }
473
NextAvailableTaskId() const474 int64_t DispatcherState::NextAvailableTaskId() const {
475 return next_available_task_id_;
476 }
477
ValidateWorker(absl::string_view worker_address) const478 Status DispatcherState::ValidateWorker(absl::string_view worker_address) const {
479 return worker_index_resolver_.ValidateWorker(worker_address);
480 }
481
GetWorkerIndex(absl::string_view worker_address) const482 StatusOr<int64_t> DispatcherState::GetWorkerIndex(
483 absl::string_view worker_address) const {
484 return worker_index_resolver_.GetWorkerIndex(worker_address);
485 }
486
487 } // namespace data
488 } // namespace tensorflow
489