1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 #include "tensorflow/core/kernels/data/batch_dataset_op.h"
13
14 #include <string>
15
16 #include "tensorflow/core/common_runtime/forward_type_inference.h"
17 #include "tensorflow/core/data/dataset_test_base.h"
18 #include "tensorflow/core/graph/node_builder.h"
19 #include "tensorflow/core/public/session_options.h"
20
21 namespace tensorflow {
22 namespace data {
23 namespace {
24
25 constexpr char kNodeName[] = "batch_dataset";
26
27 class BatchDatasetOpTest : public DatasetOpsTestBase {};
28
29 // Test Case 1: test BatchDatasetV2 with `drop_remainder` = false and a batch
30 // size that can evenly split the input dataset.
BatchDatasetParams1()31 BatchDatasetParams BatchDatasetParams1() {
32 return BatchDatasetParams(RangeDatasetParams(0, 12, 1),
33 /*batch_size=*/4,
34 /*drop_remainder=*/false,
35 /*parallel_copy=*/true,
36 /*output_dtypes=*/{DT_INT64},
37 /*output_shapes=*/{PartialTensorShape({4})},
38 /*node_name=*/kNodeName);
39 }
40
41 // Test Case 2: test BatchDatasetV2 with `drop_remainder` = true and a batch
42 // size that can evenly split the input dataset.
BatchDatasetParams2()43 BatchDatasetParams BatchDatasetParams2() {
44 return BatchDatasetParams(RangeDatasetParams(0, 12, 1),
45 /*batch_size=*/4,
46 /*drop_remainder=*/true,
47 /*parallel_copy=*/false,
48 /*output_dtypes=*/{DT_INT64},
49 /*output_shapes=*/{PartialTensorShape({4})},
50 /*node_name=*/kNodeName);
51 }
52
53 // Test Case 3: test BatchDatasetV2 with `drop_remainder` = false and a batch
54 // size that can not evenly split the input dataset.
BatchDatasetParams3()55 BatchDatasetParams BatchDatasetParams3() {
56 return BatchDatasetParams(RangeDatasetParams(0, 10, 1),
57 /*batch_size=*/3,
58 /*drop_remainder=*/false,
59 /*parallel_copy=*/false,
60 /*output_dtypes=*/{DT_INT64},
61 /*output_shapes=*/{PartialTensorShape({-1})},
62 /*node_name=*/kNodeName);
63 }
64
65 // Test Case 4: test BatchDatasetV2 with `drop_remainder` = true and a batch
66 // size that can not evenly split the input dataset.
BatchDatasetParams4()67 BatchDatasetParams BatchDatasetParams4() {
68 return BatchDatasetParams(RangeDatasetParams(0, 10, 1),
69 /*batch_size=*/3,
70 /*drop_remainder=*/true,
71 /*parallel_copy=*/true,
72 /*output_dtypes=*/{DT_INT64},
73 /*output_shapes=*/{PartialTensorShape({3})},
74 /*node_name=*/kNodeName);
75 }
76
77 // Test Case 5: test BatchDatasetV2 with `drop_remainder` = true and
78 // `batch_size` > the cardinality of the input dataset.
BatchDatasetParams5()79 BatchDatasetParams BatchDatasetParams5() {
80 return BatchDatasetParams(RangeDatasetParams(0, 10, 1),
81 /*batch_size=*/12,
82 /*drop_remainder=*/true,
83 /*parallel_copy=*/true,
84 /*output_dtypes=*/{DT_INT64},
85 /*output_shapes=*/{PartialTensorShape({12})},
86 /*node_name=*/kNodeName);
87 }
88
89 // Test Case 6: test BatchDatasetV2 with `drop_remainder` = false and
90 // `batch_size` > the cardinality of the input dataset.
BatchDatasetParams6()91 BatchDatasetParams BatchDatasetParams6() {
92 return BatchDatasetParams(RangeDatasetParams(0, 10, 1),
93 /*batch_size=*/12,
94 /*drop_remainder=*/false,
95 /*parallel_copy=*/true,
96 /*output_dtypes=*/{DT_INT64},
97 /*output_shapes=*/{PartialTensorShape({-1})},
98 /*node_name=*/kNodeName);
99 }
100
101 // Test Case 7: test BatchDatasetV2 with `drop_remainder` = false and
102 // the output of the input dataset is empty.
BatchDatasetParams7()103 BatchDatasetParams BatchDatasetParams7() {
104 return BatchDatasetParams(RangeDatasetParams(0, 0, 1),
105 /*batch_size=*/4,
106 /*drop_remainder=*/false,
107 /*parallel_copy=*/false,
108 /*output_dtypes=*/{DT_INT64},
109 /*output_shapes=*/{PartialTensorShape({4})},
110 /*node_name=*/kNodeName);
111 }
112
113 // Test Case 8: test BatchDatasetV2 with an invalid batch size
InvalidBatchSizeBatchDatasetParams()114 BatchDatasetParams InvalidBatchSizeBatchDatasetParams() {
115 return BatchDatasetParams(RangeDatasetParams(0, 10, 1),
116 /*batch_size=*/-1,
117 /*drop_remainder=*/false,
118 /*parallel_copy=*/false,
119 /*output_dtypes=*/{DT_INT64},
120 /*output_shapes=*/{PartialTensorShape({3})},
121 /*node_name=*/kNodeName);
122 }
123
GetNextTestCases()124 std::vector<GetNextTestCase<BatchDatasetParams>> GetNextTestCases() {
125 return {{/*dataset_params=*/BatchDatasetParams1(),
126 /*expected_outputs=*/
127 CreateTensors<int64_t>(
128 TensorShape({4}), {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})},
129 {/*dataset_params=*/BatchDatasetParams2(),
130 /*expected_outputs=*/
131 CreateTensors<int64_t>(
132 TensorShape({4}), {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})},
133 {/*dataset_params=*/BatchDatasetParams3(),
134 /*expected_outputs=*/
135 {CreateTensor<int64_t>(TensorShape({3}), {0, 1, 2}),
136 CreateTensor<int64_t>(TensorShape({3}), {3, 4, 5}),
137 CreateTensor<int64_t>(TensorShape({3}), {6, 7, 8}),
138 CreateTensor<int64_t>(TensorShape({1}), {9})}},
139 {/*dataset_params=*/BatchDatasetParams4(),
140 /*expected_outputs=*/
141 CreateTensors<int64_t>(TensorShape({3}),
142 {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}})},
143 {/*dataset_params=*/BatchDatasetParams5(),
144 /*expected_outputs=*/{}},
145 {/*dataset_params=*/BatchDatasetParams6(),
146 /*expected_outputs=*/
147 CreateTensors<int64_t>(TensorShape({10}),
148 {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}})},
149
150 {/*dataset_params=*/BatchDatasetParams7(),
151 /*expected_outputs=*/{}}};
152 }
153
ITERATOR_GET_NEXT_TEST_P(BatchDatasetOpTest,BatchDatasetParams,GetNextTestCases ())154 ITERATOR_GET_NEXT_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
155 GetNextTestCases())
156
157 TEST_F(BatchDatasetOpTest, DatasetNodeName) {
158 auto batch_dataset_params = BatchDatasetParams1();
159 TF_ASSERT_OK(Initialize(batch_dataset_params));
160 TF_ASSERT_OK(CheckDatasetNodeName(batch_dataset_params.node_name()));
161 }
162
TEST_F(BatchDatasetOpTest,DatasetTypeString)163 TEST_F(BatchDatasetOpTest, DatasetTypeString) {
164 auto batch_dataset_params = BatchDatasetParams1();
165 TF_ASSERT_OK(Initialize(batch_dataset_params));
166 name_utils::OpNameParams params;
167 params.op_version = batch_dataset_params.op_version();
168 TF_ASSERT_OK(CheckDatasetTypeString(
169 name_utils::OpName(BatchDatasetOp::kDatasetType, params)));
170 }
171
TEST_F(BatchDatasetOpTest,DatasetOutputDtypes)172 TEST_F(BatchDatasetOpTest, DatasetOutputDtypes) {
173 auto batch_dataset_params = BatchDatasetParams1();
174 TF_ASSERT_OK(Initialize(batch_dataset_params));
175 TF_ASSERT_OK(CheckDatasetOutputDtypes({DT_INT64}));
176 }
177
178 std::vector<DatasetOutputShapesTestCase<BatchDatasetParams>>
DatasetOutputShapesTestCases()179 DatasetOutputShapesTestCases() {
180 return {{/*dataset_params=*/BatchDatasetParams1(),
181 /*expected_output_shapes=*/{PartialTensorShape({4})}},
182 {/*dataset_params=*/BatchDatasetParams2(),
183 /*expected_output_shapes=*/{PartialTensorShape({4})}},
184 {/*dataset_params=*/BatchDatasetParams3(),
185 /*expected_output_shapes=*/{PartialTensorShape({-1})}},
186 {/*dataset_params=*/BatchDatasetParams4(),
187 /*expected_output_shapes=*/{PartialTensorShape({3})}},
188 {/*dataset_params=*/BatchDatasetParams5(),
189 /*expected_output_shapes=*/{PartialTensorShape({12})}},
190 {/*dataset_params=*/BatchDatasetParams6(),
191 /*expected_output_shapes=*/{PartialTensorShape({-1})}},
192 {/*dataset_params=*/BatchDatasetParams7(),
193 /*expected_output_shapes=*/{PartialTensorShape({4})}}};
194 }
195
DATASET_OUTPUT_SHAPES_TEST_P(BatchDatasetOpTest,BatchDatasetParams,DatasetOutputShapesTestCases ())196 DATASET_OUTPUT_SHAPES_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
197 DatasetOutputShapesTestCases())
198
199 std::vector<CardinalityTestCase<BatchDatasetParams>> CardinalityTestCases() {
200 return {
201 {/*dataset_params=*/BatchDatasetParams1(), /*expected_cardinality=*/3},
202 {/*dataset_params=*/BatchDatasetParams2(), /*expected_cardinality=*/3},
203 {/*dataset_params=*/BatchDatasetParams3(), /*expected_cardinality=*/4},
204 {/*dataset_params=*/BatchDatasetParams4(), /*expected_cardinality=*/3},
205 {/*dataset_params=*/BatchDatasetParams5(), /*expected_cardinality=*/0},
206 {/*dataset_params=*/BatchDatasetParams6(), /*expected_cardinality=*/1},
207 {/*dataset_params=*/BatchDatasetParams7(), /*expected_cardinality=*/0}};
208 }
209
DATASET_CARDINALITY_TEST_P(BatchDatasetOpTest,BatchDatasetParams,CardinalityTestCases ())210 DATASET_CARDINALITY_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
211 CardinalityTestCases())
212
213 TEST_F(BatchDatasetOpTest, IteratorOutputDtypes) {
214 auto batch_dataset_params = BatchDatasetParams1();
215 TF_ASSERT_OK(Initialize(batch_dataset_params));
216 TF_ASSERT_OK(CheckIteratorOutputDtypes({DT_INT64}));
217 }
218
219 std::vector<IteratorOutputShapesTestCase<BatchDatasetParams>>
IteratorOutputShapesTestCases()220 IteratorOutputShapesTestCases() {
221 return {{/*dataset_params=*/BatchDatasetParams1(),
222 /*expected_output_shapes=*/{PartialTensorShape({4})}},
223 {/*dataset_params=*/BatchDatasetParams2(),
224 /*expected_output_shapes=*/{PartialTensorShape({4})}},
225 {/*dataset_params=*/BatchDatasetParams3(),
226 /*expected_output_shapes=*/{PartialTensorShape({-1})}},
227 {/*dataset_params=*/BatchDatasetParams4(),
228 /*expected_output_shapes=*/{PartialTensorShape({3})}},
229 {/*dataset_params=*/BatchDatasetParams5(),
230 /*expected_output_shapes=*/{PartialTensorShape({12})}},
231 {/*dataset_params=*/BatchDatasetParams6(),
232 /*expected_output_shapes=*/{PartialTensorShape({-1})}},
233 {/*dataset_params=*/BatchDatasetParams7(),
234 /*expected_output_shapes=*/{PartialTensorShape({4})}}};
235 }
236
ITERATOR_OUTPUT_SHAPES_TEST_P(BatchDatasetOpTest,BatchDatasetParams,IteratorOutputShapesTestCases ())237 ITERATOR_OUTPUT_SHAPES_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
238 IteratorOutputShapesTestCases())
239
240 TEST_F(BatchDatasetOpTest, IteratorOutputPrefix) {
241 auto batch_dataset_params = BatchDatasetParams1();
242 TF_ASSERT_OK(Initialize(batch_dataset_params));
243 name_utils::IteratorPrefixParams params;
244 params.op_version = batch_dataset_params.op_version();
245 TF_ASSERT_OK(CheckIteratorPrefix(name_utils::IteratorPrefix(
246 BatchDatasetOp::kDatasetType, batch_dataset_params.iterator_prefix(),
247 params)));
248 }
249
250 std::vector<IteratorSaveAndRestoreTestCase<BatchDatasetParams>>
IteratorSaveAndRestoreTestCases()251 IteratorSaveAndRestoreTestCases() {
252 return {{/*dataset_params=*/BatchDatasetParams1(),
253 /*breakpoints=*/{0, 1, 5},
254 /*expected_outputs=*/
255 CreateTensors<int64_t>(
256 TensorShape({4}), {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})},
257 {/*dataset_params=*/BatchDatasetParams2(),
258 /*breakpoints=*/{0, 1, 5},
259 /*expected_outputs=*/
260 CreateTensors<int64_t>(
261 TensorShape({4}), {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})},
262 {/*dataset_params=*/BatchDatasetParams3(),
263 /*breakpoints=*/{0, 1, 5},
264 /*expected_outputs=*/
265 {CreateTensor<int64_t>(TensorShape({3}), {0, 1, 2}),
266 CreateTensor<int64_t>(TensorShape({3}), {3, 4, 5}),
267 CreateTensor<int64_t>(TensorShape({3}), {6, 7, 8}),
268 CreateTensor<int64_t>(TensorShape({1}), {9})}},
269 {/*dataset_params=*/BatchDatasetParams4(),
270 /*breakpoints=*/{0, 1, 5},
271 /*expected_outputs=*/
272 {CreateTensor<int64_t>(TensorShape({3}), {0, 1, 2}),
273 CreateTensor<int64_t>(TensorShape({3}), {3, 4, 5}),
274 CreateTensor<int64_t>(TensorShape({3}), {6, 7, 8})}},
275 {/*dataset_params=*/BatchDatasetParams5(),
276 /*breakpoints=*/{0, 1, 5},
277 /*expected_outputs=*/{}},
278 {/*dataset_params=*/BatchDatasetParams6(),
279 /*breakpoints=*/{0, 1, 5},
280 /*expected_outputs=*/
281 {CreateTensor<int64_t>(TensorShape({10}),
282 {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}},
283 {/*dataset_params=*/BatchDatasetParams7(),
284 /*breakpoints=*/{0, 1, 5},
285 /*expected_outputs=*/{}}};
286 }
287
ITERATOR_SAVE_AND_RESTORE_TEST_P(BatchDatasetOpTest,BatchDatasetParams,IteratorSaveAndRestoreTestCases ())288 ITERATOR_SAVE_AND_RESTORE_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
289 IteratorSaveAndRestoreTestCases())
290
291 TEST_F(BatchDatasetOpTest, InvalidBatchSize) {
292 auto batch_dataset_params = InvalidBatchSizeBatchDatasetParams();
293 EXPECT_EQ(Initialize(batch_dataset_params).code(),
294 tensorflow::error::INVALID_ARGUMENT);
295 }
296
297 // TODO(b/222556529) when Const has type constructor, remove the following
298 REGISTER_OP("BatchDatasetOpTest>ConstTypeCtor")
299 .Output("output: dtype")
300 .Attr("value: tensor")
301 .Attr("dtype: type")
302 .SetTypeConstructor(full_type::Unary(TFT_TENSOR, "dtype"));
303
304 // Adds identity notes to all outputs of this node
add_identity_nodes(Node * node,Graph & graph,std::vector<Node * > & identity_nodes)305 static void add_identity_nodes(Node* node, Graph& graph,
306 std::vector<Node*>& identity_nodes) {
307 for (int i = 0; i < node->num_outputs(); i++) {
308 Node* new_node;
309 std::string name = absl::StrCat("Identity", i);
310 TF_EXPECT_OK(NodeBuilder(name, "Identity")
311 .Attr("T", node->output_type(i))
312 .Input(node, i)
313 .Finalize(&graph, &new_node));
314 identity_nodes.push_back(new_node);
315 }
316 }
317
318 // Runs type inference pass on graph
type_inference(Graph & graph)319 static Status type_inference(Graph& graph) {
320 GraphOptimizationPassOptions opt_options;
321 std::unique_ptr<Graph> graph_ptr(new Graph(OpRegistry::Global()));
322 graph_ptr->Copy(graph);
323 opt_options.graph = &graph_ptr;
324 opt_options.flib_def = graph.mutable_flib_def();
325 ForwardTypeInferencePass pass;
326 return pass.Run(opt_options);
327 }
328
TEST(BatchDatsetOpTest,TypeInference)329 TEST(BatchDatsetOpTest, TypeInference) {
330 Graph graph(OpRegistry::Global());
331 Node* input_dataset;
332 Node* batch_size;
333 Node* drop_remainder;
334 Node* batch_dataset_v2;
335 FullTypeDef input_dataset_t;
336 protobuf::TextFormat::Parser parser;
337 CHECK(parser.ParseFromString(
338 R"pb(type_id: TFT_PRODUCT
339 args {
340 type_id: TFT_DATASET
341 args {
342 type_id: TFT_PRODUCT
343 args {
344 type_id: TFT_RAGGED
345 args { type_id: TFT_STRING }
346 }
347 }
348 })pb",
349 &input_dataset_t));
350 TensorProto tensor_proto;
351 TF_EXPECT_OK(NodeBuilder("input_dataset", "Const")
352 .Attr("value", tensor_proto)
353 .Attr("dtype", DT_VARIANT)
354 .Finalize(&graph, &input_dataset));
355 (*input_dataset->mutable_def()->mutable_experimental_type()) =
356 input_dataset_t;
357 // TODO(b/222556529) when Const has type constructor, use Const
358 TF_EXPECT_OK(NodeBuilder("batch_size", "BatchDatasetOpTest>ConstTypeCtor")
359 .Attr("value", tensor_proto)
360 .Attr("dtype", DT_INT64)
361 .Finalize(&graph, &batch_size));
362 // TODO(b/222556529) when Const has type constructor, use Const
363 TF_EXPECT_OK(NodeBuilder("drop_remainder", "BatchDatasetOpTest>ConstTypeCtor")
364 .Attr("value", tensor_proto)
365 .Attr("dtype", DT_BOOL)
366 .Finalize(&graph, &drop_remainder));
367 TF_EXPECT_OK(NodeBuilder("BatchDatasetV2", "BatchDatasetV2")
368 .Attr("output_types", {DT_VARIANT})
369 .Attr("output_shapes", {TensorShape({1})})
370 .Input(input_dataset)
371 .Input(batch_size)
372 .Input(drop_remainder)
373 .Finalize(&graph, &batch_dataset_v2));
374
375 std::vector<Node*> identity_nodes;
376 add_identity_nodes(batch_dataset_v2, graph, identity_nodes);
377 TF_EXPECT_OK(type_inference(graph));
378 EXPECT_TRUE(full_type::IsEqual(identity_nodes[0]->def().experimental_type(),
379 input_dataset_t))
380 << "fulltype is\n"
381 << identity_nodes[0]->def().experimental_type().DebugString()
382 << "\nexpected\n"
383 << input_dataset_t.DebugString();
384 }
385
386 } // namespace
387 } // namespace data
388 } // namespace tensorflow
389