xref: /aosp_15_r20/external/tensorflow/tensorflow/core/data/service/test_cluster.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DATA_SERVICE_TEST_CLUSTER_H_
17 #define TENSORFLOW_CORE_DATA_SERVICE_TEST_CLUSTER_H_
18 
19 #include <cstdint>
20 #include <memory>
21 #include <optional>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/core/data/service/common.pb.h"
28 #include "tensorflow/core/data/service/data_transfer.h"
29 #include "tensorflow/core/data/service/dispatcher.pb.h"
30 #include "tensorflow/core/data/service/dispatcher_client.h"
31 #include "tensorflow/core/data/service/export.pb.h"
32 #include "tensorflow/core/data/service/server_lib.h"
33 #include "tensorflow/core/data/service/test_util.h"
34 #include "tensorflow/core/data/service/worker.pb.h"
35 #include "tensorflow/core/data/service/worker_client.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/status.h"
39 #include "tensorflow/core/platform/statusor.h"
40 #include "tensorflow/core/platform/types.h"
41 #include "tensorflow/core/protobuf/data_service.pb.h"
42 
43 namespace tensorflow {
44 namespace data {
45 
46 // Helper class for unit testing a tf.data service cluster.
47 class TestCluster {
48  public:
49   struct Config {
50    public:
51     int num_workers = 3;
52     int64_t client_timeout_ms = 0;
53     int64_t worker_heartbeat_interval_ms = 0;
54     int64_t job_gc_check_interval_ms = 0;
55     int64_t job_gc_timeout_ms = 0;
56   };
57 
58   // Creates a new test cluster with a dispatcher and `num_workers` workers.
59   explicit TestCluster(int num_workers);
60   explicit TestCluster(const Config& config);
61 
62   // Initializes the test cluster. This must be called before interacting with
63   // the cluster. Initialize should be called only once.
64   Status Initialize();
65   // Adds a new worker to the cluster.
66   Status AddWorker();
67   // Returns the number of workers in this cluster.
NumWorkers()68   size_t NumWorkers() const { return workers_.size(); }
69   // Returns the number of active iterations.
NumActiveIterations()70   StatusOr<size_t> NumActiveIterations() const {
71     return dispatcher_->NumActiveIterations();
72   }
73   // Returns the dispatcher address in the form "hostname:port".
74   std::string DispatcherAddress() const;
75   // Returns the address of the worker at the specified index, in the form
76   // "hostname:port". The index must be non-negative and less than the number of
77   // workers in the cluster.
78   std::string WorkerAddress(int index) const;
79 
80   // Stops one worker.
81   void StopWorker(size_t index);
82   // Stops all workers.
83   void StopWorkers();
84 
85   // Returns the server state exports.
86   ServerStateExport ExportDispatcherState() const;
87   ServerStateExport ExportWorkerState(size_t index) const;
88 
89  private:
90   bool initialized_ = false;
91   int num_workers_;
92   Config config_;
93   std::unique_ptr<DispatchGrpcDataServer> dispatcher_;
94   std::string dispatcher_address_;
95   std::vector<std::unique_ptr<WorkerGrpcDataServer>> workers_;
96   std::vector<std::string> worker_addresses_;
97 };
98 
99 // A test utility to provide a `DatasetDef` to a `TestCluster` and generate data
100 // from each worker for verification. For example:
101 //
102 // TestCluster cluster(/*num_workers=*/2);
103 // TF_ASSERT_OK(cluster.Initialize());
104 // DatasetClient<int64_t> dataset_reader(cluster);
105 //
106 // EXPECT_THAT(
107 //     dataset_reader.Read(RangeDataset(4), ProcessingModeDef::DATA,
108 //                         TARGET_WORKERS_LOCAL),
109 //     IsOkAndHolds(UnorderedElementsAre(
110 //         Pair(cluster.WorkerAddress(0), ElementsAre(0, 2)),
111 //         Pair(cluster.WorkerAddress(1), ElementsAre(1, 3)))));
112 template <class T>
113 class DatasetClient {
114  public:
115   // Creates a dataset client. It will process datasets in `cluster`.
116   explicit DatasetClient(const TestCluster& cluster);
117 
118   // Maps a worker address to the data it produces when calling `Read`.
119   using WorkerResultMap = absl::flat_hash_map<std::string, std::vector<T>>;
120 
121   // Processes `dataset` and retrieves the data from workers. Returns the data
122   // produced by each worker, keyed by the worker address.
123   StatusOr<WorkerResultMap> Read(
124       const DatasetDef& dataset,
125       ProcessingModeDef::ShardingPolicy sharding_policy,
126       TargetWorkers target_workers);
127   // Creates an iteration and returns the iteration client ID.
128   StatusOr<int64_t> CreateIteration(const DatasetDef& dataset);
129   // Gets the tasks for iteration `iteration_client_id`. The iteration has one
130   // task processed by every worker.
131   StatusOr<std::vector<TaskInfo>> GetTasks(int64_t iteration_client_id);
132 
133  private:
134   // Registers the dataset and returns the dataset ID.
135   StatusOr<std::string> RegisterDataset(const DatasetDef& dataset);
136   // Creates an iteration and returns the iteration client ID.
137   StatusOr<int64_t> CreateIteration(
138       const std::string& dataset_id,
139       ProcessingModeDef::ShardingPolicy sharding_policy,
140       TargetWorkers target_workers);
141   // Reads values from `tasks`, one task at a time, until all tasks have
142   // finished.
143   StatusOr<WorkerResultMap> ReadFromTasks(const std::vector<TaskInfo>& tasks);
144   // Reads the next element from the specified task.
145   StatusOr<GetElementResult> ReadFromTask(const TaskInfo& task_info);
146 
147   const TestCluster& cluster_;
148   std::unique_ptr<DataServiceDispatcherClient> dispatcher_client_;
149   absl::flat_hash_map<std::string, std::unique_ptr<DataServiceWorkerClient>>
150       worker_clients_;
151 };
152 
153 template <class T>
DatasetClient(const TestCluster & cluster)154 DatasetClient<T>::DatasetClient(const TestCluster& cluster)
155     : cluster_(cluster) {
156   dispatcher_client_ = std::make_unique<DataServiceDispatcherClient>(
157       cluster_.DispatcherAddress(), "grpc");
158 
159   for (size_t i = 0; i < cluster.NumWorkers(); ++i) {
160     worker_clients_[cluster_.WorkerAddress(i)] =
161         std::make_unique<DataServiceWorkerClient>(cluster_.WorkerAddress(i),
162                                                   "grpc", "grpc");
163   }
164 }
165 
166 template <class T>
Read(const DatasetDef & dataset,ProcessingModeDef::ShardingPolicy sharding_policy,TargetWorkers target_workers)167 StatusOr<typename DatasetClient<T>::WorkerResultMap> DatasetClient<T>::Read(
168     const DatasetDef& dataset,
169     ProcessingModeDef::ShardingPolicy sharding_policy,
170     TargetWorkers target_workers) {
171   TF_ASSIGN_OR_RETURN(const std::string dataset_id, RegisterDataset(dataset));
172   TF_ASSIGN_OR_RETURN(
173       const int64_t iteration_client_id,
174       CreateIteration(dataset_id, sharding_policy, target_workers));
175   TF_ASSIGN_OR_RETURN(const std::vector<TaskInfo> tasks,
176                       GetTasks(iteration_client_id));
177   return ReadFromTasks(tasks);
178 }
179 
180 template <class T>
RegisterDataset(const DatasetDef & dataset)181 StatusOr<std::string> DatasetClient<T>::RegisterDataset(
182     const DatasetDef& dataset) {
183   std::string dataset_id;
184   TF_RETURN_IF_ERROR(dispatcher_client_->RegisterDataset(
185       dataset, DataServiceMetadata(), /*requested_dataset_id=*/std::nullopt,
186       dataset_id));
187   return dataset_id;
188 }
189 
190 template <class T>
CreateIteration(const std::string & dataset_id,ProcessingModeDef::ShardingPolicy sharding_policy,TargetWorkers target_workers)191 StatusOr<int64_t> DatasetClient<T>::CreateIteration(
192     const std::string& dataset_id,
193     ProcessingModeDef::ShardingPolicy sharding_policy,
194     TargetWorkers target_workers) {
195   ProcessingModeDef processing_mode_def;
196   processing_mode_def.set_sharding_policy(sharding_policy);
197   int64_t job_id;
198   TF_RETURN_IF_ERROR(dispatcher_client_->GetOrCreateJob(
199       dataset_id, processing_mode_def, /*job_name=*/std::nullopt,
200       /*num_consumers=*/std::nullopt, /*use_cross_trainer_cache=*/false,
201       target_workers, job_id));
202   int64_t iteration_client_id;
203   TF_RETURN_IF_ERROR(dispatcher_client_->GetOrCreateIteration(
204       job_id, /*repetition=*/0, iteration_client_id));
205   return iteration_client_id;
206 }
207 
208 template <class T>
CreateIteration(const DatasetDef & dataset)209 StatusOr<int64_t> DatasetClient<T>::CreateIteration(const DatasetDef& dataset) {
210   TF_ASSIGN_OR_RETURN(const std::string dataset_id, RegisterDataset(dataset));
211   return CreateIteration(dataset_id, ProcessingModeDef::OFF,
212                          TARGET_WORKERS_ANY);
213 }
214 
215 template <class T>
GetTasks(const int64_t iteration_client_id)216 StatusOr<std::vector<TaskInfo>> DatasetClient<T>::GetTasks(
217     const int64_t iteration_client_id) {
218   ClientHeartbeatRequest request;
219   ClientHeartbeatResponse response;
220   request.set_iteration_client_id(iteration_client_id);
221   TF_RETURN_IF_ERROR(dispatcher_client_->ClientHeartbeat(request, response));
222   if (response.task_info().empty()) {
223     return errors::NotFound("No task found for iteration ", iteration_client_id,
224                             ".");
225   }
226   return std::vector<TaskInfo>(response.task_info().begin(),
227                                response.task_info().end());
228 }
229 
230 template <class T>
231 StatusOr<typename DatasetClient<T>::WorkerResultMap>
ReadFromTasks(const std::vector<TaskInfo> & tasks)232 DatasetClient<T>::ReadFromTasks(const std::vector<TaskInfo>& tasks) {
233   WorkerResultMap result;
234   bool all_workers_finished = false;
235   while (!all_workers_finished) {
236     all_workers_finished = true;
237     for (const TaskInfo& task : tasks) {
238       StatusOr<GetElementResult> element_result = ReadFromTask(task);
239       // A task may be cancelled when it has finished but other workers are
240       // still producing data.
241       if (errors::IsCancelled(element_result.status())) {
242         continue;
243       }
244       TF_RETURN_IF_ERROR(element_result.status());
245       if (element_result->end_of_sequence) {
246         continue;
247       }
248       all_workers_finished = false;
249       result[task.worker_address()].push_back(
250           element_result->components[0].unaligned_flat<T>().data()[0]);
251     }
252   }
253   return result;
254 }
255 
256 template <class T>
ReadFromTask(const TaskInfo & task_info)257 StatusOr<GetElementResult> DatasetClient<T>::ReadFromTask(
258     const TaskInfo& task_info) {
259   GetElementRequest request;
260   GetElementResult element_result;
261   request.set_task_id(task_info.task_id());
262   TF_RETURN_IF_ERROR(worker_clients_[task_info.worker_address()]->GetElement(
263       request, element_result));
264   return element_result;
265 }
266 
267 }  // namespace data
268 }  // namespace tensorflow
269 
270 #endif  // TENSORFLOW_CORE_DATA_SERVICE_TEST_CLUSTER_H_
271