xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/example_proto_fast_parsing_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <utility>
17 
18 #include "tensorflow/core/util/example_proto_fast_parsing.h"
19 
20 #include "tensorflow/core/example/example.pb.h"
21 #include "tensorflow/core/example/feature.pb.h"
22 #include "tensorflow/core/lib/random/philox_random.h"
23 #include "tensorflow/core/lib/random/simple_philox.h"
24 #include "tensorflow/core/platform/protobuf.h"
25 #include "tensorflow/core/platform/test.h"
26 #include "tensorflow/core/platform/test_benchmark.h"
27 #include "tensorflow/core/util/example_proto_fast_parsing_test.pb.h"
28 
29 namespace tensorflow {
30 namespace example {
31 namespace {
32 
33 constexpr char kDenseInt64Key[] = "dense_int64";
34 constexpr char kDenseFloatKey[] = "dense_float";
35 constexpr char kDenseStringKey[] = "dense_string";
36 
37 constexpr char kSparseInt64Key[] = "sparse_int64";
38 constexpr char kSparseFloatKey[] = "sparse_float";
39 constexpr char kSparseStringKey[] = "sparse_string";
40 
SerializedToReadable(string serialized)41 string SerializedToReadable(string serialized) {
42   string result;
43   result += '"';
44   for (char c : serialized)
45     result += strings::StrCat("\\x", strings::Hex(c, strings::kZeroPad2));
46   result += '"';
47   return result;
48 }
49 
50 template <class T>
Serialize(const T & example)51 string Serialize(const T& example) {
52   string serialized;
53   example.SerializeToString(&serialized);
54   return serialized;
55 }
56 
57 // Tests that serialized gets parsed identically by TestFastParse(..)
58 // and the regular Example.ParseFromString(..).
TestCorrectness(const string & serialized)59 void TestCorrectness(const string& serialized) {
60   Example example;
61   Example fast_example;
62   EXPECT_TRUE(example.ParseFromString(serialized));
63   example.DiscardUnknownFields();
64   EXPECT_TRUE(TestFastParse(serialized, &fast_example));
65   EXPECT_EQ(example.DebugString(), fast_example.DebugString());
66   if (example.DebugString() != fast_example.DebugString()) {
67     LOG(ERROR) << "Bad serialized: " << SerializedToReadable(serialized);
68   }
69 }
70 
71 // Fast parsing does not differentiate between EmptyExample and EmptyFeatures
72 // TEST(FastParse, EmptyExample) {
73 //   Example example;
74 //   TestCorrectness(example);
75 // }
76 
TEST(FastParse,IgnoresPrecedingUnknownTopLevelFields)77 TEST(FastParse, IgnoresPrecedingUnknownTopLevelFields) {
78   ExampleWithExtras example;
79   (*example.mutable_features()->mutable_feature())["age"]
80       .mutable_int64_list()
81       ->add_value(13);
82   example.set_extra1("some_str");
83   example.set_extra2(123);
84   example.set_extra3(234);
85   example.set_extra4(345);
86   example.set_extra5(4.56);
87   example.add_extra6(5.67);
88   example.add_extra6(6.78);
89   (*example.mutable_extra7()->mutable_feature())["extra7"]
90       .mutable_int64_list()
91       ->add_value(1337);
92 
93   Example context;
94   (*context.mutable_features()->mutable_feature())["zipcode"]
95       .mutable_int64_list()
96       ->add_value(94043);
97 
98   TestCorrectness(strings::StrCat(Serialize(example), Serialize(context)));
99 }
100 
TEST(FastParse,IgnoresTrailingUnknownTopLevelFields)101 TEST(FastParse, IgnoresTrailingUnknownTopLevelFields) {
102   Example example;
103   (*example.mutable_features()->mutable_feature())["age"]
104       .mutable_int64_list()
105       ->add_value(13);
106 
107   ExampleWithExtras context;
108   (*context.mutable_features()->mutable_feature())["zipcode"]
109       .mutable_int64_list()
110       ->add_value(94043);
111   context.set_extra1("some_str");
112   context.set_extra2(123);
113   context.set_extra3(234);
114   context.set_extra4(345);
115   context.set_extra5(4.56);
116   context.add_extra6(5.67);
117   context.add_extra6(6.78);
118   (*context.mutable_extra7()->mutable_feature())["extra7"]
119       .mutable_int64_list()
120       ->add_value(1337);
121 
122   TestCorrectness(strings::StrCat(Serialize(example), Serialize(context)));
123 }
124 
TEST(FastParse,SingleInt64WithContext)125 TEST(FastParse, SingleInt64WithContext) {
126   Example example;
127   (*example.mutable_features()->mutable_feature())["age"]
128       .mutable_int64_list()
129       ->add_value(13);
130 
131   Example context;
132   (*context.mutable_features()->mutable_feature())["zipcode"]
133       .mutable_int64_list()
134       ->add_value(94043);
135 
136   TestCorrectness(strings::StrCat(Serialize(example), Serialize(context)));
137 }
138 
TEST(FastParse,DenseInt64WithContext)139 TEST(FastParse, DenseInt64WithContext) {
140   Example example;
141   (*example.mutable_features()->mutable_feature())["age"]
142       .mutable_int64_list()
143       ->add_value(0);
144 
145   Example context;
146   (*context.mutable_features()->mutable_feature())["age"]
147       .mutable_int64_list()
148       ->add_value(15);
149 
150   string serialized = Serialize(example) + Serialize(context);
151 
152   {
153     Example deserialized;
154     EXPECT_TRUE(deserialized.ParseFromString(serialized));
155     EXPECT_EQ(deserialized.DebugString(), context.DebugString());
156     // Whoa! Last EQ is very surprising, but standard deserialization is what it
157     // is and Servo team requested to replicate this 'feature'.
158     // In future we should return error.
159   }
160   TestCorrectness(serialized);
161 }
162 
TEST(FastParse,NonPacked)163 TEST(FastParse, NonPacked) {
164   TestCorrectness(
165       "\x0a\x0e\x0a\x0c\x0a\x03\x61\x67\x65\x12\x05\x1a\x03\x0a\x01\x0d");
166 }
167 
TEST(FastParse,Packed)168 TEST(FastParse, Packed) {
169   TestCorrectness(
170       "\x0a\x0d\x0a\x0b\x0a\x03\x61\x67\x65\x12\x04\x1a\x02\x08\x0d");
171 }
172 
TEST(FastParse,ValueBeforeKeyInMap)173 TEST(FastParse, ValueBeforeKeyInMap) {
174   TestCorrectness("\x0a\x12\x0a\x10\x12\x09\x0a\x07\x0a\x05value\x0a\x03key");
175 }
176 
TEST(FastParse,EmptyFeatures)177 TEST(FastParse, EmptyFeatures) {
178   Example example;
179   example.mutable_features();
180   TestCorrectness(Serialize(example));
181 }
182 
TestCorrectnessJson(const string & json)183 void TestCorrectnessJson(const string& json) {
184   auto resolver = protobuf::util::NewTypeResolverForDescriptorPool(
185       "type.googleapis.com", protobuf::DescriptorPool::generated_pool());
186   string serialized;
187   auto s = protobuf::util::JsonToBinaryString(
188       resolver, "type.googleapis.com/tensorflow.Example", json, &serialized);
189   EXPECT_TRUE(s.ok()) << s;
190   delete resolver;
191   TestCorrectness(serialized);
192 }
193 
TEST(FastParse,JsonUnivalent)194 TEST(FastParse, JsonUnivalent) {
195   TestCorrectnessJson(
196       "{'features': {"
197       "  'feature': {'age': {'int64_list': {'value': [0]} }}, "
198       "  'feature': {'flo': {'float_list': {'value': [1.1]} }}, "
199       "  'feature': {'byt': {'bytes_list': {'value': ['WW8='] }}}"
200       "}}");
201 }
202 
TEST(FastParse,JsonMultivalent)203 TEST(FastParse, JsonMultivalent) {
204   TestCorrectnessJson(
205       "{'features': {"
206       "  'feature': {'age': {'int64_list': {'value': [0, 13, 23]} }}, "
207       "  'feature': {'flo': {'float_list': {'value': [1.1, 1.2, 1.3]} }}, "
208       "  'feature': {'byt': {'bytes_list': {'value': ['WW8=', 'WW8K'] }}}"
209       "}}");
210 }
211 
TEST(FastParse,SingleInt64)212 TEST(FastParse, SingleInt64) {
213   Example example;
214   (*example.mutable_features()->mutable_feature())["age"]
215       .mutable_int64_list()
216       ->add_value(13);
217   TestCorrectness(Serialize(example));
218 }
219 
ExampleWithSomeFeatures()220 static string ExampleWithSomeFeatures() {
221   Example example;
222 
223   (*example.mutable_features()->mutable_feature())[""];
224 
225   (*example.mutable_features()->mutable_feature())["empty_bytes_list"]
226       .mutable_bytes_list();
227   (*example.mutable_features()->mutable_feature())["empty_float_list"]
228       .mutable_float_list();
229   (*example.mutable_features()->mutable_feature())["empty_int64_list"]
230       .mutable_int64_list();
231 
232   BytesList* bytes_list =
233       (*example.mutable_features()->mutable_feature())["bytes_list"]
234           .mutable_bytes_list();
235   bytes_list->add_value("bytes1");
236   bytes_list->add_value("bytes2");
237 
238   FloatList* float_list =
239       (*example.mutable_features()->mutable_feature())["float_list"]
240           .mutable_float_list();
241   float_list->add_value(1.0);
242   float_list->add_value(2.0);
243 
244   Int64List* int64_list =
245       (*example.mutable_features()->mutable_feature())["int64_list"]
246           .mutable_int64_list();
247   int64_list->add_value(3);
248   int64_list->add_value(270);
249   int64_list->add_value(86942);
250 
251   return Serialize(example);
252 }
253 
TEST(FastParse,SomeFeatures)254 TEST(FastParse, SomeFeatures) { TestCorrectness(ExampleWithSomeFeatures()); }
255 
AddDenseFeature(const char * feature_name,DataType dtype,PartialTensorShape shape,bool variable_length,size_t elements_per_stride,FastParseExampleConfig * out_config)256 static void AddDenseFeature(const char* feature_name, DataType dtype,
257                             PartialTensorShape shape, bool variable_length,
258                             size_t elements_per_stride,
259                             FastParseExampleConfig* out_config) {
260   out_config->dense.emplace_back();
261   auto& new_feature = out_config->dense.back();
262   new_feature.feature_name = feature_name;
263   new_feature.dtype = dtype;
264   new_feature.shape = std::move(shape);
265   new_feature.default_value = Tensor(dtype, {});
266   new_feature.variable_length = variable_length;
267   new_feature.elements_per_stride = elements_per_stride;
268 }
269 
AddSparseFeature(const char * feature_name,DataType dtype,FastParseExampleConfig * out_config)270 static void AddSparseFeature(const char* feature_name, DataType dtype,
271                              FastParseExampleConfig* out_config) {
272   out_config->sparse.emplace_back();
273   auto& new_feature = out_config->sparse.back();
274   new_feature.feature_name = feature_name;
275   new_feature.dtype = dtype;
276 }
277 
TEST(FastParse,StatsCollection)278 TEST(FastParse, StatsCollection) {
279   const size_t kNumExamples = 13;
280   std::vector<tstring> serialized(kNumExamples, ExampleWithSomeFeatures());
281 
282   FastParseExampleConfig config_dense;
283   AddDenseFeature("bytes_list", DT_STRING, {2}, false, 2, &config_dense);
284   AddDenseFeature("float_list", DT_FLOAT, {2}, false, 2, &config_dense);
285   AddDenseFeature("int64_list", DT_INT64, {3}, false, 3, &config_dense);
286   config_dense.collect_feature_stats = true;
287 
288   FastParseExampleConfig config_varlen;
289   AddDenseFeature("bytes_list", DT_STRING, {-1}, true, 1, &config_varlen);
290   AddDenseFeature("float_list", DT_FLOAT, {-1}, true, 1, &config_varlen);
291   AddDenseFeature("int64_list", DT_INT64, {-1}, true, 1, &config_varlen);
292   config_varlen.collect_feature_stats = true;
293 
294   FastParseExampleConfig config_sparse;
295   AddSparseFeature("bytes_list", DT_STRING, &config_sparse);
296   AddSparseFeature("float_list", DT_FLOAT, &config_sparse);
297   AddSparseFeature("int64_list", DT_INT64, &config_sparse);
298   config_sparse.collect_feature_stats = true;
299 
300   FastParseExampleConfig config_mixed;
301   AddDenseFeature("bytes_list", DT_STRING, {2}, false, 2, &config_mixed);
302   AddDenseFeature("float_list", DT_FLOAT, {-1}, true, 1, &config_mixed);
303   AddSparseFeature("int64_list", DT_INT64, &config_mixed);
304   config_mixed.collect_feature_stats = true;
305 
306   for (const FastParseExampleConfig& config :
307        {config_dense, config_varlen, config_sparse, config_mixed}) {
308     {
309       Result result;
310       TF_CHECK_OK(FastParseExample(config, serialized, {}, nullptr, &result));
311       EXPECT_EQ(kNumExamples, result.feature_stats.size());
312       for (const PerExampleFeatureStats& stats : result.feature_stats) {
313         EXPECT_EQ(7, stats.features_count);
314         EXPECT_EQ(7, stats.feature_values_count);
315       }
316     }
317 
318     {
319       Result result;
320       TF_CHECK_OK(FastParseSingleExample(config, serialized[0], &result));
321       EXPECT_EQ(1, result.feature_stats.size());
322       EXPECT_EQ(7, result.feature_stats[0].features_count);
323       EXPECT_EQ(7, result.feature_stats[0].feature_values_count);
324     }
325   }
326 }
327 
RandStr(random::SimplePhilox * rng)328 string RandStr(random::SimplePhilox* rng) {
329   static const char key_char_lookup[] =
330       "0123456789{}~`!@#$%^&*()"
331       "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
332       "abcdefghijklmnopqrstuvwxyz";
333   auto len = 1 + rng->Rand32() % 200;
334   string str;
335   str.reserve(len);
336   while (len-- > 0) {
337     str.push_back(
338         key_char_lookup[rng->Rand32() % (sizeof(key_char_lookup) /
339                                          sizeof(key_char_lookup[0]))]);
340   }
341   return str;
342 }
343 
Fuzz(random::SimplePhilox * rng)344 void Fuzz(random::SimplePhilox* rng) {
345   // Generate keys.
346   auto num_keys = 1 + rng->Rand32() % 100;
347   std::unordered_set<string> unique_keys;
348   for (auto i = 0; i < num_keys; ++i) {
349     unique_keys.emplace(RandStr(rng));
350   }
351 
352   // Generate serialized example.
353   Example example;
354   string serialized_example;
355   auto num_concats = 1 + rng->Rand32() % 4;
356   std::vector<Feature::KindCase> feat_types(
357       {Feature::kBytesList, Feature::kFloatList, Feature::kInt64List});
358   std::vector<string> all_keys(unique_keys.begin(), unique_keys.end());
359   while (num_concats--) {
360     example.Clear();
361     auto num_active_keys = 1 + rng->Rand32() % all_keys.size();
362 
363     // Generate features.
364     for (auto i = 0; i < num_active_keys; ++i) {
365       auto fkey = all_keys[rng->Rand32() % all_keys.size()];
366       auto ftype_idx = rng->Rand32() % feat_types.size();
367       auto num_features = 1 + rng->Rand32() % 5;
368       switch (static_cast<Feature::KindCase>(feat_types[ftype_idx])) {
369         case Feature::kBytesList: {
370           BytesList* bytes_list =
371               (*example.mutable_features()->mutable_feature())[fkey]
372                   .mutable_bytes_list();
373           while (num_features--) {
374             bytes_list->add_value(RandStr(rng));
375           }
376           break;
377         }
378         case Feature::kFloatList: {
379           FloatList* float_list =
380               (*example.mutable_features()->mutable_feature())[fkey]
381                   .mutable_float_list();
382           while (num_features--) {
383             float_list->add_value(rng->RandFloat());
384           }
385           break;
386         }
387         case Feature::kInt64List: {
388           Int64List* int64_list =
389               (*example.mutable_features()->mutable_feature())[fkey]
390                   .mutable_int64_list();
391           while (num_features--) {
392             int64_list->add_value(rng->Rand64());
393           }
394           break;
395         }
396         default: {
397           LOG(QFATAL);
398           break;
399         }
400       }
401     }
402     serialized_example += example.SerializeAsString();
403   }
404 
405   // Test correctness.
406   TestCorrectness(serialized_example);
407 }
408 
TEST(FastParse,FuzzTest)409 TEST(FastParse, FuzzTest) {
410   const uint64 seed = 1337;
411   random::PhiloxRandom philox(seed);
412   random::SimplePhilox rng(&philox);
413   auto num_runs = 200;
414   while (num_runs--) {
415     LOG(INFO) << "runs left: " << num_runs;
416     Fuzz(&rng);
417   }
418 }
419 
TEST(TestFastParseExample,Empty)420 TEST(TestFastParseExample, Empty) {
421   Result result;
422   FastParseExampleConfig config;
423   config.sparse.push_back({"test", DT_STRING});
424   Status status =
425       FastParseExample(config, gtl::ArraySlice<tstring>(),
426                        gtl::ArraySlice<tstring>(), nullptr, &result);
427   EXPECT_TRUE(status.ok()) << status;
428 }
429 
430 }  // namespace
431 }  // namespace example
432 }  // namespace tensorflow
433