xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/dataset_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/dataset.h"
17 
18 #include "tensorflow/core/framework/tensor_testutil.h"
19 #include "tensorflow/core/framework/types.pb.h"
20 #include "tensorflow/core/platform/test.h"
21 
22 namespace tensorflow {
23 namespace data {
24 
TEST(DatasetTest,FullName)25 TEST(DatasetTest, FullName) {
26   EXPECT_EQ(FullName("prefix", "name"),
27             "60d899aa0d8ce4351e7c3b419e92d25b|prefix:name");
28 }
29 
30 enum DataTypeTest {
31   _tf_int_32,
32   _tf_int_64,
33   _tf_float_,
34   _tf_double_,
35   _tf_string_
36 };
37 
38 struct DatasetTestParam {
39   const DataTypeTest type;
40   // This has to be a function pointer, to make sure the tensors we use as
41   // parameters of the test case do not become globals. Ordering of static
42   // initializers and globals can cause errors in the test.
43   std::function<std::vector<Tensor>()> tensor_factory;
44   const int64_t expected_bytes;
45 };
46 
47 class DatasetTestTotalBytes
48     : public ::testing::TestWithParam<DatasetTestParam> {};
49 
TEST_P(DatasetTestTotalBytes,TestTotalBytes)50 TEST_P(DatasetTestTotalBytes, TestTotalBytes) {
51   const DatasetTestParam& test_case = GetParam();
52   if (test_case.type == _tf_string_) {
53     // TotalBytes() is approximate and gives an upper bound for strings
54     EXPECT_LE(GetTotalBytes(test_case.tensor_factory()),
55               test_case.expected_bytes);
56   } else {
57     EXPECT_EQ(GetTotalBytes(test_case.tensor_factory()),
58               test_case.expected_bytes);
59   }
60 }
61 
tensor_tf_int_32s()62 std::vector<Tensor> tensor_tf_int_32s() {
63   return {test::AsTensor<int32>({1, 2, 3, 4, 5}),
64           test::AsTensor<int32>({1, 2, 3, 4})};
65 }
66 
tensor_tf_int_64s()67 std::vector<Tensor> tensor_tf_int_64s() {
68   return {test::AsTensor<int64_t>({1, 2, 3, 4, 5}),
69           test::AsTensor<int64_t>({10, 12})};
70 }
71 
tensor_tf_float_s()72 std::vector<Tensor> tensor_tf_float_s() {
73   return {test::AsTensor<float>({1.0, 2.0, 3.0, 4.0})};
74 }
75 
tensor_tf_double_s()76 std::vector<Tensor> tensor_tf_double_s() {
77   return {test::AsTensor<double>({100.0}), test::AsTensor<double>({200.0}),
78           test::AsTensor<double>({400.0}), test::AsTensor<double>({800.0})};
79 }
80 
81 const tstring str = "test string";  // NOLINT
tensor_strs()82 std::vector<Tensor> tensor_strs() { return {test::AsTensor<tstring>({str})}; }
83 
84 INSTANTIATE_TEST_SUITE_P(
85     DatasetTestTotalBytes, DatasetTestTotalBytes,
86     ::testing::ValuesIn(std::vector<DatasetTestParam>{
87         {_tf_int_32, tensor_tf_int_32s, 4 /*bytes*/ * 9 /*elements*/},
88         {_tf_int_64, tensor_tf_int_64s, 8 /*bytes*/ * 7 /*elements*/},
89         {_tf_float_, tensor_tf_float_s, 4 /*bytes*/ * 4 /*elements*/},
90         {_tf_double_, tensor_tf_double_s, 8 /*bytes*/ * 4 /*elements*/},
91         {_tf_string_, tensor_strs,
92          static_cast<int64_t>(sizeof(str) + str.size()) /*bytes*/}}));
93 
94 struct MergeOptionsTestParam {
95   const std::string source;
96   const std::string destination;
97   const std::string expected;
98 };
99 
100 class MergeOptionsTest
101     : public ::testing::TestWithParam<MergeOptionsTestParam> {};
102 
TEST_P(MergeOptionsTest,MergeOptions)103 TEST_P(MergeOptionsTest, MergeOptions) {
104   const MergeOptionsTestParam& test_case = GetParam();
105   Options source;
106   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(test_case.source,
107                                                           &source));
108   Options destination;
109   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(test_case.destination,
110                                                           &destination));
111   Options expected;
112   CHECK(tensorflow::protobuf::TextFormat::ParseFromString(test_case.expected,
113                                                           &expected));
114   internal::MergeOptions(source, &destination);
115   EXPECT_EQ(expected.SerializeAsString(), destination.SerializeAsString());
116 }
117 
118 INSTANTIATE_TEST_SUITE_P(
119     MergeOptionsTest, MergeOptionsTest,
120     ::testing::ValuesIn(std::vector<MergeOptionsTestParam>{
121         // Destination is empty.
122         {/*source=*/"deterministic: false", /*destination=*/"",
123          /*expected=*/"deterministic: false"},
124         // Source and destination have the same values.
125         {/*source=*/"deterministic: false",
126          /*destination=*/"deterministic: false",
127          /*expected=*/"deterministic: false"},
128         // Source values override destination values.
129         {/*source=*/"deterministic: false",
130          /*destination=*/"deterministic: true",
131          /*expected=*/"deterministic: false"},
132         // Values are enums.
133         {/*source=*/"external_state_policy: POLICY_IGNORE",
134          /*destination=*/"external_state_policy: POLICY_FAIL",
135          /*expected=*/"external_state_policy: POLICY_IGNORE"}}));
136 
TEST(DatasetTest,IsDatasetOp)137 TEST(DatasetTest, IsDatasetOp) {
138   OpDef op_def;
139   // Test zero outputs.
140   EXPECT_FALSE(DatasetOpKernel::IsDatasetOp(op_def));
141 
142   // Test invalid output type.
143   op_def.add_output_arg()->set_type(DT_STRING);
144   EXPECT_FALSE(DatasetOpKernel::IsDatasetOp(op_def));
145 
146   // Test invalid op name.
147   op_def.mutable_output_arg(0)->set_type(DT_VARIANT);
148   op_def.set_name("Identity");
149   EXPECT_FALSE(DatasetOpKernel::IsDatasetOp(op_def));
150 
151   // Test valid op names.
152   for (const auto& name : {"Dataset", "RangeDataset", "MapDatasetV1",
153                            "ParallelInterleaveDatasetV42",
154                            "DataServiceDatasetV1000", "DatasetFromGraph"}) {
155     op_def.set_name(name);
156     EXPECT_TRUE(DatasetOpKernel::IsDatasetOp(op_def));
157   }
158 }
159 
160 }  // namespace data
161 }  // namespace tensorflow
162