1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/dataset.h"
17
18 #include "tensorflow/core/framework/tensor_testutil.h"
19 #include "tensorflow/core/framework/types.pb.h"
20 #include "tensorflow/core/platform/test.h"
21
22 namespace tensorflow {
23 namespace data {
24
TEST(DatasetTest,FullName)25 TEST(DatasetTest, FullName) {
26 EXPECT_EQ(FullName("prefix", "name"),
27 "60d899aa0d8ce4351e7c3b419e92d25b|prefix:name");
28 }
29
30 enum DataTypeTest {
31 _tf_int_32,
32 _tf_int_64,
33 _tf_float_,
34 _tf_double_,
35 _tf_string_
36 };
37
38 struct DatasetTestParam {
39 const DataTypeTest type;
40 // This has to be a function pointer, to make sure the tensors we use as
41 // parameters of the test case do not become globals. Ordering of static
42 // initializers and globals can cause errors in the test.
43 std::function<std::vector<Tensor>()> tensor_factory;
44 const int64_t expected_bytes;
45 };
46
47 class DatasetTestTotalBytes
48 : public ::testing::TestWithParam<DatasetTestParam> {};
49
TEST_P(DatasetTestTotalBytes,TestTotalBytes)50 TEST_P(DatasetTestTotalBytes, TestTotalBytes) {
51 const DatasetTestParam& test_case = GetParam();
52 if (test_case.type == _tf_string_) {
53 // TotalBytes() is approximate and gives an upper bound for strings
54 EXPECT_LE(GetTotalBytes(test_case.tensor_factory()),
55 test_case.expected_bytes);
56 } else {
57 EXPECT_EQ(GetTotalBytes(test_case.tensor_factory()),
58 test_case.expected_bytes);
59 }
60 }
61
tensor_tf_int_32s()62 std::vector<Tensor> tensor_tf_int_32s() {
63 return {test::AsTensor<int32>({1, 2, 3, 4, 5}),
64 test::AsTensor<int32>({1, 2, 3, 4})};
65 }
66
tensor_tf_int_64s()67 std::vector<Tensor> tensor_tf_int_64s() {
68 return {test::AsTensor<int64_t>({1, 2, 3, 4, 5}),
69 test::AsTensor<int64_t>({10, 12})};
70 }
71
tensor_tf_float_s()72 std::vector<Tensor> tensor_tf_float_s() {
73 return {test::AsTensor<float>({1.0, 2.0, 3.0, 4.0})};
74 }
75
tensor_tf_double_s()76 std::vector<Tensor> tensor_tf_double_s() {
77 return {test::AsTensor<double>({100.0}), test::AsTensor<double>({200.0}),
78 test::AsTensor<double>({400.0}), test::AsTensor<double>({800.0})};
79 }
80
81 const tstring str = "test string"; // NOLINT
tensor_strs()82 std::vector<Tensor> tensor_strs() { return {test::AsTensor<tstring>({str})}; }
83
84 INSTANTIATE_TEST_SUITE_P(
85 DatasetTestTotalBytes, DatasetTestTotalBytes,
86 ::testing::ValuesIn(std::vector<DatasetTestParam>{
87 {_tf_int_32, tensor_tf_int_32s, 4 /*bytes*/ * 9 /*elements*/},
88 {_tf_int_64, tensor_tf_int_64s, 8 /*bytes*/ * 7 /*elements*/},
89 {_tf_float_, tensor_tf_float_s, 4 /*bytes*/ * 4 /*elements*/},
90 {_tf_double_, tensor_tf_double_s, 8 /*bytes*/ * 4 /*elements*/},
91 {_tf_string_, tensor_strs,
92 static_cast<int64_t>(sizeof(str) + str.size()) /*bytes*/}}));
93
94 struct MergeOptionsTestParam {
95 const std::string source;
96 const std::string destination;
97 const std::string expected;
98 };
99
100 class MergeOptionsTest
101 : public ::testing::TestWithParam<MergeOptionsTestParam> {};
102
TEST_P(MergeOptionsTest,MergeOptions)103 TEST_P(MergeOptionsTest, MergeOptions) {
104 const MergeOptionsTestParam& test_case = GetParam();
105 Options source;
106 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(test_case.source,
107 &source));
108 Options destination;
109 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(test_case.destination,
110 &destination));
111 Options expected;
112 CHECK(tensorflow::protobuf::TextFormat::ParseFromString(test_case.expected,
113 &expected));
114 internal::MergeOptions(source, &destination);
115 EXPECT_EQ(expected.SerializeAsString(), destination.SerializeAsString());
116 }
117
118 INSTANTIATE_TEST_SUITE_P(
119 MergeOptionsTest, MergeOptionsTest,
120 ::testing::ValuesIn(std::vector<MergeOptionsTestParam>{
121 // Destination is empty.
122 {/*source=*/"deterministic: false", /*destination=*/"",
123 /*expected=*/"deterministic: false"},
124 // Source and destination have the same values.
125 {/*source=*/"deterministic: false",
126 /*destination=*/"deterministic: false",
127 /*expected=*/"deterministic: false"},
128 // Source values override destination values.
129 {/*source=*/"deterministic: false",
130 /*destination=*/"deterministic: true",
131 /*expected=*/"deterministic: false"},
132 // Values are enums.
133 {/*source=*/"external_state_policy: POLICY_IGNORE",
134 /*destination=*/"external_state_policy: POLICY_FAIL",
135 /*expected=*/"external_state_policy: POLICY_IGNORE"}}));
136
TEST(DatasetTest,IsDatasetOp)137 TEST(DatasetTest, IsDatasetOp) {
138 OpDef op_def;
139 // Test zero outputs.
140 EXPECT_FALSE(DatasetOpKernel::IsDatasetOp(op_def));
141
142 // Test invalid output type.
143 op_def.add_output_arg()->set_type(DT_STRING);
144 EXPECT_FALSE(DatasetOpKernel::IsDatasetOp(op_def));
145
146 // Test invalid op name.
147 op_def.mutable_output_arg(0)->set_type(DT_VARIANT);
148 op_def.set_name("Identity");
149 EXPECT_FALSE(DatasetOpKernel::IsDatasetOp(op_def));
150
151 // Test valid op names.
152 for (const auto& name : {"Dataset", "RangeDataset", "MapDatasetV1",
153 "ParallelInterleaveDatasetV42",
154 "DataServiceDatasetV1000", "DatasetFromGraph"}) {
155 op_def.set_name(name);
156 EXPECT_TRUE(DatasetOpKernel::IsDatasetOp(op_def));
157 }
158 }
159
160 } // namespace data
161 } // namespace tensorflow
162