1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_DATA_SERVICE_TEST_CLUSTER_H_
17 #define TENSORFLOW_CORE_DATA_SERVICE_TEST_CLUSTER_H_
18
19 #include <cstdint>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <vector>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/core/data/service/common.pb.h"
28 #include "tensorflow/core/data/service/data_transfer.h"
29 #include "tensorflow/core/data/service/dispatcher.pb.h"
30 #include "tensorflow/core/data/service/dispatcher_client.h"
31 #include "tensorflow/core/data/service/export.pb.h"
32 #include "tensorflow/core/data/service/server_lib.h"
33 #include "tensorflow/core/data/service/test_util.h"
34 #include "tensorflow/core/data/service/worker.pb.h"
35 #include "tensorflow/core/data/service/worker_client.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/status.h"
39 #include "tensorflow/core/platform/statusor.h"
40 #include "tensorflow/core/platform/types.h"
41 #include "tensorflow/core/protobuf/data_service.pb.h"
42
43 namespace tensorflow {
44 namespace data {
45
46 // Helper class for unit testing a tf.data service cluster.
47 class TestCluster {
48 public:
49 struct Config {
50 public:
51 int num_workers = 3;
52 int64_t client_timeout_ms = 0;
53 int64_t worker_heartbeat_interval_ms = 0;
54 int64_t job_gc_check_interval_ms = 0;
55 int64_t job_gc_timeout_ms = 0;
56 };
57
58 // Creates a new test cluster with a dispatcher and `num_workers` workers.
59 explicit TestCluster(int num_workers);
60 explicit TestCluster(const Config& config);
61
62 // Initializes the test cluster. This must be called before interacting with
63 // the cluster. Initialize should be called only once.
64 Status Initialize();
65 // Adds a new worker to the cluster.
66 Status AddWorker();
67 // Returns the number of workers in this cluster.
NumWorkers()68 size_t NumWorkers() const { return workers_.size(); }
69 // Returns the number of active iterations.
NumActiveIterations()70 StatusOr<size_t> NumActiveIterations() const {
71 return dispatcher_->NumActiveIterations();
72 }
73 // Returns the dispatcher address in the form "hostname:port".
74 std::string DispatcherAddress() const;
75 // Returns the address of the worker at the specified index, in the form
76 // "hostname:port". The index must be non-negative and less than the number of
77 // workers in the cluster.
78 std::string WorkerAddress(int index) const;
79
80 // Stops one worker.
81 void StopWorker(size_t index);
82 // Stops all workers.
83 void StopWorkers();
84
85 // Returns the server state exports.
86 ServerStateExport ExportDispatcherState() const;
87 ServerStateExport ExportWorkerState(size_t index) const;
88
89 private:
90 bool initialized_ = false;
91 int num_workers_;
92 Config config_;
93 std::unique_ptr<DispatchGrpcDataServer> dispatcher_;
94 std::string dispatcher_address_;
95 std::vector<std::unique_ptr<WorkerGrpcDataServer>> workers_;
96 std::vector<std::string> worker_addresses_;
97 };
98
99 // A test utility to provide a `DatasetDef` to a `TestCluster` and generate data
100 // from each worker for verification. For example:
101 //
102 // TestCluster cluster(/*num_workers=*/2);
103 // TF_ASSERT_OK(cluster.Initialize());
104 // DatasetClient<int64_t> dataset_reader(cluster);
105 //
106 // EXPECT_THAT(
107 // dataset_reader.Read(RangeDataset(4), ProcessingModeDef::DATA,
108 // TARGET_WORKERS_LOCAL),
109 // IsOkAndHolds(UnorderedElementsAre(
110 // Pair(cluster.WorkerAddress(0), ElementsAre(0, 2)),
111 // Pair(cluster.WorkerAddress(1), ElementsAre(1, 3)))));
112 template <class T>
113 class DatasetClient {
114 public:
115 // Creates a dataset client. It will process datasets in `cluster`.
116 explicit DatasetClient(const TestCluster& cluster);
117
118 // Maps a worker address to the data it produces when calling `Read`.
119 using WorkerResultMap = absl::flat_hash_map<std::string, std::vector<T>>;
120
121 // Processes `dataset` and retrieves the data from workers. Returns the data
122 // produced by each worker, keyed by the worker address.
123 StatusOr<WorkerResultMap> Read(
124 const DatasetDef& dataset,
125 ProcessingModeDef::ShardingPolicy sharding_policy,
126 TargetWorkers target_workers);
127 // Creates an iteration and returns the iteration client ID.
128 StatusOr<int64_t> CreateIteration(const DatasetDef& dataset);
129 // Gets the tasks for iteration `iteration_client_id`. The iteration has one
130 // task processed by every worker.
131 StatusOr<std::vector<TaskInfo>> GetTasks(int64_t iteration_client_id);
132
133 private:
134 // Registers the dataset and returns the dataset ID.
135 StatusOr<std::string> RegisterDataset(const DatasetDef& dataset);
136 // Creates an iteration and returns the iteration client ID.
137 StatusOr<int64_t> CreateIteration(
138 const std::string& dataset_id,
139 ProcessingModeDef::ShardingPolicy sharding_policy,
140 TargetWorkers target_workers);
141 // Reads values from `tasks`, one task at a time, until all tasks have
142 // finished.
143 StatusOr<WorkerResultMap> ReadFromTasks(const std::vector<TaskInfo>& tasks);
144 // Reads the next element from the specified task.
145 StatusOr<GetElementResult> ReadFromTask(const TaskInfo& task_info);
146
147 const TestCluster& cluster_;
148 std::unique_ptr<DataServiceDispatcherClient> dispatcher_client_;
149 absl::flat_hash_map<std::string, std::unique_ptr<DataServiceWorkerClient>>
150 worker_clients_;
151 };
152
153 template <class T>
DatasetClient(const TestCluster & cluster)154 DatasetClient<T>::DatasetClient(const TestCluster& cluster)
155 : cluster_(cluster) {
156 dispatcher_client_ = std::make_unique<DataServiceDispatcherClient>(
157 cluster_.DispatcherAddress(), "grpc");
158
159 for (size_t i = 0; i < cluster.NumWorkers(); ++i) {
160 worker_clients_[cluster_.WorkerAddress(i)] =
161 std::make_unique<DataServiceWorkerClient>(cluster_.WorkerAddress(i),
162 "grpc", "grpc");
163 }
164 }
165
166 template <class T>
Read(const DatasetDef & dataset,ProcessingModeDef::ShardingPolicy sharding_policy,TargetWorkers target_workers)167 StatusOr<typename DatasetClient<T>::WorkerResultMap> DatasetClient<T>::Read(
168 const DatasetDef& dataset,
169 ProcessingModeDef::ShardingPolicy sharding_policy,
170 TargetWorkers target_workers) {
171 TF_ASSIGN_OR_RETURN(const std::string dataset_id, RegisterDataset(dataset));
172 TF_ASSIGN_OR_RETURN(
173 const int64_t iteration_client_id,
174 CreateIteration(dataset_id, sharding_policy, target_workers));
175 TF_ASSIGN_OR_RETURN(const std::vector<TaskInfo> tasks,
176 GetTasks(iteration_client_id));
177 return ReadFromTasks(tasks);
178 }
179
180 template <class T>
RegisterDataset(const DatasetDef & dataset)181 StatusOr<std::string> DatasetClient<T>::RegisterDataset(
182 const DatasetDef& dataset) {
183 std::string dataset_id;
184 TF_RETURN_IF_ERROR(dispatcher_client_->RegisterDataset(
185 dataset, DataServiceMetadata(), /*requested_dataset_id=*/std::nullopt,
186 dataset_id));
187 return dataset_id;
188 }
189
190 template <class T>
CreateIteration(const std::string & dataset_id,ProcessingModeDef::ShardingPolicy sharding_policy,TargetWorkers target_workers)191 StatusOr<int64_t> DatasetClient<T>::CreateIteration(
192 const std::string& dataset_id,
193 ProcessingModeDef::ShardingPolicy sharding_policy,
194 TargetWorkers target_workers) {
195 ProcessingModeDef processing_mode_def;
196 processing_mode_def.set_sharding_policy(sharding_policy);
197 int64_t job_id;
198 TF_RETURN_IF_ERROR(dispatcher_client_->GetOrCreateJob(
199 dataset_id, processing_mode_def, /*job_name=*/std::nullopt,
200 /*num_consumers=*/std::nullopt, /*use_cross_trainer_cache=*/false,
201 target_workers, job_id));
202 int64_t iteration_client_id;
203 TF_RETURN_IF_ERROR(dispatcher_client_->GetOrCreateIteration(
204 job_id, /*repetition=*/0, iteration_client_id));
205 return iteration_client_id;
206 }
207
208 template <class T>
CreateIteration(const DatasetDef & dataset)209 StatusOr<int64_t> DatasetClient<T>::CreateIteration(const DatasetDef& dataset) {
210 TF_ASSIGN_OR_RETURN(const std::string dataset_id, RegisterDataset(dataset));
211 return CreateIteration(dataset_id, ProcessingModeDef::OFF,
212 TARGET_WORKERS_ANY);
213 }
214
215 template <class T>
GetTasks(const int64_t iteration_client_id)216 StatusOr<std::vector<TaskInfo>> DatasetClient<T>::GetTasks(
217 const int64_t iteration_client_id) {
218 ClientHeartbeatRequest request;
219 ClientHeartbeatResponse response;
220 request.set_iteration_client_id(iteration_client_id);
221 TF_RETURN_IF_ERROR(dispatcher_client_->ClientHeartbeat(request, response));
222 if (response.task_info().empty()) {
223 return errors::NotFound("No task found for iteration ", iteration_client_id,
224 ".");
225 }
226 return std::vector<TaskInfo>(response.task_info().begin(),
227 response.task_info().end());
228 }
229
230 template <class T>
231 StatusOr<typename DatasetClient<T>::WorkerResultMap>
ReadFromTasks(const std::vector<TaskInfo> & tasks)232 DatasetClient<T>::ReadFromTasks(const std::vector<TaskInfo>& tasks) {
233 WorkerResultMap result;
234 bool all_workers_finished = false;
235 while (!all_workers_finished) {
236 all_workers_finished = true;
237 for (const TaskInfo& task : tasks) {
238 StatusOr<GetElementResult> element_result = ReadFromTask(task);
239 // A task may be cancelled when it has finished but other workers are
240 // still producing data.
241 if (errors::IsCancelled(element_result.status())) {
242 continue;
243 }
244 TF_RETURN_IF_ERROR(element_result.status());
245 if (element_result->end_of_sequence) {
246 continue;
247 }
248 all_workers_finished = false;
249 result[task.worker_address()].push_back(
250 element_result->components[0].unaligned_flat<T>().data()[0]);
251 }
252 }
253 return result;
254 }
255
256 template <class T>
ReadFromTask(const TaskInfo & task_info)257 StatusOr<GetElementResult> DatasetClient<T>::ReadFromTask(
258 const TaskInfo& task_info) {
259 GetElementRequest request;
260 GetElementResult element_result;
261 request.set_task_id(task_info.task_id());
262 TF_RETURN_IF_ERROR(worker_clients_[task_info.worker_address()]->GetElement(
263 request, element_result));
264 return element_result;
265 }
266
267 } // namespace data
268 } // namespace tensorflow
269
270 #endif // TENSORFLOW_CORE_DATA_SERVICE_TEST_CLUSTER_H_
271