1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 #include "tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.h"
13
14 #include <string>
15
16 #include "tensorflow/core/common_runtime/forward_type_inference.h"
17 #include "tensorflow/core/data/dataset_test_base.h"
18 #include "tensorflow/core/graph/node_builder.h"
19 #include "tensorflow/core/kernels/data/shard_dataset_op.h"
20 #include "tensorflow/core/public/session_options.h"
21
22 namespace tensorflow {
23 namespace data {
24 namespace experimental {
25 namespace {
26
27 constexpr char kNodeName[] = "auto_shard_dataset";
28
29 class AutoShardDatasetParams : public DatasetParams {
30 public:
31 template <typename T>
AutoShardDatasetParams(T input_dataset_params,int64_t num_workers,int64_t index,int auto_shard_policy,int64_t num_replicas,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)32 AutoShardDatasetParams(T input_dataset_params, int64_t num_workers,
33 int64_t index, int auto_shard_policy,
34 int64_t num_replicas, DataTypeVector output_dtypes,
35 std::vector<PartialTensorShape> output_shapes,
36 string node_name)
37 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
38 std::move(node_name)),
39 num_workers_(num_workers),
40 num_replicas_(num_replicas),
41 index_(index),
42 auto_shard_policy_(auto_shard_policy) {
43 input_dataset_params_.push_back(std::make_unique<T>(input_dataset_params));
44 iterator_prefix_ =
45 name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
46 input_dataset_params.iterator_prefix());
47 }
48
GetInputTensors() const49 std::vector<Tensor> GetInputTensors() const override {
50 return CreateTensors<int64_t>(TensorShape({}), {{num_workers_}, {index_}});
51 }
52
GetInputNames(std::vector<string> * input_names) const53 Status GetInputNames(std::vector<string>* input_names) const override {
54 input_names->clear();
55 input_names->emplace_back(AutoShardDatasetOp::kInputDataset);
56 input_names->emplace_back(AutoShardDatasetOp::kNumWorkers);
57 input_names->emplace_back(AutoShardDatasetOp::kIndex);
58 return OkStatus();
59 }
60
GetAttributes(AttributeVector * attr_vector) const61 Status GetAttributes(AttributeVector* attr_vector) const override {
62 attr_vector->clear();
63 attr_vector->emplace_back(AutoShardDatasetOp::kAutoShardPolicy,
64 auto_shard_policy_);
65 attr_vector->emplace_back(AutoShardDatasetOp::kNumReplicas, num_replicas_);
66 attr_vector->emplace_back(AutoShardDatasetOp::kOutputTypes, output_dtypes_);
67 attr_vector->emplace_back(AutoShardDatasetOp::kOutputShapes,
68 output_shapes_);
69 return OkStatus();
70 }
71
dataset_type() const72 string dataset_type() const override {
73 return AutoShardDatasetOp::kDatasetType;
74 }
75
76 private:
77 int64_t num_workers_;
78 int64_t num_replicas_;
79 int64_t index_;
80 int auto_shard_policy_;
81 };
82
83 class AutoShardDatasetOpTest : public DatasetOpsTestBase {};
84
85 // Test Case 1: simple case.
AutoShardDatasetParams1()86 AutoShardDatasetParams AutoShardDatasetParams1() {
87 return AutoShardDatasetParams(RangeDatasetParams(0, 10, 1),
88 /*num_workers=*/5,
89 /*index=*/2,
90 /*auto_shard_policy=*/0,
91 /*num_replicas=*/5,
92 /*output_dtypes=*/{DT_INT64},
93 /*output_shapes=*/{PartialTensorShape({})},
94 /*node_name=*/kNodeName);
95 }
96
97 // Test Case 2: the index is larger than the available elements.
AutoShardDatasetParams2()98 AutoShardDatasetParams AutoShardDatasetParams2() {
99 return AutoShardDatasetParams(RangeDatasetParams(0, 1, 1),
100 /*num_workers=*/5,
101 /*index=*/2,
102 /*auto_shard_policy=*/0,
103 /*num_replicas=*/5,
104 /*output_dtypes=*/{DT_INT64},
105 /*output_shapes=*/{PartialTensorShape({})},
106 /*node_name=*/kNodeName);
107 }
108
109 // Test Case 3: the number of outputs could not be evenly divided by
110 // num_workers.
AutoShardDatasetParams3()111 AutoShardDatasetParams AutoShardDatasetParams3() {
112 return AutoShardDatasetParams(RangeDatasetParams(0, 10, 1),
113 /*num_workers=*/4,
114 /*index=*/3,
115 /*auto_shard_policy=*/0,
116 /*num_replicas=*/4,
117 /*output_dtypes=*/{DT_INT64},
118 /*output_shapes=*/{PartialTensorShape({})},
119 /*node_name=*/kNodeName);
120 }
121
122 // TODO(feihugis): add more test cases that have ReaderDatasets (e.g. a
123 // CSVDataset or a TFRecordDataset) in the pipeline.
124
125 // Test case 4: the index is greater than the number of workers.
AutoShardDatasetParams4()126 AutoShardDatasetParams AutoShardDatasetParams4() {
127 return AutoShardDatasetParams(RangeDatasetParams(0, 10, 1),
128 /*num_workers=*/5,
129 /*index=*/7,
130 /*auto_shard_policy=*/0,
131 /*num_replicas=*/5,
132 /*output_dtypes=*/{DT_INT64},
133 /*output_shapes=*/{PartialTensorShape({})},
134 /*node_name=*/kNodeName);
135 }
136
137 // Test case 5: the index is negative.
AutoShardDatasetParams5()138 AutoShardDatasetParams AutoShardDatasetParams5() {
139 return AutoShardDatasetParams(RangeDatasetParams(0, 10, 1),
140 /*num_workers=*/5,
141 /*index=*/-3,
142 /*auto_shard_policy=*/0,
143 /*num_replicas=*/5,
144 /*output_dtypes=*/{DT_INT64},
145 /*output_shapes=*/{PartialTensorShape({})},
146 /*node_name=*/kNodeName);
147 }
148
149 // Test case 6: num_workers is negative.
AutoShardDatasetParams6()150 AutoShardDatasetParams AutoShardDatasetParams6() {
151 return AutoShardDatasetParams(RangeDatasetParams(0, 10, 1),
152 /*num_workers=*/-3,
153 /*index=*/1,
154 /*auto_shard_policy=*/0,
155 /*num_replicas=*/5,
156 /*output_dtypes=*/{DT_INT64},
157 /*output_shapes=*/{PartialTensorShape({})},
158 /*node_name=*/kNodeName);
159 }
160
161 // Test case 7: num_workers is zero.
AutoShardDatasetParams7()162 AutoShardDatasetParams AutoShardDatasetParams7() {
163 return AutoShardDatasetParams(RangeDatasetParams(0, 10, 1),
164 /*num_workers=*/0,
165 /*index=*/1,
166 /*auto_shard_policy=*/0,
167 /*num_replicas=*/5,
168 /*output_dtypes=*/{DT_INT64},
169 /*output_shapes=*/{PartialTensorShape({})},
170 /*node_name=*/kNodeName);
171 }
172
GetNextTestCases()173 std::vector<GetNextTestCase<AutoShardDatasetParams>> GetNextTestCases() {
174 return {
175 {/*dataset_params=*/AutoShardDatasetParams1(),
176 /*expected_outputs=*/CreateTensors<int64_t>(TensorShape{}, {{2}, {7}})},
177 {/*dataset_params=*/AutoShardDatasetParams2(),
178 /*expected_outputs=*/{}},
179 {/*dataset_params=*/AutoShardDatasetParams3(),
180 /*expected_outputs=*/CreateTensors<int64_t>(TensorShape{}, {{3}, {7}})}};
181 }
182
ITERATOR_GET_NEXT_TEST_P(AutoShardDatasetOpTest,AutoShardDatasetParams,GetNextTestCases ())183 ITERATOR_GET_NEXT_TEST_P(AutoShardDatasetOpTest, AutoShardDatasetParams,
184 GetNextTestCases())
185
186 TEST_F(AutoShardDatasetOpTest, InvalidArguments) {
187 std::vector<AutoShardDatasetParams> invalid_dataset_params = {
188 AutoShardDatasetParams4(), AutoShardDatasetParams5(),
189 AutoShardDatasetParams6(), AutoShardDatasetParams7()};
190 for (const auto& dataset_params : invalid_dataset_params) {
191 EXPECT_EQ(Initialize(dataset_params).code(),
192 tensorflow::error::INVALID_ARGUMENT);
193 }
194 }
195
196 // TODO(b/222556529) when Const has type constructor, remove the following
197 REGISTER_OP("AutoShardDatasetOpTest>ConstTypeCtor")
198 .Output("output: dtype")
199 .Attr("value: tensor")
200 .Attr("dtype: type")
201 .SetTypeConstructor(full_type::Unary(TFT_TENSOR, "dtype"));
202
203 // Adds identity notes to all outputs of this node
add_identity_nodes(Node * node,Graph & graph,std::vector<Node * > & identity_nodes)204 static void add_identity_nodes(Node* node, Graph& graph,
205 std::vector<Node*>& identity_nodes) {
206 for (int i = 0; i < node->num_outputs(); i++) {
207 Node* new_node;
208 std::string name = absl::StrCat("Identity", i);
209 TF_EXPECT_OK(NodeBuilder(name, "Identity")
210 .Attr("T", node->output_type(i))
211 .Input(node, i)
212 .Finalize(&graph, &new_node));
213 identity_nodes.push_back(new_node);
214 }
215 }
216
217 // Runs type inference pass on graph
type_inference(Graph & graph)218 static Status type_inference(Graph& graph) {
219 GraphOptimizationPassOptions opt_options;
220 std::unique_ptr<Graph> graph_ptr(new Graph(OpRegistry::Global()));
221 graph_ptr->Copy(graph);
222 opt_options.graph = &graph_ptr;
223 opt_options.flib_def = graph.mutable_flib_def();
224 ForwardTypeInferencePass pass;
225 return pass.Run(opt_options);
226 }
227
TEST_F(AutoShardDatasetOpTest,AutoShardDatasetTypeInference)228 TEST_F(AutoShardDatasetOpTest, AutoShardDatasetTypeInference) {
229 Graph graph(OpRegistry::Global());
230 Node* input_dataset;
231 Node* num_workers;
232 Node* index;
233 Node* auto_shard_dataset;
234 FullTypeDef input_dataset_t;
235 protobuf::TextFormat::Parser parser;
236 CHECK(parser.ParseFromString(
237 R"pb(type_id: TFT_PRODUCT
238 args {
239 type_id: TFT_DATASET
240 args {
241 type_id: TFT_PRODUCT
242 args {
243 type_id: TFT_RAGGED
244 args { type_id: TFT_STRING }
245 }
246 }
247 })pb",
248 &input_dataset_t));
249 TensorProto tensor_proto;
250 TF_EXPECT_OK(NodeBuilder("input_dataset", "Const")
251 .Attr("value", tensor_proto)
252 .Attr("dtype", DT_VARIANT)
253 .Finalize(&graph, &input_dataset));
254 (*input_dataset->mutable_def()->mutable_experimental_type()) =
255 input_dataset_t;
256 // TODO(b/222556529) when Const has type constructor, use Const
257 TF_EXPECT_OK(
258 NodeBuilder("num_workers", "AutoShardDatasetOpTest>ConstTypeCtor")
259 .Attr("value", tensor_proto)
260 .Attr("dtype", DT_INT64)
261 .Finalize(&graph, &num_workers));
262 // TODO(b/222556529) when Const has type constructor, use Const
263 TF_EXPECT_OK(NodeBuilder("index", "AutoShardDatasetOpTest>ConstTypeCtor")
264 .Attr("value", tensor_proto)
265 .Attr("dtype", DT_INT64)
266 .Finalize(&graph, &index));
267 TF_EXPECT_OK(NodeBuilder("AutoShardDataset", "AutoShardDataset")
268 .Attr("output_types", {DT_VARIANT})
269 .Attr("output_shapes", {TensorShape({1})})
270 .Input(input_dataset)
271 .Input(num_workers)
272 .Input(index)
273 .Finalize(&graph, &auto_shard_dataset));
274 std::vector<Node*> identity_nodes;
275 add_identity_nodes(auto_shard_dataset, graph, identity_nodes);
276 TF_EXPECT_OK(type_inference(graph));
277 EXPECT_TRUE(full_type::IsEqual(identity_nodes[0]->def().experimental_type(),
278 input_dataset_t))
279 << "fulltype is\n"
280 << identity_nodes[0]->def().experimental_type().DebugString()
281 << "\nexpected\n"
282 << input_dataset_t.DebugString();
283 }
284
TEST_F(AutoShardDatasetOpTest,RebatchDatasetTypeInference)285 TEST_F(AutoShardDatasetOpTest, RebatchDatasetTypeInference) {
286 Graph graph(OpRegistry::Global());
287 Node* input_dataset;
288 Node* num_replicas;
289 Node* rebatch_dataset;
290 FullTypeDef input_dataset_t;
291 protobuf::TextFormat::Parser parser;
292 CHECK(parser.ParseFromString(
293 R"pb(type_id: TFT_PRODUCT
294 args {
295 type_id: TFT_DATASET
296 args {
297 type_id: TFT_PRODUCT
298 args {
299 type_id: TFT_RAGGED
300 args { type_id: TFT_STRING }
301 }
302 }
303 })pb",
304 &input_dataset_t));
305 TensorProto tensor_proto;
306 TF_EXPECT_OK(NodeBuilder("input_dataset", "Const")
307 .Attr("value", tensor_proto)
308 .Attr("dtype", DT_VARIANT)
309 .Finalize(&graph, &input_dataset));
310 (*input_dataset->mutable_def()->mutable_experimental_type()) =
311 input_dataset_t;
312 // TODO(b/222556529) when Const has type constructor, use Const
313 TF_EXPECT_OK(
314 NodeBuilder("num_replicas", "AutoShardDatasetOpTest>ConstTypeCtor")
315 .Attr("value", tensor_proto)
316 .Attr("dtype", DT_INT64)
317 .Finalize(&graph, &num_replicas));
318 TF_EXPECT_OK(NodeBuilder("RebatchDataset", "RebatchDataset")
319 .Attr("output_types", {DT_VARIANT})
320 .Attr("output_shapes", {TensorShape({1})})
321 .Input(input_dataset)
322 .Input(num_replicas)
323 .Finalize(&graph, &rebatch_dataset));
324 std::vector<Node*> identity_nodes;
325 add_identity_nodes(rebatch_dataset, graph, identity_nodes);
326 TF_EXPECT_OK(type_inference(graph));
327 EXPECT_TRUE(full_type::IsEqual(identity_nodes[0]->def().experimental_type(),
328 input_dataset_t))
329 << "fulltype is\n"
330 << identity_nodes[0]->def().experimental_type().DebugString()
331 << "\nexpected\n"
332 << input_dataset_t.DebugString();
333 }
334
TEST_F(AutoShardDatasetOpTest,RebatchDatasetV2TypeInference)335 TEST_F(AutoShardDatasetOpTest, RebatchDatasetV2TypeInference) {
336 Graph graph(OpRegistry::Global());
337 Node* input_dataset;
338 Node* batch_sizes;
339 Node* drop_remainder;
340 Node* rebatch_dataset_v2;
341 FullTypeDef input_dataset_t;
342 protobuf::TextFormat::Parser parser;
343 CHECK(parser.ParseFromString(
344 R"pb(type_id: TFT_PRODUCT
345 args {
346 type_id: TFT_DATASET
347 args {
348 type_id: TFT_PRODUCT
349 args {
350 type_id: TFT_RAGGED
351 args { type_id: TFT_STRING }
352 }
353 }
354 })pb",
355 &input_dataset_t));
356 TensorProto tensor_proto;
357 TF_EXPECT_OK(NodeBuilder("input_dataset", "Const")
358 .Attr("value", tensor_proto)
359 .Attr("dtype", DT_VARIANT)
360 .Finalize(&graph, &input_dataset));
361 (*input_dataset->mutable_def()->mutable_experimental_type()) =
362 input_dataset_t;
363 // TODO(b/222556529) when Const has type constructor, use Const
364 TF_EXPECT_OK(
365 NodeBuilder("num_replicas", "AutoShardDatasetOpTest>ConstTypeCtor")
366 .Attr("value", tensor_proto)
367 .Attr("dtype", DT_INT64)
368 .Finalize(&graph, &batch_sizes));
369 // TODO(b/222556529) when Const has type constructor, use Const
370 TF_EXPECT_OK(
371 NodeBuilder("drop_remainder", "AutoShardDatasetOpTest>ConstTypeCtor")
372 .Attr("value", tensor_proto)
373 .Attr("dtype", DT_BOOL)
374 .Finalize(&graph, &drop_remainder));
375 TF_EXPECT_OK(NodeBuilder("RebatchDatasetV2", "RebatchDatasetV2")
376 .Attr("output_types", {DT_VARIANT})
377 .Attr("output_shapes", {TensorShape({1})})
378 .Input(input_dataset)
379 .Input(batch_sizes)
380 .Input(drop_remainder)
381 .Finalize(&graph, &rebatch_dataset_v2));
382 std::vector<Node*> identity_nodes;
383 add_identity_nodes(rebatch_dataset_v2, graph, identity_nodes);
384 TF_EXPECT_OK(type_inference(graph));
385 EXPECT_TRUE(full_type::IsEqual(identity_nodes[0]->def().experimental_type(),
386 input_dataset_t))
387 << "fulltype is\n"
388 << identity_nodes[0]->def().experimental_type().DebugString()
389 << "\nexpected\n"
390 << input_dataset_t.DebugString();
391 }
392
393 } // namespace
394 } // namespace experimental
395 } // namespace data
396 } // namespace tensorflow
397