1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <utility>
17
18 #include "tensorflow/core/util/example_proto_fast_parsing.h"
19
20 #include "tensorflow/core/example/example.pb.h"
21 #include "tensorflow/core/example/feature.pb.h"
22 #include "tensorflow/core/lib/random/philox_random.h"
23 #include "tensorflow/core/lib/random/simple_philox.h"
24 #include "tensorflow/core/platform/protobuf.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/platform/test_benchmark.h"
27 #include "tensorflow/core/util/example_proto_fast_parsing_test.pb.h"
28
29 namespace tensorflow {
30 namespace example {
31 namespace {
32
33 constexpr char kDenseInt64Key[] = "dense_int64";
34 constexpr char kDenseFloatKey[] = "dense_float";
35 constexpr char kDenseStringKey[] = "dense_string";
36
37 constexpr char kSparseInt64Key[] = "sparse_int64";
38 constexpr char kSparseFloatKey[] = "sparse_float";
39 constexpr char kSparseStringKey[] = "sparse_string";
40
SerializedToReadable(string serialized)41 string SerializedToReadable(string serialized) {
42 string result;
43 result += '"';
44 for (char c : serialized)
45 result += strings::StrCat("\\x", strings::Hex(c, strings::kZeroPad2));
46 result += '"';
47 return result;
48 }
49
50 template <class T>
Serialize(const T & example)51 string Serialize(const T& example) {
52 string serialized;
53 example.SerializeToString(&serialized);
54 return serialized;
55 }
56
57 // Tests that serialized gets parsed identically by TestFastParse(..)
58 // and the regular Example.ParseFromString(..).
TestCorrectness(const string & serialized)59 void TestCorrectness(const string& serialized) {
60 Example example;
61 Example fast_example;
62 EXPECT_TRUE(example.ParseFromString(serialized));
63 example.DiscardUnknownFields();
64 EXPECT_TRUE(TestFastParse(serialized, &fast_example));
65 EXPECT_EQ(example.DebugString(), fast_example.DebugString());
66 if (example.DebugString() != fast_example.DebugString()) {
67 LOG(ERROR) << "Bad serialized: " << SerializedToReadable(serialized);
68 }
69 }
70
71 // Fast parsing does not differentiate between EmptyExample and EmptyFeatures
72 // TEST(FastParse, EmptyExample) {
73 // Example example;
74 // TestCorrectness(example);
75 // }
76
TEST(FastParse,IgnoresPrecedingUnknownTopLevelFields)77 TEST(FastParse, IgnoresPrecedingUnknownTopLevelFields) {
78 ExampleWithExtras example;
79 (*example.mutable_features()->mutable_feature())["age"]
80 .mutable_int64_list()
81 ->add_value(13);
82 example.set_extra1("some_str");
83 example.set_extra2(123);
84 example.set_extra3(234);
85 example.set_extra4(345);
86 example.set_extra5(4.56);
87 example.add_extra6(5.67);
88 example.add_extra6(6.78);
89 (*example.mutable_extra7()->mutable_feature())["extra7"]
90 .mutable_int64_list()
91 ->add_value(1337);
92
93 Example context;
94 (*context.mutable_features()->mutable_feature())["zipcode"]
95 .mutable_int64_list()
96 ->add_value(94043);
97
98 TestCorrectness(strings::StrCat(Serialize(example), Serialize(context)));
99 }
100
TEST(FastParse,IgnoresTrailingUnknownTopLevelFields)101 TEST(FastParse, IgnoresTrailingUnknownTopLevelFields) {
102 Example example;
103 (*example.mutable_features()->mutable_feature())["age"]
104 .mutable_int64_list()
105 ->add_value(13);
106
107 ExampleWithExtras context;
108 (*context.mutable_features()->mutable_feature())["zipcode"]
109 .mutable_int64_list()
110 ->add_value(94043);
111 context.set_extra1("some_str");
112 context.set_extra2(123);
113 context.set_extra3(234);
114 context.set_extra4(345);
115 context.set_extra5(4.56);
116 context.add_extra6(5.67);
117 context.add_extra6(6.78);
118 (*context.mutable_extra7()->mutable_feature())["extra7"]
119 .mutable_int64_list()
120 ->add_value(1337);
121
122 TestCorrectness(strings::StrCat(Serialize(example), Serialize(context)));
123 }
124
TEST(FastParse,SingleInt64WithContext)125 TEST(FastParse, SingleInt64WithContext) {
126 Example example;
127 (*example.mutable_features()->mutable_feature())["age"]
128 .mutable_int64_list()
129 ->add_value(13);
130
131 Example context;
132 (*context.mutable_features()->mutable_feature())["zipcode"]
133 .mutable_int64_list()
134 ->add_value(94043);
135
136 TestCorrectness(strings::StrCat(Serialize(example), Serialize(context)));
137 }
138
TEST(FastParse,DenseInt64WithContext)139 TEST(FastParse, DenseInt64WithContext) {
140 Example example;
141 (*example.mutable_features()->mutable_feature())["age"]
142 .mutable_int64_list()
143 ->add_value(0);
144
145 Example context;
146 (*context.mutable_features()->mutable_feature())["age"]
147 .mutable_int64_list()
148 ->add_value(15);
149
150 string serialized = Serialize(example) + Serialize(context);
151
152 {
153 Example deserialized;
154 EXPECT_TRUE(deserialized.ParseFromString(serialized));
155 EXPECT_EQ(deserialized.DebugString(), context.DebugString());
156 // Whoa! Last EQ is very surprising, but standard deserialization is what it
157 // is and Servo team requested to replicate this 'feature'.
158 // In future we should return error.
159 }
160 TestCorrectness(serialized);
161 }
162
TEST(FastParse,NonPacked)163 TEST(FastParse, NonPacked) {
164 TestCorrectness(
165 "\x0a\x0e\x0a\x0c\x0a\x03\x61\x67\x65\x12\x05\x1a\x03\x0a\x01\x0d");
166 }
167
TEST(FastParse,Packed)168 TEST(FastParse, Packed) {
169 TestCorrectness(
170 "\x0a\x0d\x0a\x0b\x0a\x03\x61\x67\x65\x12\x04\x1a\x02\x08\x0d");
171 }
172
TEST(FastParse,ValueBeforeKeyInMap)173 TEST(FastParse, ValueBeforeKeyInMap) {
174 TestCorrectness("\x0a\x12\x0a\x10\x12\x09\x0a\x07\x0a\x05value\x0a\x03key");
175 }
176
TEST(FastParse,EmptyFeatures)177 TEST(FastParse, EmptyFeatures) {
178 Example example;
179 example.mutable_features();
180 TestCorrectness(Serialize(example));
181 }
182
TestCorrectnessJson(const string & json)183 void TestCorrectnessJson(const string& json) {
184 auto resolver = protobuf::util::NewTypeResolverForDescriptorPool(
185 "type.googleapis.com", protobuf::DescriptorPool::generated_pool());
186 string serialized;
187 auto s = protobuf::util::JsonToBinaryString(
188 resolver, "type.googleapis.com/tensorflow.Example", json, &serialized);
189 EXPECT_TRUE(s.ok()) << s;
190 delete resolver;
191 TestCorrectness(serialized);
192 }
193
TEST(FastParse,JsonUnivalent)194 TEST(FastParse, JsonUnivalent) {
195 TestCorrectnessJson(
196 "{'features': {"
197 " 'feature': {'age': {'int64_list': {'value': [0]} }}, "
198 " 'feature': {'flo': {'float_list': {'value': [1.1]} }}, "
199 " 'feature': {'byt': {'bytes_list': {'value': ['WW8='] }}}"
200 "}}");
201 }
202
TEST(FastParse,JsonMultivalent)203 TEST(FastParse, JsonMultivalent) {
204 TestCorrectnessJson(
205 "{'features': {"
206 " 'feature': {'age': {'int64_list': {'value': [0, 13, 23]} }}, "
207 " 'feature': {'flo': {'float_list': {'value': [1.1, 1.2, 1.3]} }}, "
208 " 'feature': {'byt': {'bytes_list': {'value': ['WW8=', 'WW8K'] }}}"
209 "}}");
210 }
211
TEST(FastParse,SingleInt64)212 TEST(FastParse, SingleInt64) {
213 Example example;
214 (*example.mutable_features()->mutable_feature())["age"]
215 .mutable_int64_list()
216 ->add_value(13);
217 TestCorrectness(Serialize(example));
218 }
219
ExampleWithSomeFeatures()220 static string ExampleWithSomeFeatures() {
221 Example example;
222
223 (*example.mutable_features()->mutable_feature())[""];
224
225 (*example.mutable_features()->mutable_feature())["empty_bytes_list"]
226 .mutable_bytes_list();
227 (*example.mutable_features()->mutable_feature())["empty_float_list"]
228 .mutable_float_list();
229 (*example.mutable_features()->mutable_feature())["empty_int64_list"]
230 .mutable_int64_list();
231
232 BytesList* bytes_list =
233 (*example.mutable_features()->mutable_feature())["bytes_list"]
234 .mutable_bytes_list();
235 bytes_list->add_value("bytes1");
236 bytes_list->add_value("bytes2");
237
238 FloatList* float_list =
239 (*example.mutable_features()->mutable_feature())["float_list"]
240 .mutable_float_list();
241 float_list->add_value(1.0);
242 float_list->add_value(2.0);
243
244 Int64List* int64_list =
245 (*example.mutable_features()->mutable_feature())["int64_list"]
246 .mutable_int64_list();
247 int64_list->add_value(3);
248 int64_list->add_value(270);
249 int64_list->add_value(86942);
250
251 return Serialize(example);
252 }
253
TEST(FastParse,SomeFeatures)254 TEST(FastParse, SomeFeatures) { TestCorrectness(ExampleWithSomeFeatures()); }
255
AddDenseFeature(const char * feature_name,DataType dtype,PartialTensorShape shape,bool variable_length,size_t elements_per_stride,FastParseExampleConfig * out_config)256 static void AddDenseFeature(const char* feature_name, DataType dtype,
257 PartialTensorShape shape, bool variable_length,
258 size_t elements_per_stride,
259 FastParseExampleConfig* out_config) {
260 out_config->dense.emplace_back();
261 auto& new_feature = out_config->dense.back();
262 new_feature.feature_name = feature_name;
263 new_feature.dtype = dtype;
264 new_feature.shape = std::move(shape);
265 new_feature.default_value = Tensor(dtype, {});
266 new_feature.variable_length = variable_length;
267 new_feature.elements_per_stride = elements_per_stride;
268 }
269
AddSparseFeature(const char * feature_name,DataType dtype,FastParseExampleConfig * out_config)270 static void AddSparseFeature(const char* feature_name, DataType dtype,
271 FastParseExampleConfig* out_config) {
272 out_config->sparse.emplace_back();
273 auto& new_feature = out_config->sparse.back();
274 new_feature.feature_name = feature_name;
275 new_feature.dtype = dtype;
276 }
277
TEST(FastParse,StatsCollection)278 TEST(FastParse, StatsCollection) {
279 const size_t kNumExamples = 13;
280 std::vector<tstring> serialized(kNumExamples, ExampleWithSomeFeatures());
281
282 FastParseExampleConfig config_dense;
283 AddDenseFeature("bytes_list", DT_STRING, {2}, false, 2, &config_dense);
284 AddDenseFeature("float_list", DT_FLOAT, {2}, false, 2, &config_dense);
285 AddDenseFeature("int64_list", DT_INT64, {3}, false, 3, &config_dense);
286 config_dense.collect_feature_stats = true;
287
288 FastParseExampleConfig config_varlen;
289 AddDenseFeature("bytes_list", DT_STRING, {-1}, true, 1, &config_varlen);
290 AddDenseFeature("float_list", DT_FLOAT, {-1}, true, 1, &config_varlen);
291 AddDenseFeature("int64_list", DT_INT64, {-1}, true, 1, &config_varlen);
292 config_varlen.collect_feature_stats = true;
293
294 FastParseExampleConfig config_sparse;
295 AddSparseFeature("bytes_list", DT_STRING, &config_sparse);
296 AddSparseFeature("float_list", DT_FLOAT, &config_sparse);
297 AddSparseFeature("int64_list", DT_INT64, &config_sparse);
298 config_sparse.collect_feature_stats = true;
299
300 FastParseExampleConfig config_mixed;
301 AddDenseFeature("bytes_list", DT_STRING, {2}, false, 2, &config_mixed);
302 AddDenseFeature("float_list", DT_FLOAT, {-1}, true, 1, &config_mixed);
303 AddSparseFeature("int64_list", DT_INT64, &config_mixed);
304 config_mixed.collect_feature_stats = true;
305
306 for (const FastParseExampleConfig& config :
307 {config_dense, config_varlen, config_sparse, config_mixed}) {
308 {
309 Result result;
310 TF_CHECK_OK(FastParseExample(config, serialized, {}, nullptr, &result));
311 EXPECT_EQ(kNumExamples, result.feature_stats.size());
312 for (const PerExampleFeatureStats& stats : result.feature_stats) {
313 EXPECT_EQ(7, stats.features_count);
314 EXPECT_EQ(7, stats.feature_values_count);
315 }
316 }
317
318 {
319 Result result;
320 TF_CHECK_OK(FastParseSingleExample(config, serialized[0], &result));
321 EXPECT_EQ(1, result.feature_stats.size());
322 EXPECT_EQ(7, result.feature_stats[0].features_count);
323 EXPECT_EQ(7, result.feature_stats[0].feature_values_count);
324 }
325 }
326 }
327
RandStr(random::SimplePhilox * rng)328 string RandStr(random::SimplePhilox* rng) {
329 static const char key_char_lookup[] =
330 "0123456789{}~`!@#$%^&*()"
331 "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
332 "abcdefghijklmnopqrstuvwxyz";
333 auto len = 1 + rng->Rand32() % 200;
334 string str;
335 str.reserve(len);
336 while (len-- > 0) {
337 str.push_back(
338 key_char_lookup[rng->Rand32() % (sizeof(key_char_lookup) /
339 sizeof(key_char_lookup[0]))]);
340 }
341 return str;
342 }
343
Fuzz(random::SimplePhilox * rng)344 void Fuzz(random::SimplePhilox* rng) {
345 // Generate keys.
346 auto num_keys = 1 + rng->Rand32() % 100;
347 std::unordered_set<string> unique_keys;
348 for (auto i = 0; i < num_keys; ++i) {
349 unique_keys.emplace(RandStr(rng));
350 }
351
352 // Generate serialized example.
353 Example example;
354 string serialized_example;
355 auto num_concats = 1 + rng->Rand32() % 4;
356 std::vector<Feature::KindCase> feat_types(
357 {Feature::kBytesList, Feature::kFloatList, Feature::kInt64List});
358 std::vector<string> all_keys(unique_keys.begin(), unique_keys.end());
359 while (num_concats--) {
360 example.Clear();
361 auto num_active_keys = 1 + rng->Rand32() % all_keys.size();
362
363 // Generate features.
364 for (auto i = 0; i < num_active_keys; ++i) {
365 auto fkey = all_keys[rng->Rand32() % all_keys.size()];
366 auto ftype_idx = rng->Rand32() % feat_types.size();
367 auto num_features = 1 + rng->Rand32() % 5;
368 switch (static_cast<Feature::KindCase>(feat_types[ftype_idx])) {
369 case Feature::kBytesList: {
370 BytesList* bytes_list =
371 (*example.mutable_features()->mutable_feature())[fkey]
372 .mutable_bytes_list();
373 while (num_features--) {
374 bytes_list->add_value(RandStr(rng));
375 }
376 break;
377 }
378 case Feature::kFloatList: {
379 FloatList* float_list =
380 (*example.mutable_features()->mutable_feature())[fkey]
381 .mutable_float_list();
382 while (num_features--) {
383 float_list->add_value(rng->RandFloat());
384 }
385 break;
386 }
387 case Feature::kInt64List: {
388 Int64List* int64_list =
389 (*example.mutable_features()->mutable_feature())[fkey]
390 .mutable_int64_list();
391 while (num_features--) {
392 int64_list->add_value(rng->Rand64());
393 }
394 break;
395 }
396 default: {
397 LOG(QFATAL);
398 break;
399 }
400 }
401 }
402 serialized_example += example.SerializeAsString();
403 }
404
405 // Test correctness.
406 TestCorrectness(serialized_example);
407 }
408
TEST(FastParse,FuzzTest)409 TEST(FastParse, FuzzTest) {
410 const uint64 seed = 1337;
411 random::PhiloxRandom philox(seed);
412 random::SimplePhilox rng(&philox);
413 auto num_runs = 200;
414 while (num_runs--) {
415 LOG(INFO) << "runs left: " << num_runs;
416 Fuzz(&rng);
417 }
418 }
419
TEST(TestFastParseExample,Empty)420 TEST(TestFastParseExample, Empty) {
421 Result result;
422 FastParseExampleConfig config;
423 config.sparse.push_back({"test", DT_STRING});
424 Status status =
425 FastParseExample(config, gtl::ArraySlice<tstring>(),
426 gtl::ArraySlice<tstring>(), nullptr, &result);
427 EXPECT_TRUE(status.ok()) << status;
428 }
429
430 } // namespace
431 } // namespace example
432 } // namespace tensorflow
433