1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5     http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 #include "tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.h"
13 
14 #include <string>
15 
16 #include "tensorflow/core/common_runtime/forward_type_inference.h"
17 #include "tensorflow/core/data/dataset_test_base.h"
18 #include "tensorflow/core/graph/node_builder.h"
19 #include "tensorflow/core/kernels/data/shard_dataset_op.h"
20 #include "tensorflow/core/public/session_options.h"
21 
22 namespace tensorflow {
23 namespace data {
24 namespace experimental {
25 namespace {
26 
27 constexpr char kNodeName[] = "auto_shard_dataset";
28 
29 class AutoShardDatasetParams : public DatasetParams {
30  public:
31   template <typename T>
AutoShardDatasetParams(T input_dataset_params,int64_t num_workers,int64_t index,int auto_shard_policy,int64_t num_replicas,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)32   AutoShardDatasetParams(T input_dataset_params, int64_t num_workers,
33                          int64_t index, int auto_shard_policy,
34                          int64_t num_replicas, DataTypeVector output_dtypes,
35                          std::vector<PartialTensorShape> output_shapes,
36                          string node_name)
37       : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
38                       std::move(node_name)),
39         num_workers_(num_workers),
40         num_replicas_(num_replicas),
41         index_(index),
42         auto_shard_policy_(auto_shard_policy) {
43     input_dataset_params_.push_back(std::make_unique<T>(input_dataset_params));
44     iterator_prefix_ =
45         name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
46                                    input_dataset_params.iterator_prefix());
47   }
48 
GetInputTensors() const49   std::vector<Tensor> GetInputTensors() const override {
50     return CreateTensors<int64_t>(TensorShape({}), {{num_workers_}, {index_}});
51   }
52 
GetInputNames(std::vector<string> * input_names) const53   Status GetInputNames(std::vector<string>* input_names) const override {
54     input_names->clear();
55     input_names->emplace_back(AutoShardDatasetOp::kInputDataset);
56     input_names->emplace_back(AutoShardDatasetOp::kNumWorkers);
57     input_names->emplace_back(AutoShardDatasetOp::kIndex);
58     return OkStatus();
59   }
60 
GetAttributes(AttributeVector * attr_vector) const61   Status GetAttributes(AttributeVector* attr_vector) const override {
62     attr_vector->clear();
63     attr_vector->emplace_back(AutoShardDatasetOp::kAutoShardPolicy,
64                               auto_shard_policy_);
65     attr_vector->emplace_back(AutoShardDatasetOp::kNumReplicas, num_replicas_);
66     attr_vector->emplace_back(AutoShardDatasetOp::kOutputTypes, output_dtypes_);
67     attr_vector->emplace_back(AutoShardDatasetOp::kOutputShapes,
68                               output_shapes_);
69     return OkStatus();
70   }
71 
dataset_type() const72   string dataset_type() const override {
73     return AutoShardDatasetOp::kDatasetType;
74   }
75 
76  private:
77   int64_t num_workers_;
78   int64_t num_replicas_;
79   int64_t index_;
80   int auto_shard_policy_;
81 };
82 
83 class AutoShardDatasetOpTest : public DatasetOpsTestBase {};
84 
85 // Test Case 1: simple case.
AutoShardDatasetParams1()86 AutoShardDatasetParams AutoShardDatasetParams1() {
87   return AutoShardDatasetParams(RangeDatasetParams(0, 10, 1),
88                                 /*num_workers=*/5,
89                                 /*index=*/2,
90                                 /*auto_shard_policy=*/0,
91                                 /*num_replicas=*/5,
92                                 /*output_dtypes=*/{DT_INT64},
93                                 /*output_shapes=*/{PartialTensorShape({})},
94                                 /*node_name=*/kNodeName);
95 }
96 
97 // Test Case 2: the index is larger than the available elements.
AutoShardDatasetParams2()98 AutoShardDatasetParams AutoShardDatasetParams2() {
99   return AutoShardDatasetParams(RangeDatasetParams(0, 1, 1),
100                                 /*num_workers=*/5,
101                                 /*index=*/2,
102                                 /*auto_shard_policy=*/0,
103                                 /*num_replicas=*/5,
104                                 /*output_dtypes=*/{DT_INT64},
105                                 /*output_shapes=*/{PartialTensorShape({})},
106                                 /*node_name=*/kNodeName);
107 }
108 
109 // Test Case 3: the number of outputs could not be evenly divided by
110 // num_workers.
AutoShardDatasetParams3()111 AutoShardDatasetParams AutoShardDatasetParams3() {
112   return AutoShardDatasetParams(RangeDatasetParams(0, 10, 1),
113                                 /*num_workers=*/4,
114                                 /*index=*/3,
115                                 /*auto_shard_policy=*/0,
116                                 /*num_replicas=*/4,
117                                 /*output_dtypes=*/{DT_INT64},
118                                 /*output_shapes=*/{PartialTensorShape({})},
119                                 /*node_name=*/kNodeName);
120 }
121 
122 // TODO(feihugis): add more test cases that have ReaderDatasets (e.g. a
123 // CSVDataset or a TFRecordDataset) in the pipeline.
124 
125 // Test case 4: the index is greater than the number of workers.
AutoShardDatasetParams4()126 AutoShardDatasetParams AutoShardDatasetParams4() {
127   return AutoShardDatasetParams(RangeDatasetParams(0, 10, 1),
128                                 /*num_workers=*/5,
129                                 /*index=*/7,
130                                 /*auto_shard_policy=*/0,
131                                 /*num_replicas=*/5,
132                                 /*output_dtypes=*/{DT_INT64},
133                                 /*output_shapes=*/{PartialTensorShape({})},
134                                 /*node_name=*/kNodeName);
135 }
136 
137 // Test case 5: the index is negative.
AutoShardDatasetParams5()138 AutoShardDatasetParams AutoShardDatasetParams5() {
139   return AutoShardDatasetParams(RangeDatasetParams(0, 10, 1),
140                                 /*num_workers=*/5,
141                                 /*index=*/-3,
142                                 /*auto_shard_policy=*/0,
143                                 /*num_replicas=*/5,
144                                 /*output_dtypes=*/{DT_INT64},
145                                 /*output_shapes=*/{PartialTensorShape({})},
146                                 /*node_name=*/kNodeName);
147 }
148 
149 // Test case 6: num_workers is negative.
AutoShardDatasetParams6()150 AutoShardDatasetParams AutoShardDatasetParams6() {
151   return AutoShardDatasetParams(RangeDatasetParams(0, 10, 1),
152                                 /*num_workers=*/-3,
153                                 /*index=*/1,
154                                 /*auto_shard_policy=*/0,
155                                 /*num_replicas=*/5,
156                                 /*output_dtypes=*/{DT_INT64},
157                                 /*output_shapes=*/{PartialTensorShape({})},
158                                 /*node_name=*/kNodeName);
159 }
160 
161 // Test case 7: num_workers is zero.
AutoShardDatasetParams7()162 AutoShardDatasetParams AutoShardDatasetParams7() {
163   return AutoShardDatasetParams(RangeDatasetParams(0, 10, 1),
164                                 /*num_workers=*/0,
165                                 /*index=*/1,
166                                 /*auto_shard_policy=*/0,
167                                 /*num_replicas=*/5,
168                                 /*output_dtypes=*/{DT_INT64},
169                                 /*output_shapes=*/{PartialTensorShape({})},
170                                 /*node_name=*/kNodeName);
171 }
172 
GetNextTestCases()173 std::vector<GetNextTestCase<AutoShardDatasetParams>> GetNextTestCases() {
174   return {
175       {/*dataset_params=*/AutoShardDatasetParams1(),
176        /*expected_outputs=*/CreateTensors<int64_t>(TensorShape{}, {{2}, {7}})},
177       {/*dataset_params=*/AutoShardDatasetParams2(),
178        /*expected_outputs=*/{}},
179       {/*dataset_params=*/AutoShardDatasetParams3(),
180        /*expected_outputs=*/CreateTensors<int64_t>(TensorShape{}, {{3}, {7}})}};
181 }
182 
ITERATOR_GET_NEXT_TEST_P(AutoShardDatasetOpTest,AutoShardDatasetParams,GetNextTestCases ())183 ITERATOR_GET_NEXT_TEST_P(AutoShardDatasetOpTest, AutoShardDatasetParams,
184                          GetNextTestCases())
185 
186 TEST_F(AutoShardDatasetOpTest, InvalidArguments) {
187   std::vector<AutoShardDatasetParams> invalid_dataset_params = {
188       AutoShardDatasetParams4(), AutoShardDatasetParams5(),
189       AutoShardDatasetParams6(), AutoShardDatasetParams7()};
190   for (const auto& dataset_params : invalid_dataset_params) {
191     EXPECT_EQ(Initialize(dataset_params).code(),
192               tensorflow::error::INVALID_ARGUMENT);
193   }
194 }
195 
196 // TODO(b/222556529) when Const has type constructor, remove the following
197 REGISTER_OP("AutoShardDatasetOpTest>ConstTypeCtor")
198     .Output("output: dtype")
199     .Attr("value: tensor")
200     .Attr("dtype: type")
201     .SetTypeConstructor(full_type::Unary(TFT_TENSOR, "dtype"));
202 
203 // Adds identity notes to all outputs of this node
add_identity_nodes(Node * node,Graph & graph,std::vector<Node * > & identity_nodes)204 static void add_identity_nodes(Node* node, Graph& graph,
205                                std::vector<Node*>& identity_nodes) {
206   for (int i = 0; i < node->num_outputs(); i++) {
207     Node* new_node;
208     std::string name = absl::StrCat("Identity", i);
209     TF_EXPECT_OK(NodeBuilder(name, "Identity")
210                      .Attr("T", node->output_type(i))
211                      .Input(node, i)
212                      .Finalize(&graph, &new_node));
213     identity_nodes.push_back(new_node);
214   }
215 }
216 
217 // Runs type inference pass on graph
type_inference(Graph & graph)218 static Status type_inference(Graph& graph) {
219   GraphOptimizationPassOptions opt_options;
220   std::unique_ptr<Graph> graph_ptr(new Graph(OpRegistry::Global()));
221   graph_ptr->Copy(graph);
222   opt_options.graph = &graph_ptr;
223   opt_options.flib_def = graph.mutable_flib_def();
224   ForwardTypeInferencePass pass;
225   return pass.Run(opt_options);
226 }
227 
TEST_F(AutoShardDatasetOpTest,AutoShardDatasetTypeInference)228 TEST_F(AutoShardDatasetOpTest, AutoShardDatasetTypeInference) {
229   Graph graph(OpRegistry::Global());
230   Node* input_dataset;
231   Node* num_workers;
232   Node* index;
233   Node* auto_shard_dataset;
234   FullTypeDef input_dataset_t;
235   protobuf::TextFormat::Parser parser;
236   CHECK(parser.ParseFromString(
237       R"pb(type_id: TFT_PRODUCT
238            args {
239              type_id: TFT_DATASET
240              args {
241                type_id: TFT_PRODUCT
242                args {
243                  type_id: TFT_RAGGED
244                  args { type_id: TFT_STRING }
245                }
246              }
247            })pb",
248       &input_dataset_t));
249   TensorProto tensor_proto;
250   TF_EXPECT_OK(NodeBuilder("input_dataset", "Const")
251                    .Attr("value", tensor_proto)
252                    .Attr("dtype", DT_VARIANT)
253                    .Finalize(&graph, &input_dataset));
254   (*input_dataset->mutable_def()->mutable_experimental_type()) =
255       input_dataset_t;
256   // TODO(b/222556529) when Const has type constructor, use Const
257   TF_EXPECT_OK(
258       NodeBuilder("num_workers", "AutoShardDatasetOpTest>ConstTypeCtor")
259           .Attr("value", tensor_proto)
260           .Attr("dtype", DT_INT64)
261           .Finalize(&graph, &num_workers));
262   // TODO(b/222556529) when Const has type constructor, use Const
263   TF_EXPECT_OK(NodeBuilder("index", "AutoShardDatasetOpTest>ConstTypeCtor")
264                    .Attr("value", tensor_proto)
265                    .Attr("dtype", DT_INT64)
266                    .Finalize(&graph, &index));
267   TF_EXPECT_OK(NodeBuilder("AutoShardDataset", "AutoShardDataset")
268                    .Attr("output_types", {DT_VARIANT})
269                    .Attr("output_shapes", {TensorShape({1})})
270                    .Input(input_dataset)
271                    .Input(num_workers)
272                    .Input(index)
273                    .Finalize(&graph, &auto_shard_dataset));
274   std::vector<Node*> identity_nodes;
275   add_identity_nodes(auto_shard_dataset, graph, identity_nodes);
276   TF_EXPECT_OK(type_inference(graph));
277   EXPECT_TRUE(full_type::IsEqual(identity_nodes[0]->def().experimental_type(),
278                                  input_dataset_t))
279       << "fulltype is\n"
280       << identity_nodes[0]->def().experimental_type().DebugString()
281       << "\nexpected\n"
282       << input_dataset_t.DebugString();
283 }
284 
TEST_F(AutoShardDatasetOpTest,RebatchDatasetTypeInference)285 TEST_F(AutoShardDatasetOpTest, RebatchDatasetTypeInference) {
286   Graph graph(OpRegistry::Global());
287   Node* input_dataset;
288   Node* num_replicas;
289   Node* rebatch_dataset;
290   FullTypeDef input_dataset_t;
291   protobuf::TextFormat::Parser parser;
292   CHECK(parser.ParseFromString(
293       R"pb(type_id: TFT_PRODUCT
294            args {
295              type_id: TFT_DATASET
296              args {
297                type_id: TFT_PRODUCT
298                args {
299                  type_id: TFT_RAGGED
300                  args { type_id: TFT_STRING }
301                }
302              }
303            })pb",
304       &input_dataset_t));
305   TensorProto tensor_proto;
306   TF_EXPECT_OK(NodeBuilder("input_dataset", "Const")
307                    .Attr("value", tensor_proto)
308                    .Attr("dtype", DT_VARIANT)
309                    .Finalize(&graph, &input_dataset));
310   (*input_dataset->mutable_def()->mutable_experimental_type()) =
311       input_dataset_t;
312   // TODO(b/222556529) when Const has type constructor, use Const
313   TF_EXPECT_OK(
314       NodeBuilder("num_replicas", "AutoShardDatasetOpTest>ConstTypeCtor")
315           .Attr("value", tensor_proto)
316           .Attr("dtype", DT_INT64)
317           .Finalize(&graph, &num_replicas));
318   TF_EXPECT_OK(NodeBuilder("RebatchDataset", "RebatchDataset")
319                    .Attr("output_types", {DT_VARIANT})
320                    .Attr("output_shapes", {TensorShape({1})})
321                    .Input(input_dataset)
322                    .Input(num_replicas)
323                    .Finalize(&graph, &rebatch_dataset));
324   std::vector<Node*> identity_nodes;
325   add_identity_nodes(rebatch_dataset, graph, identity_nodes);
326   TF_EXPECT_OK(type_inference(graph));
327   EXPECT_TRUE(full_type::IsEqual(identity_nodes[0]->def().experimental_type(),
328                                  input_dataset_t))
329       << "fulltype is\n"
330       << identity_nodes[0]->def().experimental_type().DebugString()
331       << "\nexpected\n"
332       << input_dataset_t.DebugString();
333 }
334 
TEST_F(AutoShardDatasetOpTest,RebatchDatasetV2TypeInference)335 TEST_F(AutoShardDatasetOpTest, RebatchDatasetV2TypeInference) {
336   Graph graph(OpRegistry::Global());
337   Node* input_dataset;
338   Node* batch_sizes;
339   Node* drop_remainder;
340   Node* rebatch_dataset_v2;
341   FullTypeDef input_dataset_t;
342   protobuf::TextFormat::Parser parser;
343   CHECK(parser.ParseFromString(
344       R"pb(type_id: TFT_PRODUCT
345            args {
346              type_id: TFT_DATASET
347              args {
348                type_id: TFT_PRODUCT
349                args {
350                  type_id: TFT_RAGGED
351                  args { type_id: TFT_STRING }
352                }
353              }
354            })pb",
355       &input_dataset_t));
356   TensorProto tensor_proto;
357   TF_EXPECT_OK(NodeBuilder("input_dataset", "Const")
358                    .Attr("value", tensor_proto)
359                    .Attr("dtype", DT_VARIANT)
360                    .Finalize(&graph, &input_dataset));
361   (*input_dataset->mutable_def()->mutable_experimental_type()) =
362       input_dataset_t;
363   // TODO(b/222556529) when Const has type constructor, use Const
364   TF_EXPECT_OK(
365       NodeBuilder("num_replicas", "AutoShardDatasetOpTest>ConstTypeCtor")
366           .Attr("value", tensor_proto)
367           .Attr("dtype", DT_INT64)
368           .Finalize(&graph, &batch_sizes));
369   // TODO(b/222556529) when Const has type constructor, use Const
370   TF_EXPECT_OK(
371       NodeBuilder("drop_remainder", "AutoShardDatasetOpTest>ConstTypeCtor")
372           .Attr("value", tensor_proto)
373           .Attr("dtype", DT_BOOL)
374           .Finalize(&graph, &drop_remainder));
375   TF_EXPECT_OK(NodeBuilder("RebatchDatasetV2", "RebatchDatasetV2")
376                    .Attr("output_types", {DT_VARIANT})
377                    .Attr("output_shapes", {TensorShape({1})})
378                    .Input(input_dataset)
379                    .Input(batch_sizes)
380                    .Input(drop_remainder)
381                    .Finalize(&graph, &rebatch_dataset_v2));
382   std::vector<Node*> identity_nodes;
383   add_identity_nodes(rebatch_dataset_v2, graph, identity_nodes);
384   TF_EXPECT_OK(type_inference(graph));
385   EXPECT_TRUE(full_type::IsEqual(identity_nodes[0]->def().experimental_type(),
386                                  input_dataset_t))
387       << "fulltype is\n"
388       << identity_nodes[0]->def().experimental_type().DebugString()
389       << "\nexpected\n"
390       << input_dataset_t.DebugString();
391 }
392 
393 }  // namespace
394 }  // namespace experimental
395 }  // namespace data
396 }  // namespace tensorflow
397