xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/decode_proto_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // DecodeProto is a TensorFlow op which extracts arbitrary fields from protos
17 // serialized as strings.
18 //
19 // See docs in ../ops/decode_proto_op.cc.
20 //
21 // This implementation reads the serialized format using a handful of calls from
22 // the WireFormatLite API used by generated proto code. WireFormatLite is marked
23 // as an "internal" proto API but is widely used in practice and highly unlikely
24 // to change. This will be much faster than the previous implementation based on
25 // constructing a temporary dynamic message in memory and using the proto
26 // reflection api to read it. It can be used with any proto whose descriptors
27 // are available at runtime but should be competitive in speed with approaches
28 // that compile in the proto definitions.
29 
30 #include <memory>
31 #include <string>
32 #include <vector>
33 
34 #include "absl/container/flat_hash_map.h"
35 #include "absl/types/span.h"
36 #include "third_party/eigen3/Eigen/Core"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/framework/tensor_types.h"
39 #include "tensorflow/core/framework/types.h"
40 #include "tensorflow/core/lib/core/errors.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/platform/protobuf.h"
43 #include "tensorflow/core/util/proto/decode.h"
44 #include "tensorflow/core/util/proto/descriptors.h"
45 #include "tensorflow/core/util/proto/proto_utils.h"
46 #include "tensorflow/core/util/ptr_util.h"
47 
48 namespace tensorflow {
49 namespace {
50 
51 using ::tensorflow::MakeUnique;
52 using ::tensorflow::protobuf::Descriptor;
53 using ::tensorflow::protobuf::DescriptorPool;
54 using ::tensorflow::protobuf::DynamicMessageFactory;
55 using ::tensorflow::protobuf::FieldDescriptor;
56 using ::tensorflow::protobuf::Message;
57 using ::tensorflow::protobuf::TextFormat;
58 using ::tensorflow::protobuf::internal::WireFormatLite;
59 using ::tensorflow::protobuf::io::CodedInputStream;
60 
61 const bool kFailOnDecodeError = true;
62 
63 // Used to store the default value of a protocol message field, casted to the
64 // type of the output tensor.
65 //
66 // TODO(paskin): Use absl::variant once TensorFlow gets absl dependencies.
67 struct DefaultValue {
68   DataType dtype = DataType::DT_INVALID;
69   union Value {
70     bool v_bool;           // DT_BOOL
71     double v_double;       // DT_DOUBLE
72     float v_float;         // DT_FLOAT
73     int8 v_int8;           // DT_INT8
74     int32 v_int32;         // DT_INT32
75     int64_t v_int64;       // DT_INT64
76     const char* v_string;  // DT_STRING
77     uint8 v_uint8;         // DT_UINT8
78     uint8 v_uint32;        // DT_UINT32
79     uint8 v_uint64;        // DT_UINT64
80   };
81   Value value;
82 };
83 
84 // Initializes a DefaultValue object.  This generic template handles numeric
85 // types and strings are handled by a template specialization below.
86 //
87 // Args:
88 //   dtype: the type of the output tensor
89 //   value: the default value as obtained from the FieldDescriptor
90 //   result: the object to initialize
91 template <typename T>
InitDefaultValue(DataType dtype,const T value,DefaultValue * result)92 Status InitDefaultValue(DataType dtype, const T value, DefaultValue* result) {
93   result->dtype = dtype;
94   switch (dtype) {
95     case DT_BOOL:
96       result->value.v_bool = static_cast<bool>(value);
97       break;
98     case DT_DOUBLE:
99       result->value.v_double = static_cast<double>(value);
100       break;
101     case DT_FLOAT:
102       result->value.v_float = static_cast<float>(value);
103       break;
104     case DT_INT8:
105       result->value.v_int8 = static_cast<int8>(value);
106       break;
107     case DT_INT32:
108       result->value.v_int32 = static_cast<int32>(value);
109       break;
110     case DT_INT64:
111       result->value.v_int64 = static_cast<int64_t>(value);
112       break;
113     case DT_UINT8:
114       result->value.v_uint8 = static_cast<uint8>(value);
115       break;
116     case DT_UINT32:
117       result->value.v_uint32 = static_cast<uint32>(value);
118       break;
119     case DT_UINT64:
120       result->value.v_uint64 = static_cast<uint64>(value);
121       break;
122     default:
123       // We should never get here, given the type checking that occurs earlier.
124       return errors::Internal(
125           "Cannot initialize default value for unsupported type: ",
126           DataTypeString(dtype));
127   }
128   return OkStatus();
129 }
130 
131 template <>
InitDefaultValue(DataType dtype,const char * value,DefaultValue * result)132 Status InitDefaultValue(DataType dtype, const char* value,
133                         DefaultValue* result) {
134   // These are sanity checks that should never trigger given the code that
135   // leads here.
136   if (TF_PREDICT_FALSE(dtype != DT_STRING)) {
137     return errors::InvalidArgument(
138         "Cannot cast field to anything but DT_STRING");
139   }
140   if (TF_PREDICT_FALSE(value == nullptr)) {
141     return errors::InvalidArgument("Null default string value.");
142   }
143   result->dtype = DT_STRING;
144   result->value.v_string = value;
145   return OkStatus();
146 }
147 
148 // Initializes a default value from the output data type and the field
149 // descriptor.
InitDefaultValueFromFieldDescriptor(DataType dtype,const FieldDescriptor * field_desc,DefaultValue * result)150 Status InitDefaultValueFromFieldDescriptor(DataType dtype,
151                                            const FieldDescriptor* field_desc,
152                                            DefaultValue* result) {
153   switch (field_desc->type()) {
154     case WireFormatLite::TYPE_DOUBLE:
155       return InitDefaultValue(dtype, field_desc->default_value_double(),
156                               result);
157     case WireFormatLite::TYPE_FLOAT:
158       return InitDefaultValue(dtype, field_desc->default_value_float(), result);
159     case WireFormatLite::TYPE_INT64:
160     case WireFormatLite::TYPE_SINT64:
161     case WireFormatLite::TYPE_SFIXED64:
162       return InitDefaultValue(dtype, field_desc->default_value_int64(), result);
163     case WireFormatLite::TYPE_FIXED64:
164     case WireFormatLite::TYPE_UINT64:
165       return InitDefaultValue(dtype, field_desc->default_value_uint64(),
166                               result);
167     case WireFormatLite::TYPE_INT32:
168     case WireFormatLite::TYPE_SINT32:
169     case WireFormatLite::TYPE_SFIXED32:
170       return InitDefaultValue(dtype, field_desc->default_value_int32(), result);
171     case WireFormatLite::TYPE_FIXED32:
172     case WireFormatLite::TYPE_UINT32:
173       return InitDefaultValue(dtype, field_desc->default_value_uint32(),
174                               result);
175     case WireFormatLite::TYPE_BOOL:
176       return InitDefaultValue(dtype, field_desc->default_value_bool(), result);
177     case WireFormatLite::TYPE_ENUM:
178       return InitDefaultValue(dtype, field_desc->default_value_enum()->number(),
179                               result);
180     case WireFormatLite::TYPE_BYTES:
181     case WireFormatLite::TYPE_STRING:
182       // Manipulating default string values as C-style pointers should be OK
183       // for typical code-generated protocol messages.  It is possible in
184       // principle to register a message descriptor on the fly, and these
185       // pointers may not be stable if that descriptor has a weird
186       // implementation.  (But the return type of default_value_string() is
187       // const string&, so it'd have to be very weird.)
188       return InitDefaultValue(dtype, field_desc->default_value_string().c_str(),
189                               result);
190     case WireFormatLite::TYPE_GROUP:
191     case WireFormatLite::TYPE_MESSAGE:
192       return InitDefaultValue(dtype, "", result);
193       // default: intentionally omitted in order to enable static checking.
194   }
195   return OkStatus();
196 }
197 
198 // A FieldInfo holds a handful of information from the FieldDescriptor
199 // and user attributes.
200 struct FieldInfo {
FieldInfotensorflow::__anon9e74c1ef0111::FieldInfo201   FieldInfo(const FieldDescriptor* field_desc, int user_index,
202             DefaultValue def_value)
203       : output_index(user_index), default_value(def_value) {
204     // Without this intermediate data structure, the profile had hotspots
205     // calling methods of FieldDescriptor.
206     number = field_desc->number();
207 
208     // The wire format library defines the same constants used in
209     // descriptor.proto. This static_cast is safe because they are guaranteed to
210     // stay in sync. We need the field type from the FieldDescriptor here
211     // because the wire format doesn't tell us anything about what happens
212     // inside a packed repeated field: there is enough information in the wire
213     // format to skip the whole field but not enough to know how to parse what's
214     // inside. For that we go to the schema.
215     type = static_cast<WireFormatLite::FieldType>(field_desc->type());
216     is_repeated = field_desc->is_repeated();
217   }
218 
219   // Disable copy and move.
220   FieldInfo(const FieldInfo&) = delete;
221   FieldInfo& operator=(const FieldInfo&) = delete;
222 
223   // Internally we sort field descriptors by wire number for fast lookup. In
224   // general this is different from the order given by the user. Output_index
225   // gives the index into the field_names and output_types attributes and into
226   // the output tensor list.
227   int output_index = -1;
228 
229   // This is a cache of the relevant fields from `FieldDescriptorProto`. This
230   // was added after noticing that FieldDescriptor->type() was using 6% of the
231   // cpu profile.
232   WireFormatLite::FieldType type;
233   int number;
234   bool is_repeated;
235   DefaultValue default_value;
236 };
237 
238 // A CountCollector counts sizes of repeated and optional fields in a proto.
239 //
240 // Each field is tracked by a single CountCollector instance. The instance
241 // manages a single count, which is stored as a pointer (it is intended to be a
242 // reference to the `sizes` output which is being filled in). The pointer is
243 // passed in at initialization.
244 //
245 // Counting is done as a separate pass in order to allocate output tensors all
246 // at once. This allows the TensorFlow runtime to optimize allocation for the
247 // consumer, while removing the need for copying inside this op. After this
248 // pass, the DenseCollector class (below) gathers the data: it is more complex
249 // and provides better motivation for the API here.
250 class CountCollector {
251  public:
252   CountCollector() = delete;
253 
254   // The count may be stored inside an Eigen Tensor to eliminate copying.
CountCollector(int32 * count)255   explicit CountCollector(int32* count) : count_ptr_(count) {}
256 
257   // Reads (in this case counts) a single value.
ReadValue(CodedInputStream * input,const FieldInfo & field)258   Status ReadValue(CodedInputStream* input, const FieldInfo& field) {
259     // Only repeated fields can have count > 1.
260     if (*count_ptr_ == 0 || field.is_repeated) {
261       (*count_ptr_)++;
262     }
263     // We expect a wire type based on the schema field_type, to allow a little
264     // more checking.
265     if (!SkipValue(input, field)) {
266       return errors::DataLoss("ReadValue: Failed skipping field when counting");
267     }
268     return OkStatus();
269   }
270 
271   // Reads (in this case counts) a length-delimited list of values.
ReadPackedValues(CodedInputStream * input,const FieldInfo & field,size_t buf_size)272   Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field,
273                           size_t buf_size) {
274     if (buf_size == 0) {
275       return OkStatus();
276     }
277 
278     const void* tmpbuf;
279     int unused_max_buf_size;
280 
281     input->GetDirectBufferPointerInline(&tmpbuf, &unused_max_buf_size);
282     // This is safe because the underlying storage for the CodedInputStream is
283     // owned by the input tensor. If it were a Cord or file-backed stream this
284     // pointer would go stale after the bytes were skipped.
285     const uint8* buf = reinterpret_cast<const uint8*>(tmpbuf);
286 
287     // Important: we skipped the input->{Push,Pop}Limit() calls for speed,
288     // so the bounds check on buf_size inside Skip() is critical, and
289     // must be done before scanning the contents.
290     if (!input->Skip(buf_size)) {
291       return errors::DataLoss("ReadPackedValues: Skipping packed field failed");
292     }
293 
294     // Dispatch to the appropriately typed field reader based on the schema
295     // type.
296     Status st;
297     switch (field.type) {
298       case WireFormatLite::TYPE_DOUBLE:
299         st = CountPackedFixed<double>(buf, buf_size);
300         break;
301       case WireFormatLite::TYPE_FLOAT:
302         st = CountPackedFixed<float>(buf, buf_size);
303         break;
304       case WireFormatLite::TYPE_INT64:
305         st = CountPackedVarint(buf, buf_size);
306         break;
307       case WireFormatLite::TYPE_UINT64:
308         st = CountPackedVarint(buf, buf_size);
309         break;
310       case WireFormatLite::TYPE_INT32:
311         st = CountPackedVarint(buf, buf_size);
312         break;
313       case WireFormatLite::TYPE_FIXED64:
314         st = CountPackedFixed<uint64>(buf, buf_size);
315         break;
316       case WireFormatLite::TYPE_FIXED32:
317         st = CountPackedFixed<uint32>(buf, buf_size);
318         break;
319       case WireFormatLite::TYPE_BOOL:
320         st = CountPackedVarint(buf, buf_size);
321         break;
322       case WireFormatLite::TYPE_STRING:
323         st = errors::DataLoss("TYPE_STRING encountered as packed");
324         break;
325       case WireFormatLite::TYPE_GROUP:
326         st = errors::DataLoss("TYPE_GROUP encountered as packed");
327         break;
328       case WireFormatLite::TYPE_MESSAGE:
329         st = errors::DataLoss("TYPE_MESSAGE encountered as packed");
330         break;
331       case WireFormatLite::TYPE_BYTES:
332         st = errors::DataLoss("TYPE_BYTES encountered as packed");
333         break;
334       case WireFormatLite::TYPE_UINT32:
335         st = CountPackedVarint(buf, buf_size);
336         break;
337       case WireFormatLite::TYPE_ENUM:
338         st = CountPackedVarint(buf, buf_size);
339         break;
340       case WireFormatLite::TYPE_SFIXED32:
341         st = CountPackedFixed<int32>(buf, buf_size);
342         break;
343       case WireFormatLite::TYPE_SFIXED64:
344         st = CountPackedFixed<int64_t>(buf, buf_size);
345         break;
346       case WireFormatLite::TYPE_SINT32:
347         st = CountPackedVarint(buf, buf_size);
348         break;
349       case WireFormatLite::TYPE_SINT64:
350         st = CountPackedVarint(buf, buf_size);
351         break;
352         // default: intentionally omitted in order to enable static checking.
353     }
354     if (!st.ok()) {
355       return st;
356     }
357 
358     if (!field.is_repeated && *count_ptr_ > 1) {
359       *count_ptr_ = 1;
360     }
361     return OkStatus();
362   }
363 
364  private:
365   // Skips a length-delimited value.
SkipBytes(CodedInputStream * input)366   static bool SkipBytes(CodedInputStream* input) {
367     uint32 length;
368     if (!input->ReadVarint32(&length)) {
369       return false;
370     }
371     return input->Skip(length);
372   }
373 
374   // Counts the number of packed varints in an array. The end of a varint is
375   // signaled by a value < 0x80, so counting them requires parsing the
376   // bytestream. It is the caller's responsibility to ensure that len > 0.
CountPackedVarint(const uint8 * buf,size_t len)377   Status CountPackedVarint(const uint8* buf, size_t len) {
378     const uint8* bound = buf + len;
379     int count;
380 
381     // The last byte in a valid encoded varint is guaranteed to have the high
382     // bit unset. We rely on this property to prevent ReadVarint64FromArray from
383     // going out of bounds, so validate the end of the buf before scanning
384     // anything.
385     if (bound[-1] & 0x80) {
386       return errors::DataLoss("Corrupt packed varint");
387     }
388 
389     // Now we can trust ReadVarint64FromArray to stay in bounds.
390     for (count = 0; buf < bound; ++count) {
391       uint64 temp;
392       bool ok;
393       buf = internal::ReadVarint64FromArray(buf, &ok, &temp);
394       if (!ok) {
395         return errors::DataLoss("Corrupt packed varint");
396       }
397     }
398 
399     *count_ptr_ += count;
400     return OkStatus();
401   }
402 
403   // Counts the number of fixed-size values in a packed field. This can be done
404   // without actually parsing anything.
405   template <typename T>
CountPackedFixed(const uint8 * unused_buf,size_t len)406   Status CountPackedFixed(const uint8* unused_buf, size_t len) {
407     int count = len / sizeof(T);
408     if (count * sizeof(T) != len) {
409       return errors::DataLoss(
410           "Illegal data length for packed fixed-size type: ", len);
411     }
412     *count_ptr_ += len / sizeof(T);
413     return OkStatus();
414   }
415 
416   // Skips a single value in the input stream. Dispatches to the appropriately
417   // typed field skipper based on the schema type tag. This is not as permissive
418   // as just handling the wire type.
SkipValue(CodedInputStream * input,const FieldInfo & field)419   static bool SkipValue(CodedInputStream* input, const FieldInfo& field) {
420     uint32 tmp32;
421     protobuf_uint64 tmp64;
422     switch (field.type) {
423       case WireFormatLite::TYPE_DOUBLE:
424         return input->ReadLittleEndian64(&tmp64);
425       case WireFormatLite::TYPE_FLOAT:
426         return input->ReadLittleEndian32(&tmp32);
427       case WireFormatLite::TYPE_INT64:
428         return input->ReadVarint64(&tmp64);
429       case WireFormatLite::TYPE_UINT64:
430         return input->ReadVarint64(&tmp64);
431       case WireFormatLite::TYPE_INT32:
432         return input->ReadVarint32(&tmp32);
433       case WireFormatLite::TYPE_FIXED64:
434         return input->ReadLittleEndian64(&tmp64);
435       case WireFormatLite::TYPE_FIXED32:
436         return input->ReadLittleEndian32(&tmp32);
437       case WireFormatLite::TYPE_BOOL:
438         return input->ReadVarint32(&tmp32);
439       case WireFormatLite::TYPE_STRING:
440         return SkipBytes(input);
441       case WireFormatLite::TYPE_GROUP:
442         return WireFormatLite::SkipField(
443             input, WireFormatLite::MakeTag(
444                        field.number, WireFormatLite::WIRETYPE_START_GROUP));
445       case WireFormatLite::TYPE_MESSAGE:
446         return SkipBytes(input);
447       case WireFormatLite::TYPE_BYTES:
448         return SkipBytes(input);
449       case WireFormatLite::TYPE_UINT32:
450         return input->ReadVarint32(&tmp32);
451       case WireFormatLite::TYPE_ENUM:
452         return input->ReadVarint32(&tmp32);
453       case WireFormatLite::TYPE_SFIXED32:
454         return input->ReadLittleEndian32(&tmp32);
455       case WireFormatLite::TYPE_SFIXED64:
456         return input->ReadLittleEndian64(&tmp64);
457       case WireFormatLite::TYPE_SINT32:
458         return input->ReadVarint32(&tmp32);
459       case WireFormatLite::TYPE_SINT64:
460         return input->ReadVarint64(&tmp64);
461         // default: intentionally omitted in order to enable static checking.
462     }
463   }
464 
465   int32* count_ptr_ = nullptr;
466 };
467 
468 // A DenseCollector accumulates values from a proto into a tensor.
469 //
470 // There is an instance of DenseCollector for each field of each proto. The
471 // DenseCollector deserializes the value from the wire directly into the
472 // preallocated output Tensor.
473 //
474 // This class is named DenseCollector because in the future there should be a
475 // SparseCollector that accumulates field data into sparse tensors if the user
476 // requests it.
477 class DenseCollector {
478  public:
479   DenseCollector() = delete;
480 
481   // A DenseCollector applies to one field of a serialized message.
482   // Note that default_value.dtype is the type of the output tensor.
DenseCollector(uint8 * datap,DefaultValue default_value,int max_repeat_count)483   DenseCollector(uint8* datap, DefaultValue default_value, int max_repeat_count)
484       : datap_(datap),
485         default_value_(default_value),
486         max_repeat_count_(max_repeat_count) {}
487 
488   // Reads a value from the input stream and stores it.
489   //
490   // Always inlining gave a ~50% speedup on microbenchmarks at one point.
491   // TODO(nix): try removing it to see if that still holds.
492   // TODO(jsimsa): ABSL_ATTRIBUTE_ALWAYS_INLINE
ReadValue(CodedInputStream * input,const FieldInfo & field)493   Status ReadValue(CodedInputStream* input, const FieldInfo& field) {
494     // For required and optional fields, we overwrite values[0] with
495     // the latest one in the wire stream.
496     // See https://developers.google.com/protocol-buffers/docs/encoding#optional
497     // Only for repeated fields do we advance the next_repeat_index_ past 1.
498     // TODO(nix): to handle oneof we must also zero out any previous values
499     //  seen on the wire.
500     int32_t index = 0;
501     if (field.is_repeated) {
502       index = next_repeat_index_;
503     }
504     next_repeat_index_ = index + 1;
505 
506     return internal::ReadValue(input, field.type, field.number,
507                                default_value_.dtype, index, datap_);
508   }
509 
510   // Reads and stores a length-delimited list of values.
ReadPackedValues(CodedInputStream * input,const FieldInfo & field,const size_t buf_size)511   Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field,
512                           const size_t buf_size) {
513     const void* buf;
514     int unused_max_buf_size;
515     input->GetDirectBufferPointerInline(&buf, &unused_max_buf_size);
516     // This is safe because the underlying storage for the CodedInputStream is
517     // owned by the input tensor. If it were a Cord or file-backed stream this
518     // pointer would go stale after the bytes were skipped.
519     if (!input->Skip(buf_size)) {
520       return errors::DataLoss(
521           "ReadPackedValues: Skipping packed field failed.  Field tag: ",
522           field.number);
523     }
524 
525     // Setting stride=0 causes new values to overwrite old ones for
526     // non-repeated fields.
527     const int stride = field.is_repeated ? 1 : 0;
528 
529     if (next_repeat_index_ >= max_repeat_count_) {
530       return errors::DataLoss(
531           "ReadPackedValues: Tried to write more entries than allowed.  "
532           "Field tag: ",
533           field.number, ", Max entries allowed: ", max_repeat_count_);
534     } else {
535       return internal::ReadPackedFromArray(buf, buf_size, field.type,
536                                            field.number, default_value_.dtype,
537                                            stride, &next_repeat_index_, datap_);
538     }
539   }
540 
541   // Fills in any missing values in the output array with defaults. Dispatches
542   // to the appropriately typed field default based on the runtime type tag.
FillWithDefaults()543   Status FillWithDefaults() {
544     switch (default_value_.dtype) {
545       case DataType::DT_BOOL:
546         return FillDefault<bool>(default_value_.value.v_bool);
547       case DataType::DT_FLOAT:
548         return FillDefault<float>(default_value_.value.v_float);
549       case DataType::DT_DOUBLE:
550         return FillDefault<double>(default_value_.value.v_double);
551       case DataType::DT_INT8:
552         return FillDefault<int8>(default_value_.value.v_int8);
553       case DataType::DT_INT32:
554         return FillDefault<int32>(default_value_.value.v_int32);
555       case DataType::DT_INT64:
556         return FillDefault<int64_t>(default_value_.value.v_int64);
557       case DataType::DT_STRING:
558         return FillDefault<tstring>(default_value_.value.v_string);
559       case DataType::DT_UINT8:
560         return FillDefault<uint8>(default_value_.value.v_uint8);
561       case DataType::DT_UINT32:
562         return FillDefault<uint32>(default_value_.value.v_uint32);
563       case DataType::DT_UINT64:
564         return FillDefault<uint64>(default_value_.value.v_uint64);
565       default:
566         // There are many tensorflow dtypes not handled here, but they
567         // should not come up unless type casting is added to the Op.
568         // Chaining with tf.cast() should do the right thing until then.
569         return errors::DataLoss("Failed filling defaults for ",
570                                 DataTypeString(default_value_.dtype));
571     }
572   }
573 
574  private:
575   // Fills empty values in the dense representation with a default value. This
576   // uses next_repeat_index_ which counts the number of parsed values for the
577   // field.
578   template <class T>
FillDefault(const T & default_value)579   Status FillDefault(const T& default_value) {
580     for (int i = next_repeat_index_; i < max_repeat_count_; i++) {
581       reinterpret_cast<T*>(datap_)[i] = default_value;
582     }
583     return OkStatus();
584   }
585 
586   int32 next_repeat_index_ = 0;
587 
588   // This is a pointer to data_[message_index_]. There is no bounds checking at
589   // this level: we computed the max repeat size for each field in
590   // CountCollector and use the same code to traverse it here, so we are
591   // guaranteed not to be called for more items than we have allocated space.
592   void* const datap_ = nullptr;
593 
594   const DefaultValue default_value_;
595   const int max_repeat_count_ = 0;
596 };
597 
598 class DecodeProtoOp : public OpKernel {
599  public:
DecodeProtoOp(OpKernelConstruction * context)600   explicit DecodeProtoOp(OpKernelConstruction* context) : OpKernel(context) {
601     string descriptor_source;
602     OP_REQUIRES_OK(context,
603                    context->GetAttr("descriptor_source", &descriptor_source));
604 
605     // We always get back a desc_pool, but we may not own it. If we own it,
606     // owned_desc_pool_ will be filled in.
607     DescriptorPool const* desc_pool;
608     OP_REQUIRES_OK(context, GetDescriptorPool(context->env(), descriptor_source,
609                                               &desc_pool, &owned_desc_pool_));
610 
611     string message_type;
612     OP_REQUIRES_OK(context, context->GetAttr("message_type", &message_type));
613 
614     const Descriptor* message_desc =
615         desc_pool->FindMessageTypeByName(message_type);
616     OP_REQUIRES(context, message_desc != nullptr,
617                 errors::InvalidArgument("No descriptor found for message type ",
618                                         message_type));
619 
620     std::vector<string> field_names;
621     OP_REQUIRES_OK(context, context->GetAttr("field_names", &field_names));
622     std::vector<DataType> output_types;
623     OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_types));
624     OP_REQUIRES(
625         context, field_names.size() == output_types.size(),
626         errors::InvalidArgument("field_names and output_types attributes must "
627                                 "have the same length"));
628 
629     // Gather the field descriptors and check that requested output types match.
630     int field_index = 0;
631     std::vector<const FieldDescriptor*> field_descs;
632     std::vector<const FieldDescriptor*> exts;
633     absl::flat_hash_map<string, const FieldDescriptor*> ext_name_to_field;
634     std::vector<const FieldDescriptor*>::iterator ext_it = exts.begin();
635     for (const string& name : field_names) {
636       auto fd = message_desc->FindFieldByName(name);
637       if (fd == nullptr) {
638         // If field can't be found in original message, try to find a matching
639         // extension (by its full_name). First check a hashmap for a matching
640         // extension, and if not found, then iterate through available
641         // extensions to find a match (updating the hashmap while iterating.)
642         auto lookup_result = ext_name_to_field.find(name);
643         if (lookup_result != ext_name_to_field.end()) {
644           fd = lookup_result->second;
645         } else {
646           if (ext_it == exts.begin()) {
647             desc_pool->FindAllExtensions(message_desc, &exts);
648             ext_it = exts.begin();
649           }
650           while (ext_it != exts.end()) {
651             auto ext_name = (*ext_it)->full_name();
652             auto ext_field = *ext_it;
653             ++ext_it;
654 
655             ext_name_to_field.insert({ext_name, ext_field});
656             if (ext_name == name) {
657               fd = ext_field;
658               break;
659             }
660           }
661         }
662       }
663       OP_REQUIRES(context, fd != nullptr,
664                   errors::InvalidArgument("Unknown field: ", name,
665                                           " in message type ", message_type));
666       OP_REQUIRES(
667           context,
668           proto_utils::IsCompatibleType(fd->type(), output_types[field_index]),
669           // Many TensorFlow types don't have corresponding proto types and the
670           // user will get an error if they are requested. It would be nice to
671           // allow conversions here, but tf.cast already exists so we don't
672           // duplicate the functionality.
673           errors::InvalidArgument("Unexpected output type for ",
674                                   fd->full_name(), ": ", fd->cpp_type(), " to ",
675                                   output_types[field_index]));
676 
677       field_index++;
678       field_descs.push_back(fd);
679     }
680 
681     // Internally we want the field_descs sorted by their number on the wire.
682     // But the output tensors are allocated in the order given by the caller.
683     // Build a mapping i->j, where field_descs[i] corresponds to outputs[j].
684     std::vector<int> output_indices;
685     output_indices.reserve(field_names.size());
686     for (int i = 0; i < field_names.size(); i++) {
687       output_indices.push_back(i);
688     }
689     std::sort(output_indices.begin(), output_indices.end(),
690               [field_descs](int a, int b) {
691                 return field_descs[a]->number() < field_descs[b]->number();
692               });
693 
694     // Now store the fields in sorted order.
695     for (int i = 0; i < field_names.size(); i++) {
696       const int output_index = output_indices[i];
697       const DataType dtype = output_types[output_index];
698       const FieldDescriptor* field_descriptor = field_descs[output_index];
699       DefaultValue default_value;
700       OP_REQUIRES_OK(context, InitDefaultValueFromFieldDescriptor(
701                                   dtype, field_descriptor, &default_value));
702       fields_.push_back(
703           MakeUnique<FieldInfo>(field_descriptor, output_index, default_value));
704     }
705 
706     message_prototype_ = message_factory_.GetPrototype(message_desc);
707     OP_REQUIRES(context, message_prototype_ != nullptr,
708                 errors::InvalidArgument("Couldn't get prototype message: ",
709                                         message_desc->full_name()));
710     string format;
711     OP_REQUIRES_OK(context, context->GetAttr("message_format", &format));
712     OP_REQUIRES(
713         context, format == "binary" || format == "text",
714         errors::InvalidArgument("format must be one of binary or text"));
715     is_binary_ = format == "binary";
716 
717     // Enable the initial protobuf sanitizer, which is much more expensive than
718     // the decoder.
719     // TODO(nix): Remove this once the fast decoder has passed security review.
720     OP_REQUIRES_OK(context, context->GetAttr("sanitize", &sanitize_));
721   }
722 
Compute(OpKernelContext * ctx)723   void Compute(OpKernelContext* ctx) override {
724     const Tensor& buf_tensor = ctx->input(0);
725     int message_count = buf_tensor.NumElements();
726     OP_REQUIRES(ctx, message_count >= 1,
727                 errors::InvalidArgument(
728                     "Bufs argument must contain at least one value"));
729 
730     int field_count = fields_.size();
731 
732     // Save the argument shape for later, then flatten the input Tensor since we
733     // are working componentwise. We will restore the same shape in the returned
734     // Tensor.
735     const TensorShape& shape_prefix = buf_tensor.shape();
736 
737     TensorShape sizes_shape = shape_prefix;
738     sizes_shape.AddDim(field_count);
739     Tensor* sizes_tensor = nullptr;
740     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, sizes_shape, &sizes_tensor));
741 
742     // This is used to allocate binary bufs if used. It serves only to define
743     // memory ownership.
744     std::vector<tstring> tmp_binary_bufs(message_count);
745 
746     // These are the actual buffers to use, which may be in tmp_binary_bufs
747     // or may be pointers into the buf_tensor. Either way they are not owned
748     // here.
749     std::vector<const tstring*> bufs;
750 
751     if (is_binary_ && !sanitize_) {
752       // Fast path.
753       for (int mi = 0; mi < message_count; ++mi) {
754         const tstring* buf = &buf_tensor.flat<tstring>()(mi);
755         bufs.push_back(buf);
756       }
757     } else {
758       // We will have to allocate a copy, either to convert from text to binary
759       // or to sanitize a binary proto.
760       for (int mi = 0; mi < message_count; ++mi) {
761         ReserializeMessage(ctx, buf_tensor.flat<tstring>()(mi),
762                            &tmp_binary_bufs[mi]);
763         if (!ctx->status().ok()) {
764           return;
765         }
766         bufs.push_back(&tmp_binary_bufs[mi]);
767       }
768     }
769 
770     // Walk through all the strings in the input tensor, counting the number of
771     // fields in each. We can't allocate our actual output Tensor until we know
772     // the maximum repeat count, so we do a first pass through the serialized
773     // proto just counting fields. We always allocate at least one value so that
774     // optional fields are populated with default values - this avoids a TF
775     // conditional when handling the output data. The caller can distinguish
776     // between real data and defaults using the repeat count matrix that is
777     // returned by decode_proto.
778     std::vector<int32> max_sizes(field_count, 1);
779     for (int mi = 0; mi < message_count; ++mi) {
780       CountFields(ctx, mi, *bufs[mi], sizes_tensor, &max_sizes);
781       if (!ctx->status().ok()) {
782         return;
783       }
784     }
785 
786     // Allocate the output tensors now that we've seen the max size.
787     // TODO(nix): Use allocate_output_or_forward_input for the largest
788     //   output tensor. This can avoid one large allocation by re-using
789     //   the memory of the input tensor.
790     std::vector<Tensor*> outputs(field_count);
791     for (int fi = 0; fi < field_count; ++fi) {
792       TensorShape flat_shape = {static_cast<int64_t>(message_count),
793                                 max_sizes[fi]};
794       TensorShape out_shape = shape_prefix;
795       out_shape.AddDim(max_sizes[fi]);
796 
797       // Surprisingly we don't specify the types from the output_types
798       // attribute: that is done for us based on the Op declaration:
799       //  REGISTER_OP(...)
800       //    .Attr("output_types: list(type) >= 0")
801       //    .Output("values: output_types")
802       OP_REQUIRES_OK(ctx, ctx->allocate_output(fields_[fi]->output_index + 1,
803                                                out_shape, &outputs[fi]));
804     }
805 
806     // Make the second pass through the serialized proto, decoding into
807     // preallocated tensors.
808     AccumulateFields(ctx, bufs, outputs);
809   }
810 
811  private:
812   // Copy a serialized message to binary, e.g. to handle text proto inputs.
ReserializeMessage(OpKernelContext * ctx,const tstring & buf,tstring * binary_buf)813   void ReserializeMessage(OpKernelContext* ctx, const tstring& buf,
814                           tstring* binary_buf) {
815     // Handle text protos by translating them to binary.
816     std::unique_ptr<Message> message(message_prototype_->New());
817     OP_REQUIRES(ctx, message, errors::DataLoss("Initializing message failed"));
818 
819     if (is_binary_) {
820       // If we get here we are sanitizing the input protobuf by parsing
821       // and reserializing it with a trusted (but very slow) library.
822       OP_REQUIRES(ctx, message->ParseFromString(buf),
823                   errors::DataLoss("Unable to parse binary protobuf"));
824     } else {
825       OP_REQUIRES(ctx, TextFormat::ParseFromString(buf, message.get()),
826                   errors::DataLoss("Unable to parse text protobuf"));
827     }
828 
829     OP_REQUIRES(ctx, SerializeToTString(*message, binary_buf),
830                 errors::DataLoss("Unable to reserialize text proto as binary"));
831   }
832 
833   // Count the number of occurrences of each requested field in a message batch.
CountFields(OpKernelContext * ctx,int message_index,const tstring & buf,Tensor * sizes_tensor,std::vector<int32> * max_sizes)834   void CountFields(OpKernelContext* ctx, int message_index, const tstring& buf,
835                    Tensor* sizes_tensor, std::vector<int32>* max_sizes) {
836     int field_count = fields_.size();
837 
838     CodedInputStream input(reinterpret_cast<const uint8*>(buf.c_str()),
839                            buf.size());
840 
841     std::vector<int32> field_sizes(field_count, 0);
842     std::vector<CountCollector> counters;
843     counters.reserve(field_count);
844     for (int i = 0; i < field_count; i++) {
845       counters.emplace_back(&field_sizes[i]);
846     }
847 
848     Status st = Collect(&input, absl::MakeSpan(counters));
849     if (st.ok() && !input.ConsumedEntireMessage()) {
850       st = errors::DataLoss("CountFields: Failed to consume entire buffer");
851     }
852     if (kFailOnDecodeError) {
853       OP_REQUIRES_OK(ctx, st);  // NOLINT
854     }
855     if (!st.ok()) {
856       // This code suppresses the corrupt proto, treating it as empty
857       // to avoid crashing the process.
858       LOG(WARNING) << "Proto counting error for message type " << message_type_
859                    << ": " << st;
860 
861       for (int fi = 0; fi < field_count; fi++) {
862         field_sizes[fi] = 0;
863       }
864       // Finished decoding this message.
865       return;
866     }
867 
868     // Update the size tensor and max repeat size for each field.
869     auto sizes = sizes_tensor->flat_inner_dims<int32>();
870     for (int fi = 0; fi < field_count; fi++) {
871       int32_t size = field_sizes[fi];
872       sizes(message_index, fields_[fi]->output_index) = size;
873       if ((*max_sizes)[fi] < size) {
874         (*max_sizes)[fi] = size;
875       }
876     }
877   }
878 
879   // Parse fields from a serialized message into preallocated tensors.
AccumulateFields(OpKernelContext * ctx,const std::vector<const tstring * > & bufs,std::vector<Tensor * > outputs)880   void AccumulateFields(OpKernelContext* ctx,
881                         const std::vector<const tstring*>& bufs,
882                         std::vector<Tensor*> outputs) {
883     struct TensorInfo {
884       explicit TensorInfo(Tensor* tensor) {
885         // Note that we can decode only max_repeat_count values before overflow.
886         // No other bounds checking is done for repeated fields. For
887         // optional fields there is a check to make sure that only the last
888         // value on the wire appears in the output tensor.
889         dtype = tensor->dtype();
890         last_dim_size = tensor->dim_size(tensor->dims() - 1);
891 
892         if (dtype != DT_STRING) {
893           const int element_size = DataTypeSize(dtype);
894           CHECK_GT(element_size, 0);
895           stride = last_dim_size * element_size;
896 
897           const int64_t flatshape[1] = {tensor->NumElements() * element_size};
898           data = tensor->bit_casted_shaped<uint8, 1>(flatshape).data();
899         } else {
900           // DataTypeSize() returns 0 for string types.
901           stride = last_dim_size * sizeof(tstring);
902           data = reinterpret_cast<uint8*>(tensor->flat<tstring>().data());
903         }
904       }
905 
906       DataType dtype;
907       int last_dim_size;
908       int stride;
909       uint8* data;
910     };
911 
912     int field_count = fields_.size();
913 
914     std::vector<TensorInfo> tensors;
915     tensors.reserve(field_count);
916     for (int fi = 0; fi < field_count; fi++) {
917       tensors.emplace_back(outputs[fi]);
918     }
919 
920     for (int message_index = 0; message_index < bufs.size(); ++message_index) {
921       const tstring& buf = *bufs[message_index];
922 
923       std::vector<DenseCollector> collectors;
924       collectors.reserve(field_count);
925       for (int output_index = 0; output_index < field_count; ++output_index) {
926         const TensorInfo& info = tensors[output_index];
927         const FieldInfo* field_info = fields_[output_index].get();
928         DCHECK(field_info != nullptr);
929         const DefaultValue default_value = field_info->default_value;
930         collectors.emplace_back(info.data + message_index * info.stride,
931                                 default_value, info.last_dim_size);
932       }
933 
934       // Fill in output tensors from the wire.
935       CodedInputStream input(reinterpret_cast<const uint8*>(buf.c_str()),
936                              buf.size());
937       Status st = Collect(&input, absl::MakeSpan(collectors));
938       if (st.ok() && !input.ConsumedEntireMessage()) {
939         st = errors::DataLoss(
940             "AccumulateFields: Failed to consume entire buffer");
941       }
942       if (kFailOnDecodeError) {
943         OP_REQUIRES_OK(ctx, st);  // NOLINT
944       }
945       if (!st.ok()) {
946         // This code suppresses the corrupt proto, treating it as empty
947         // to avoid crashing training.
948         LOG(WARNING) << "Proto counting error for message type "
949                      << message_type_ << ": " << st;
950       }
951 
952       // Fill the remainder of the dense outputs with default values.
953       for (auto& collector : collectors) {
954         OP_REQUIRES_OK(ctx, collector.FillWithDefaults());
955       }
956     }
957   }
958 
959   // Traverses a serialized protobuf, dispatching values to the collectors.
960   template <class CollectorClass>
Collect(CodedInputStream * input,absl::Span<CollectorClass> collectors)961   Status Collect(CodedInputStream* input,
962                  absl::Span<CollectorClass> collectors) {
963     // At the beginning of each loop, the last field number that was seen,
964     // regardless of whether it was collected or not, or -1 if no field has
965     // been seen before.
966     int last_seen_field_number = -1;
967     // The FieldInfo that is expected to be used next.
968     // It was either used to collect the last seen field number, or if the
969     // last seen field number was not in fields_, it is the next FieldInfo after
970     // the last seen field number. At the beginning it is the first FieldInfo.
971     auto expected_field_info_iter = fields_.begin();
972 
973     // The 'tag' variable should always be treated as tainted.
974     for (uint32 tag = input->ReadTag();
975          tag != 0 && WireFormatLite::GetTagWireType(tag) !=
976                          WireFormatLite::WIRETYPE_END_GROUP;
977          tag = input->ReadTag()) {
978       DCHECK(expected_field_info_iter == fields_.begin() ||
979              last_seen_field_number >
980                  (*(expected_field_info_iter - 1))->number);
981       DCHECK(expected_field_info_iter == fields_.end() ||
982              last_seen_field_number <= (*expected_field_info_iter)->number);
983 
984       // The field wire number.
985       const int field_number = WireFormatLite::GetTagFieldNumber(tag);
986       // The field info associated with the field wire number.
987       const FieldInfo* field_info = nullptr;
988 
989       // fields_ are ordered by their field numbers. If the field numbers
990       // on wire are also ordered (which is a convention), then we can
991       // monotonically increment `expected_field_info_iter` as the field
992       // numbers on wire get larger. If we detect any out-of-order
993       // field number, we reset `expected_field_info_iter`, and expect that
994       // future wire numbers are ordered. This algorithm is quadratic in the
995       // worst case where field numbers on wire are in descending order, however
996       // it works well in the case where two serialized protobufs are
997       // concatenated together.
998       if (field_number < last_seen_field_number) {
999         expected_field_info_iter = fields_.begin();
1000       }
1001 
1002       // Advance expected_field_info_iter until
1003       // field_number <= expected_field_number.
1004       for (; expected_field_info_iter != fields_.end();
1005            ++expected_field_info_iter) {
1006         DCHECK(expected_field_info_iter == fields_.begin() ||
1007                field_number > (*(expected_field_info_iter - 1))->number);
1008         const FieldInfo* expected_field_info = expected_field_info_iter->get();
1009         if (field_number <= expected_field_info->number) {
1010           if (field_number == expected_field_info->number) {
1011             field_info = expected_field_info;
1012           }
1013           break;
1014         }
1015       }
1016       last_seen_field_number = field_number;
1017       if (!field_info) {
1018         // This DCHECK verifies that if we skip a field, we didn't want it.
1019         // In particular, field_builders is empty or the field_number is either:
1020         // before fields_.begin().number or  after (fields_.end() - 1).number or
1021         // in-between expected_field_info_iter and expected_field_info_iter - 1.
1022         DCHECK(fields_.empty() || (field_number < (*fields_.begin())->number) ||
1023                (field_number > (*(fields_.end() - 1))->number) ||
1024                (((*(expected_field_info_iter - 1))->number < field_number) &&
1025                 (field_number < (*(expected_field_info_iter))->number)));
1026         // Unknown and unrequested fields are skipped.
1027         if (!WireFormatLite::SkipField(input, tag)) {
1028           return errors::DataLoss("Failed skipping unrequested field");
1029         }
1030         continue;
1031       }
1032 
1033       TF_RETURN_IF_ERROR(CollectField(
1034           *field_info, WireFormatLite::GetTagWireType(tag), input,
1035           &collectors[expected_field_info_iter - fields_.begin()]));
1036     }
1037     return OkStatus();
1038   }
1039 
1040   // Collects values for a single field.
1041   template <class CollectorClass>
CollectField(const FieldInfo & field,WireFormatLite::WireType wire_type,CodedInputStream * input,CollectorClass * collector)1042   Status CollectField(const FieldInfo& field,
1043                       WireFormatLite::WireType wire_type,
1044                       CodedInputStream* input, CollectorClass* collector) {
1045     // The wire format library defines the same constants used in
1046     // descriptor.proto. This static_cast is safe because they are guaranteed to
1047     // stay in sync.
1048     //
1049     // We need the field type from the FieldDescriptor here because the wire
1050     // format doesn't tell us anything about what happens inside a packed
1051     // repeated field: there is enough information in the wire format to skip
1052     // the whole field but not enough to know how to parse what's inside. For
1053     // that we go to the schema.
1054     WireFormatLite::WireType schema_wire_type =
1055         WireFormatLite::WireTypeForFieldType(field.type);
1056 
1057     // Handle packed repeated fields. SkipField would skip the whole
1058     // length-delimited blob without letting us count the values, so we have to
1059     // scan them ourselves.
1060     if (wire_type == WireFormatLite::WIRETYPE_LENGTH_DELIMITED &&
1061         schema_wire_type != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
1062       // Handle packed repeated primitives.
1063       int length;
1064       if (!input->ReadVarintSizeAsInt(&length)) {
1065         return errors::DataLoss("CollectField: Failed reading packed size");
1066       }
1067       return collector->ReadPackedValues(input, field, length);
1068     }
1069 
1070     // Read ordinary values, including strings, bytes, and messages.
1071     if (wire_type != schema_wire_type) {
1072       if (!WireFormatLite::SkipField(
1073               input, WireFormatLite::MakeTag(field.number, wire_type))) {
1074         return errors::DataLoss(
1075             "CollectField: Failed skipping malformed field");
1076       }
1077       return OkStatus();
1078     }
1079     return collector->ReadValue(input, field);
1080   }
1081 
1082   string message_type_;
1083   // Note that fields are sorted by increasing field number, which is not in
1084   // general the order given by the user-specified field_names and output_types
1085   // Op attributes.
1086   std::vector<std::unique_ptr<const FieldInfo>> fields_;
1087 
1088   // Owned_desc_pool_ is null when using descriptor_source=local.
1089   std::unique_ptr<DescriptorPool> owned_desc_pool_;
1090   DynamicMessageFactory message_factory_;
1091   const Message* message_prototype_;
1092 
1093   // True if decoding binary format, false if decoding text format.
1094   bool is_binary_;
1095 
1096   // True if the protos should be sanitized before parsing. Enables the initial
1097   // protobuf sanitizer, which is much more expensive than the decoder. The flag
1098   // defaults to true but can be set to false for trusted sources.
1099   //
1100   // TODO(nix): Flip the default to false when the fast decoder has passed
1101   // security review.
1102   bool sanitize_;
1103 
1104   TF_DISALLOW_COPY_AND_ASSIGN(DecodeProtoOp);
1105 };
1106 
1107 REGISTER_KERNEL_BUILDER(Name("DecodeProtoV2").Device(DEVICE_CPU),
1108                         DecodeProtoOp);
1109 
1110 }  // namespace
1111 }  // namespace tensorflow
1112