xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/batch_dataset_op_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5     http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 #include "tensorflow/core/kernels/data/batch_dataset_op.h"
13 
14 #include <string>
15 
16 #include "tensorflow/core/common_runtime/forward_type_inference.h"
17 #include "tensorflow/core/data/dataset_test_base.h"
18 #include "tensorflow/core/graph/node_builder.h"
19 #include "tensorflow/core/public/session_options.h"
20 
21 namespace tensorflow {
22 namespace data {
23 namespace {
24 
25 constexpr char kNodeName[] = "batch_dataset";
26 
27 class BatchDatasetOpTest : public DatasetOpsTestBase {};
28 
29 // Test Case 1: test BatchDatasetV2 with `drop_remainder` = false and a batch
30 // size that can evenly split the input dataset.
BatchDatasetParams1()31 BatchDatasetParams BatchDatasetParams1() {
32   return BatchDatasetParams(RangeDatasetParams(0, 12, 1),
33                             /*batch_size=*/4,
34                             /*drop_remainder=*/false,
35                             /*parallel_copy=*/true,
36                             /*output_dtypes=*/{DT_INT64},
37                             /*output_shapes=*/{PartialTensorShape({4})},
38                             /*node_name=*/kNodeName);
39 }
40 
41 // Test Case 2: test BatchDatasetV2 with `drop_remainder` = true and a batch
42 // size that can evenly split the input dataset.
BatchDatasetParams2()43 BatchDatasetParams BatchDatasetParams2() {
44   return BatchDatasetParams(RangeDatasetParams(0, 12, 1),
45                             /*batch_size=*/4,
46                             /*drop_remainder=*/true,
47                             /*parallel_copy=*/false,
48                             /*output_dtypes=*/{DT_INT64},
49                             /*output_shapes=*/{PartialTensorShape({4})},
50                             /*node_name=*/kNodeName);
51 }
52 
53 // Test Case 3: test BatchDatasetV2 with `drop_remainder` = false and a batch
54 // size that can not evenly split the input dataset.
BatchDatasetParams3()55 BatchDatasetParams BatchDatasetParams3() {
56   return BatchDatasetParams(RangeDatasetParams(0, 10, 1),
57                             /*batch_size=*/3,
58                             /*drop_remainder=*/false,
59                             /*parallel_copy=*/false,
60                             /*output_dtypes=*/{DT_INT64},
61                             /*output_shapes=*/{PartialTensorShape({-1})},
62                             /*node_name=*/kNodeName);
63 }
64 
65 // Test Case 4: test BatchDatasetV2 with `drop_remainder` = true and a batch
66 // size that can not evenly split the input dataset.
BatchDatasetParams4()67 BatchDatasetParams BatchDatasetParams4() {
68   return BatchDatasetParams(RangeDatasetParams(0, 10, 1),
69                             /*batch_size=*/3,
70                             /*drop_remainder=*/true,
71                             /*parallel_copy=*/true,
72                             /*output_dtypes=*/{DT_INT64},
73                             /*output_shapes=*/{PartialTensorShape({3})},
74                             /*node_name=*/kNodeName);
75 }
76 
77 // Test Case 5: test BatchDatasetV2 with `drop_remainder` = true and
78 // `batch_size` > the cardinality of the input dataset.
BatchDatasetParams5()79 BatchDatasetParams BatchDatasetParams5() {
80   return BatchDatasetParams(RangeDatasetParams(0, 10, 1),
81                             /*batch_size=*/12,
82                             /*drop_remainder=*/true,
83                             /*parallel_copy=*/true,
84                             /*output_dtypes=*/{DT_INT64},
85                             /*output_shapes=*/{PartialTensorShape({12})},
86                             /*node_name=*/kNodeName);
87 }
88 
89 // Test Case 6: test BatchDatasetV2 with `drop_remainder` = false and
90 // `batch_size` > the cardinality of the input dataset.
BatchDatasetParams6()91 BatchDatasetParams BatchDatasetParams6() {
92   return BatchDatasetParams(RangeDatasetParams(0, 10, 1),
93                             /*batch_size=*/12,
94                             /*drop_remainder=*/false,
95                             /*parallel_copy=*/true,
96                             /*output_dtypes=*/{DT_INT64},
97                             /*output_shapes=*/{PartialTensorShape({-1})},
98                             /*node_name=*/kNodeName);
99 }
100 
101 // Test Case 7: test BatchDatasetV2 with `drop_remainder` = false and
102 // the output of the input dataset is empty.
BatchDatasetParams7()103 BatchDatasetParams BatchDatasetParams7() {
104   return BatchDatasetParams(RangeDatasetParams(0, 0, 1),
105                             /*batch_size=*/4,
106                             /*drop_remainder=*/false,
107                             /*parallel_copy=*/false,
108                             /*output_dtypes=*/{DT_INT64},
109                             /*output_shapes=*/{PartialTensorShape({4})},
110                             /*node_name=*/kNodeName);
111 }
112 
113 // Test Case 8: test BatchDatasetV2 with an invalid batch size
InvalidBatchSizeBatchDatasetParams()114 BatchDatasetParams InvalidBatchSizeBatchDatasetParams() {
115   return BatchDatasetParams(RangeDatasetParams(0, 10, 1),
116                             /*batch_size=*/-1,
117                             /*drop_remainder=*/false,
118                             /*parallel_copy=*/false,
119                             /*output_dtypes=*/{DT_INT64},
120                             /*output_shapes=*/{PartialTensorShape({3})},
121                             /*node_name=*/kNodeName);
122 }
123 
GetNextTestCases()124 std::vector<GetNextTestCase<BatchDatasetParams>> GetNextTestCases() {
125   return {{/*dataset_params=*/BatchDatasetParams1(),
126            /*expected_outputs=*/
127            CreateTensors<int64_t>(
128                TensorShape({4}), {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})},
129           {/*dataset_params=*/BatchDatasetParams2(),
130            /*expected_outputs=*/
131            CreateTensors<int64_t>(
132                TensorShape({4}), {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})},
133           {/*dataset_params=*/BatchDatasetParams3(),
134            /*expected_outputs=*/
135            {CreateTensor<int64_t>(TensorShape({3}), {0, 1, 2}),
136             CreateTensor<int64_t>(TensorShape({3}), {3, 4, 5}),
137             CreateTensor<int64_t>(TensorShape({3}), {6, 7, 8}),
138             CreateTensor<int64_t>(TensorShape({1}), {9})}},
139           {/*dataset_params=*/BatchDatasetParams4(),
140            /*expected_outputs=*/
141            CreateTensors<int64_t>(TensorShape({3}),
142                                   {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}})},
143           {/*dataset_params=*/BatchDatasetParams5(),
144            /*expected_outputs=*/{}},
145           {/*dataset_params=*/BatchDatasetParams6(),
146            /*expected_outputs=*/
147            CreateTensors<int64_t>(TensorShape({10}),
148                                   {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}})},
149 
150           {/*dataset_params=*/BatchDatasetParams7(),
151            /*expected_outputs=*/{}}};
152 }
153 
ITERATOR_GET_NEXT_TEST_P(BatchDatasetOpTest,BatchDatasetParams,GetNextTestCases ())154 ITERATOR_GET_NEXT_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
155                          GetNextTestCases())
156 
157 TEST_F(BatchDatasetOpTest, DatasetNodeName) {
158   auto batch_dataset_params = BatchDatasetParams1();
159   TF_ASSERT_OK(Initialize(batch_dataset_params));
160   TF_ASSERT_OK(CheckDatasetNodeName(batch_dataset_params.node_name()));
161 }
162 
TEST_F(BatchDatasetOpTest,DatasetTypeString)163 TEST_F(BatchDatasetOpTest, DatasetTypeString) {
164   auto batch_dataset_params = BatchDatasetParams1();
165   TF_ASSERT_OK(Initialize(batch_dataset_params));
166   name_utils::OpNameParams params;
167   params.op_version = batch_dataset_params.op_version();
168   TF_ASSERT_OK(CheckDatasetTypeString(
169       name_utils::OpName(BatchDatasetOp::kDatasetType, params)));
170 }
171 
TEST_F(BatchDatasetOpTest,DatasetOutputDtypes)172 TEST_F(BatchDatasetOpTest, DatasetOutputDtypes) {
173   auto batch_dataset_params = BatchDatasetParams1();
174   TF_ASSERT_OK(Initialize(batch_dataset_params));
175   TF_ASSERT_OK(CheckDatasetOutputDtypes({DT_INT64}));
176 }
177 
178 std::vector<DatasetOutputShapesTestCase<BatchDatasetParams>>
DatasetOutputShapesTestCases()179 DatasetOutputShapesTestCases() {
180   return {{/*dataset_params=*/BatchDatasetParams1(),
181            /*expected_output_shapes=*/{PartialTensorShape({4})}},
182           {/*dataset_params=*/BatchDatasetParams2(),
183            /*expected_output_shapes=*/{PartialTensorShape({4})}},
184           {/*dataset_params=*/BatchDatasetParams3(),
185            /*expected_output_shapes=*/{PartialTensorShape({-1})}},
186           {/*dataset_params=*/BatchDatasetParams4(),
187            /*expected_output_shapes=*/{PartialTensorShape({3})}},
188           {/*dataset_params=*/BatchDatasetParams5(),
189            /*expected_output_shapes=*/{PartialTensorShape({12})}},
190           {/*dataset_params=*/BatchDatasetParams6(),
191            /*expected_output_shapes=*/{PartialTensorShape({-1})}},
192           {/*dataset_params=*/BatchDatasetParams7(),
193            /*expected_output_shapes=*/{PartialTensorShape({4})}}};
194 }
195 
DATASET_OUTPUT_SHAPES_TEST_P(BatchDatasetOpTest,BatchDatasetParams,DatasetOutputShapesTestCases ())196 DATASET_OUTPUT_SHAPES_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
197                              DatasetOutputShapesTestCases())
198 
199 std::vector<CardinalityTestCase<BatchDatasetParams>> CardinalityTestCases() {
200   return {
201       {/*dataset_params=*/BatchDatasetParams1(), /*expected_cardinality=*/3},
202       {/*dataset_params=*/BatchDatasetParams2(), /*expected_cardinality=*/3},
203       {/*dataset_params=*/BatchDatasetParams3(), /*expected_cardinality=*/4},
204       {/*dataset_params=*/BatchDatasetParams4(), /*expected_cardinality=*/3},
205       {/*dataset_params=*/BatchDatasetParams5(), /*expected_cardinality=*/0},
206       {/*dataset_params=*/BatchDatasetParams6(), /*expected_cardinality=*/1},
207       {/*dataset_params=*/BatchDatasetParams7(), /*expected_cardinality=*/0}};
208 }
209 
DATASET_CARDINALITY_TEST_P(BatchDatasetOpTest,BatchDatasetParams,CardinalityTestCases ())210 DATASET_CARDINALITY_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
211                            CardinalityTestCases())
212 
213 TEST_F(BatchDatasetOpTest, IteratorOutputDtypes) {
214   auto batch_dataset_params = BatchDatasetParams1();
215   TF_ASSERT_OK(Initialize(batch_dataset_params));
216   TF_ASSERT_OK(CheckIteratorOutputDtypes({DT_INT64}));
217 }
218 
219 std::vector<IteratorOutputShapesTestCase<BatchDatasetParams>>
IteratorOutputShapesTestCases()220 IteratorOutputShapesTestCases() {
221   return {{/*dataset_params=*/BatchDatasetParams1(),
222            /*expected_output_shapes=*/{PartialTensorShape({4})}},
223           {/*dataset_params=*/BatchDatasetParams2(),
224            /*expected_output_shapes=*/{PartialTensorShape({4})}},
225           {/*dataset_params=*/BatchDatasetParams3(),
226            /*expected_output_shapes=*/{PartialTensorShape({-1})}},
227           {/*dataset_params=*/BatchDatasetParams4(),
228            /*expected_output_shapes=*/{PartialTensorShape({3})}},
229           {/*dataset_params=*/BatchDatasetParams5(),
230            /*expected_output_shapes=*/{PartialTensorShape({12})}},
231           {/*dataset_params=*/BatchDatasetParams6(),
232            /*expected_output_shapes=*/{PartialTensorShape({-1})}},
233           {/*dataset_params=*/BatchDatasetParams7(),
234            /*expected_output_shapes=*/{PartialTensorShape({4})}}};
235 }
236 
ITERATOR_OUTPUT_SHAPES_TEST_P(BatchDatasetOpTest,BatchDatasetParams,IteratorOutputShapesTestCases ())237 ITERATOR_OUTPUT_SHAPES_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
238                               IteratorOutputShapesTestCases())
239 
240 TEST_F(BatchDatasetOpTest, IteratorOutputPrefix) {
241   auto batch_dataset_params = BatchDatasetParams1();
242   TF_ASSERT_OK(Initialize(batch_dataset_params));
243   name_utils::IteratorPrefixParams params;
244   params.op_version = batch_dataset_params.op_version();
245   TF_ASSERT_OK(CheckIteratorPrefix(name_utils::IteratorPrefix(
246       BatchDatasetOp::kDatasetType, batch_dataset_params.iterator_prefix(),
247       params)));
248 }
249 
250 std::vector<IteratorSaveAndRestoreTestCase<BatchDatasetParams>>
IteratorSaveAndRestoreTestCases()251 IteratorSaveAndRestoreTestCases() {
252   return {{/*dataset_params=*/BatchDatasetParams1(),
253            /*breakpoints=*/{0, 1, 5},
254            /*expected_outputs=*/
255            CreateTensors<int64_t>(
256                TensorShape({4}), {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})},
257           {/*dataset_params=*/BatchDatasetParams2(),
258            /*breakpoints=*/{0, 1, 5},
259            /*expected_outputs=*/
260            CreateTensors<int64_t>(
261                TensorShape({4}), {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 10, 11}})},
262           {/*dataset_params=*/BatchDatasetParams3(),
263            /*breakpoints=*/{0, 1, 5},
264            /*expected_outputs=*/
265            {CreateTensor<int64_t>(TensorShape({3}), {0, 1, 2}),
266             CreateTensor<int64_t>(TensorShape({3}), {3, 4, 5}),
267             CreateTensor<int64_t>(TensorShape({3}), {6, 7, 8}),
268             CreateTensor<int64_t>(TensorShape({1}), {9})}},
269           {/*dataset_params=*/BatchDatasetParams4(),
270            /*breakpoints=*/{0, 1, 5},
271            /*expected_outputs=*/
272            {CreateTensor<int64_t>(TensorShape({3}), {0, 1, 2}),
273             CreateTensor<int64_t>(TensorShape({3}), {3, 4, 5}),
274             CreateTensor<int64_t>(TensorShape({3}), {6, 7, 8})}},
275           {/*dataset_params=*/BatchDatasetParams5(),
276            /*breakpoints=*/{0, 1, 5},
277            /*expected_outputs=*/{}},
278           {/*dataset_params=*/BatchDatasetParams6(),
279            /*breakpoints=*/{0, 1, 5},
280            /*expected_outputs=*/
281            {CreateTensor<int64_t>(TensorShape({10}),
282                                   {0, 1, 2, 3, 4, 5, 6, 7, 8, 9})}},
283           {/*dataset_params=*/BatchDatasetParams7(),
284            /*breakpoints=*/{0, 1, 5},
285            /*expected_outputs=*/{}}};
286 }
287 
ITERATOR_SAVE_AND_RESTORE_TEST_P(BatchDatasetOpTest,BatchDatasetParams,IteratorSaveAndRestoreTestCases ())288 ITERATOR_SAVE_AND_RESTORE_TEST_P(BatchDatasetOpTest, BatchDatasetParams,
289                                  IteratorSaveAndRestoreTestCases())
290 
291 TEST_F(BatchDatasetOpTest, InvalidBatchSize) {
292   auto batch_dataset_params = InvalidBatchSizeBatchDatasetParams();
293   EXPECT_EQ(Initialize(batch_dataset_params).code(),
294             tensorflow::error::INVALID_ARGUMENT);
295 }
296 
297 // TODO(b/222556529) when Const has type constructor, remove the following
298 REGISTER_OP("BatchDatasetOpTest>ConstTypeCtor")
299     .Output("output: dtype")
300     .Attr("value: tensor")
301     .Attr("dtype: type")
302     .SetTypeConstructor(full_type::Unary(TFT_TENSOR, "dtype"));
303 
304 // Adds identity notes to all outputs of this node
add_identity_nodes(Node * node,Graph & graph,std::vector<Node * > & identity_nodes)305 static void add_identity_nodes(Node* node, Graph& graph,
306                                std::vector<Node*>& identity_nodes) {
307   for (int i = 0; i < node->num_outputs(); i++) {
308     Node* new_node;
309     std::string name = absl::StrCat("Identity", i);
310     TF_EXPECT_OK(NodeBuilder(name, "Identity")
311                      .Attr("T", node->output_type(i))
312                      .Input(node, i)
313                      .Finalize(&graph, &new_node));
314     identity_nodes.push_back(new_node);
315   }
316 }
317 
318 // Runs type inference pass on graph
type_inference(Graph & graph)319 static Status type_inference(Graph& graph) {
320   GraphOptimizationPassOptions opt_options;
321   std::unique_ptr<Graph> graph_ptr(new Graph(OpRegistry::Global()));
322   graph_ptr->Copy(graph);
323   opt_options.graph = &graph_ptr;
324   opt_options.flib_def = graph.mutable_flib_def();
325   ForwardTypeInferencePass pass;
326   return pass.Run(opt_options);
327 }
328 
TEST(BatchDatsetOpTest,TypeInference)329 TEST(BatchDatsetOpTest, TypeInference) {
330   Graph graph(OpRegistry::Global());
331   Node* input_dataset;
332   Node* batch_size;
333   Node* drop_remainder;
334   Node* batch_dataset_v2;
335   FullTypeDef input_dataset_t;
336   protobuf::TextFormat::Parser parser;
337   CHECK(parser.ParseFromString(
338       R"pb(type_id: TFT_PRODUCT
339            args {
340              type_id: TFT_DATASET
341              args {
342                type_id: TFT_PRODUCT
343                args {
344                  type_id: TFT_RAGGED
345                  args { type_id: TFT_STRING }
346                }
347              }
348            })pb",
349       &input_dataset_t));
350   TensorProto tensor_proto;
351   TF_EXPECT_OK(NodeBuilder("input_dataset", "Const")
352                    .Attr("value", tensor_proto)
353                    .Attr("dtype", DT_VARIANT)
354                    .Finalize(&graph, &input_dataset));
355   (*input_dataset->mutable_def()->mutable_experimental_type()) =
356       input_dataset_t;
357   // TODO(b/222556529) when Const has type constructor, use Const
358   TF_EXPECT_OK(NodeBuilder("batch_size", "BatchDatasetOpTest>ConstTypeCtor")
359                    .Attr("value", tensor_proto)
360                    .Attr("dtype", DT_INT64)
361                    .Finalize(&graph, &batch_size));
362   // TODO(b/222556529) when Const has type constructor, use Const
363   TF_EXPECT_OK(NodeBuilder("drop_remainder", "BatchDatasetOpTest>ConstTypeCtor")
364                    .Attr("value", tensor_proto)
365                    .Attr("dtype", DT_BOOL)
366                    .Finalize(&graph, &drop_remainder));
367   TF_EXPECT_OK(NodeBuilder("BatchDatasetV2", "BatchDatasetV2")
368                    .Attr("output_types", {DT_VARIANT})
369                    .Attr("output_shapes", {TensorShape({1})})
370                    .Input(input_dataset)
371                    .Input(batch_size)
372                    .Input(drop_remainder)
373                    .Finalize(&graph, &batch_dataset_v2));
374 
375   std::vector<Node*> identity_nodes;
376   add_identity_nodes(batch_dataset_v2, graph, identity_nodes);
377   TF_EXPECT_OK(type_inference(graph));
378   EXPECT_TRUE(full_type::IsEqual(identity_nodes[0]->def().experimental_type(),
379                                  input_dataset_t))
380       << "fulltype is\n"
381       << identity_nodes[0]->def().experimental_type().DebugString()
382       << "\nexpected\n"
383       << input_dataset_t.DebugString();
384 }
385 
386 }  // namespace
387 }  // namespace data
388 }  // namespace tensorflow
389