xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/proto/decode.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Inline functions for parsing the protocol buffers wire format.
17 //
18 // These functions have been optimized at the expense of safety.
19 // They are broken out into a separate file for readability but are
20 // not intended for use by clients other than the decode_proto op.
21 //
22 // The calling code in the decode_proto op does some fairly
23 // complicated things to ensure that this code is called
24 // safely. Changes to this code should be thoroughly fuzz tested.
25 
26 #ifndef TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
27 #define TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
28 
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace tensorflow {
35 namespace internal {
36 
37 using tensorflow::protobuf::internal::WireFormatLite;
38 using tensorflow::protobuf::io::CodedInputStream;
39 using tensorflow::protobuf::io::CodedOutputStream;
40 using tensorflow::protobuf::io::StringOutputStream;
41 
42 // Converts an uint64 to an int64 without loss of information.
43 // Unsigned values greater than INT64_MAX are represented as
44 // negative numbers by wrapping (same as twos-complement bit equivalence).
WrapUnsignedAsSigned64(uint64 unsigned_value)45 inline int64_t WrapUnsignedAsSigned64(uint64 unsigned_value) {
46   // For a detailed explanation of why this works to wrap unsigned ints, see
47   // http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
48   // Both if tests should be optimized out.
49   if (unsigned_value <= INT64_MAX) {
50     return static_cast<int64_t>(unsigned_value);
51   }
52   // The C++ spec allows an architecture where this test is required.
53   if (unsigned_value >= INT64_MIN) {
54     return static_cast<int64_t>(unsigned_value - INT64_MIN) + INT64_MIN;
55   }
56   return 0;  // This should never occur.
57 }
58 
59 // Converts an uint32 to an int32 without loss of information.
60 // Unsigned values greater than INT_MAX are represented as
61 // negative numbers by wrapping (same as twos-complement bit equivalence).
WrapUnsignedAsSigned32(uint32 unsigned_value)62 inline int32 WrapUnsignedAsSigned32(uint32 unsigned_value) {
63   // For a detailed explanation of why this works to wrap unsigned ints, see
64   // http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
65   // Both if tests should be optimized out.
66   if (unsigned_value <= INT_MAX) {
67     return static_cast<int32>(unsigned_value);
68   }
69   // The C++ spec allows an architecture where this test is required.
70   if (unsigned_value >= INT_MIN) {
71     return static_cast<int32>(unsigned_value - INT_MIN) + INT_MIN;
72   }
73   return 0;  // This should never occur.
74 }
75 
76 // Reads a single varint32 from a byte array.
77 // It is the caller's responsibility to ensure that there is enough
78 // space in the buffer.
79 // The ok value will be set to false if the buffer does not contain
80 // a valid varint.
81 inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok,
82                                           uint64* value);
83 
84 // Reads a single varint32 from a byte array.
85 // It is the caller's responsibility to ensure that there is enough
86 // space in the buffer.
87 // The ok value will be set to false if the buffer does not contain
88 // a valid varint.
89 // This is slightly less efficient than the private version in
90 // coded_stream.cc but we duplicate less code by calling
91 // the 64 bit version instead of copying the code.
ReadVarint32FromArray(const uint8 * buffer,bool * ok,uint32 * value)92 inline const uint8* ReadVarint32FromArray(const uint8* buffer, bool* ok,
93                                           uint32* value) {
94   uint64 tmp = 0;
95   const uint8* buf = ReadVarint64FromArray(buffer, ok, &tmp);
96   *value = tmp & 0xffffffff;
97   return buf;
98 }
99 
100 // Reads a single proto field value from a byte array into an array.
101 // The array is part of a Tensor that was allocated by the caller
102 // with type TensorType, while DeclaredType is the proto field type.
103 template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
104 const uint8* ReadFromArray(const uint8* buf, TensorType* value);
105 
106 template <>
107 inline const uint8* ReadFromArray<int64_t, WireFormatLite::TYPE_INT32>(
108     const uint8* buf, int64_t* value) {
109   uint32 temp = 0;
110   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
111   buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
112   *value = static_cast<int64_t>(temp);
113   return buf;
114 }
115 
116 template <>
117 inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_INT32>(
118     const uint8* buf, int32* value) {
119   uint32 temp = 0;
120   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
121   buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
122   *value = static_cast<int32>(temp);
123   return buf;
124 }
125 
126 template <>
127 inline const uint8* ReadFromArray<int64_t, WireFormatLite::TYPE_INT64>(
128     const uint8* buf, int64_t* value) {
129   uint64 temp = 0;
130   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
131   buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
132   *value = WrapUnsignedAsSigned64(temp);
133   return buf;
134 }
135 
136 template <>
137 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_UINT32>(
138     const uint8* buf, uint64* value) {
139   uint32 temp = 0;
140   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
141   buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
142   *value = temp;
143   return buf;
144 }
145 
146 template <>
147 inline const uint8* ReadFromArray<uint32, WireFormatLite::TYPE_UINT32>(
148     const uint8* buf, uint32* value) {
149   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
150   return ReadVarint32FromArray(buf, &unused_ok, value);
151 }
152 
153 template <>
154 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_UINT64>(
155     const uint8* buf, uint64* value) {
156   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
157   return ReadVarint64FromArray(buf, &unused_ok, value);
158 }
159 
160 template <>
161 inline const uint8* ReadFromArray<int64_t, WireFormatLite::TYPE_SINT32>(
162     const uint8* buf, int64_t* value) {
163   uint64 temp = 0;
164   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
165   buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
166   *value = WireFormatLite::ZigZagDecode32(temp);
167   return buf;
168 }
169 
170 template <>
171 inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SINT32>(
172     const uint8* buf, int32* value) {
173   uint32 temp = 0;
174   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
175   buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
176   *value = WireFormatLite::ZigZagDecode32(temp);
177   return buf;
178 }
179 
180 template <>
181 inline const uint8* ReadFromArray<int64_t, WireFormatLite::TYPE_SINT64>(
182     const uint8* buf, int64_t* value) {
183   uint64 temp = 0;
184   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
185   buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
186   *value = WireFormatLite::ZigZagDecode64(temp);
187   return buf;
188 }
189 
190 template <>
191 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_FIXED32>(
192     const uint8* buf, uint64* value) {
193   uint32 temp;
194   buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
195                                                WireFormatLite::TYPE_FIXED32>(
196       buf, &temp);
197   *value = temp;
198   return buf;
199 }
200 
201 template <>
202 inline const uint8* ReadFromArray<uint32, WireFormatLite::TYPE_FIXED32>(
203     const uint8* buf, uint32* value) {
204   uint32 temp;
205   buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
206                                                WireFormatLite::TYPE_FIXED32>(
207       buf, &temp);
208   *value = WrapUnsignedAsSigned32(temp);
209   return buf;
210 }
211 
212 template <>
213 inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_FIXED64>(
214     const uint8* buf, uint64* value) {
215   protobuf_uint64 temp;
216   buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_uint64,
217                                                WireFormatLite::TYPE_FIXED64>(
218       buf, &temp);
219   *value = WrapUnsignedAsSigned64(temp);
220   return buf;
221 }
222 
223 template <>
224 inline const uint8* ReadFromArray<int64_t, WireFormatLite::TYPE_SFIXED32>(
225     const uint8* buf, int64_t* value) {
226   int32_t temp;
227   buf = WireFormatLite::ReadPrimitiveFromArray<int32,
228                                                WireFormatLite::TYPE_SFIXED32>(
229       buf, &temp);
230   *value = temp;
231   return buf;
232 }
233 
234 template <>
235 inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SFIXED32>(
236     const uint8* buf, int32* value) {
237   return WireFormatLite::ReadPrimitiveFromArray<int32,
238                                                 WireFormatLite::TYPE_SFIXED32>(
239       buf, value);
240 }
241 
242 template <>
243 inline const uint8* ReadFromArray<int64_t, WireFormatLite::TYPE_SFIXED64>(
244     const uint8* buf, int64_t* value) {
245   protobuf_int64 temp;
246   buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_int64,
247                                                WireFormatLite::TYPE_SFIXED64>(
248       buf, &temp);
249   *value = temp;
250   return buf;
251 }
252 
253 template <>
254 inline const uint8* ReadFromArray<float, WireFormatLite::TYPE_FLOAT>(
255     const uint8* buf, float* value) {
256   return WireFormatLite::ReadPrimitiveFromArray<float,
257                                                 WireFormatLite::TYPE_FLOAT>(
258       buf, value);
259 }
260 
261 template <>
262 inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_FLOAT>(
263     const uint8* buf, double* value) {
264   float temp;
265   buf =
266       WireFormatLite::ReadPrimitiveFromArray<float, WireFormatLite::TYPE_FLOAT>(
267           buf, &temp);
268   *value = temp;
269   return buf;
270 }
271 
272 template <>
273 inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_DOUBLE>(
274     const uint8* buf, double* value) {
275   return WireFormatLite::ReadPrimitiveFromArray<double,
276                                                 WireFormatLite::TYPE_DOUBLE>(
277       buf, value);
278 }
279 
280 template <>
281 inline const uint8* ReadFromArray<bool, WireFormatLite::TYPE_BOOL>(
282     const uint8* buf, bool* value) {
283   uint64 temp = 0;
284   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
285   buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
286   *value = temp != 0;
287   return buf;
288 }
289 
290 template <>
291 inline const uint8* ReadFromArray<int, WireFormatLite::TYPE_ENUM>(
292     const uint8* buf, int* value) {
293   uint32 temp = 0;
294   bool unused_ok;  // The Counting pass would have failed if this were corrupt.
295   buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
296   *value = static_cast<int>(temp);
297   return buf;
298 }
299 
300 // Reads packed values from an array.
301 // Stride is set to 1 for repeated fields, and 0 for non-repeated fields
302 // (where any value overwrites previous values).
303 template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
ReadPackedPrimitives(const void * bufp,const size_t len,const int index,const int stride,void * datap)304 inline int ReadPackedPrimitives(const void* bufp, const size_t len,
305                                 const int index, const int stride,
306                                 void* datap) {
307   const uint8* buf = reinterpret_cast<const uint8*>(bufp);
308   const uint8* bound = buf + len;
309   TensorType* data = reinterpret_cast<TensorType*>(datap) + index;
310   int count;
311 
312   // This could overrun the bound by stride-1. This is defended
313   // against in the caller, where it ensures that the input buffer
314   // contains complete values.
315   for (count = 0; buf < bound; count += stride) {
316     buf = ReadFromArray<TensorType, DeclaredType>(buf, data + count);
317   }
318   return count;
319 }
320 
321 // Reads a value of a primitive type field from a serialized proto.
322 // The value is parsed from the serialized format, then static_cast
323 // to the desired type for TensorFlow and stored.
324 template <class ValueType, class TensorType,
325           enum WireFormatLite::FieldType DeclaredType>
ReadPrimitive(CodedInputStream * input,int index,void * data)326 inline Status ReadPrimitive(CodedInputStream* input, int index, void* data) {
327   ValueType v;
328   if (!WireFormatLite::ReadPrimitive<ValueType, DeclaredType>(input, &v)) {
329     return errors::DataLoss("Failed reading primitive");
330   }
331 
332   reinterpret_cast<TensorType*>(data)[index] = v;
333   return OkStatus();
334 }
335 
336 // Reads a string, submessage, or other variable-length field from a
337 // serialized proto.
338 // May read all or part of a repeated field.
ReadBytes(CodedInputStream * input,int index,void * datap)339 inline Status ReadBytes(CodedInputStream* input, int index, void* datap) {
340   tstring* data = reinterpret_cast<tstring*>(datap) + index;
341 
342   uint32 length;
343   if (!input->ReadVarint32(&length)) {
344     return errors::DataLoss("Failed reading bytes");
345   }
346 
347   data->resize_uninitialized(length);
348 
349   if (!input->ReadRaw(data->data(), length)) {
350     return errors::DataLoss("Failed reading bytes");
351   }
352   return OkStatus();
353 }
354 
355 // Reads a tag-delimited field (TYPE_GROUP) from a serialized proto,
356 // as a bytestring.
ReadGroupBytes(CodedInputStream * input,int field_number,int index,void * datap)357 inline Status ReadGroupBytes(CodedInputStream* input, int field_number,
358                              int index, void* datap) {
359   // WireFormatLite::SkipField has an option to emit the
360   // skipped bytes to an output stream. We could do better by implementing our
361   // own scanner but this is simpler for now.
362   // TODO(nix): there is a faster way to grab TYPE_GROUP bytes by relying
363   // on input->IsFlat() == true and using input->GetDirectBufferPointer()
364   // with input->CurrentPosition().
365   tstring* data = reinterpret_cast<tstring*>(datap) + index;
366   // TODO(dero): To mitigate the string to tstring copy, we can implement our
367   // own scanner as described above.  We would first need to obtain the length
368   // in an initial pass and resize/reserve the tstring. But, given that
369   // TYPE_GROUP is deprecated and currently no tests in
370   // tensorflow/python/kernel_tests/proto:decode_proto_op_test target a
371   // TYPE_GROUP tag, we use std::string as a read buffer.
372   string buf;
373   StringOutputStream string_stream(&buf);
374   {
375     CodedOutputStream out(&string_stream);
376     if (!WireFormatLite::SkipField(
377             input,
378             WireFormatLite::MakeTag(field_number,
379                                     WireFormatLite::WIRETYPE_START_GROUP),
380             &out)) {
381       return errors::DataLoss("Failed reading group");
382     }
383   }
384   *data = buf;
385   return OkStatus();
386 }
387 
388 // Reads a single field value from a CodedInputStream into a tensor.
ReadValue(CodedInputStream * input,WireFormatLite::FieldType field_type,int field_number,DataType dtype,int index,void * datap)389 inline Status ReadValue(CodedInputStream* input,
390                         WireFormatLite::FieldType field_type, int field_number,
391                         DataType dtype, int index, void* datap) {
392   // Dispatch to the appropriately typed field reader based on the schema type.
393   switch (field_type) {
394     case WireFormatLite::TYPE_DOUBLE:
395       return ReadPrimitive<double, double, WireFormatLite::TYPE_DOUBLE>(
396           input, index, datap);
397     case WireFormatLite::TYPE_FLOAT:
398       switch (dtype) {
399         case DataType::DT_DOUBLE:
400           return ReadPrimitive<float, double, WireFormatLite::TYPE_FLOAT>(
401               input, index, datap);
402         case DataType::DT_FLOAT:
403           return ReadPrimitive<float, float, WireFormatLite::TYPE_FLOAT>(
404               input, index, datap);
405         default:
406           return errors::DataLoss("Failed reading TYPE_FLOAT for ",
407                                   DataTypeString(dtype));
408       }
409     case WireFormatLite::TYPE_INT64:
410       return ReadPrimitive<protobuf_int64, int64_t, WireFormatLite::TYPE_INT64>(
411           input, index, datap);
412     case WireFormatLite::TYPE_UINT64:
413       return ReadPrimitive<protobuf_uint64, uint64,
414                            WireFormatLite::TYPE_UINT64>(input, index, datap);
415     case WireFormatLite::TYPE_INT32:
416       switch (dtype) {
417         case DataType::DT_INT64:
418           return ReadPrimitive<int32, int64_t, WireFormatLite::TYPE_INT32>(
419               input, index, datap);
420         case DataType::DT_INT32:
421           return ReadPrimitive<int32, int32, WireFormatLite::TYPE_INT32>(
422               input, index, datap);
423         default:
424           return errors::DataLoss("Failed reading TYPE_INT32 for ",
425                                   DataTypeString(dtype));
426       }
427     case WireFormatLite::TYPE_FIXED64:
428       return ReadPrimitive<protobuf_uint64, uint64,
429                            WireFormatLite::TYPE_FIXED64>(input, index, datap);
430     case WireFormatLite::TYPE_FIXED32:
431       switch (dtype) {
432         case DataType::DT_UINT64:
433           return ReadPrimitive<uint32, uint64, WireFormatLite::TYPE_FIXED32>(
434               input, index, datap);
435         case DataType::DT_UINT32:
436           return ReadPrimitive<uint32, uint32, WireFormatLite::TYPE_FIXED32>(
437               input, index, datap);
438         default:
439           return errors::DataLoss("Failed reading TYPE_FIXED32 for ",
440                                   DataTypeString(dtype));
441       }
442     case WireFormatLite::TYPE_BOOL:
443       return ReadPrimitive<bool, bool, WireFormatLite::TYPE_BOOL>(input, index,
444                                                                   datap);
445     case WireFormatLite::TYPE_STRING:
446       return ReadBytes(input, index, datap);
447     case WireFormatLite::TYPE_GROUP:
448       return ReadGroupBytes(input, field_number, index, datap);
449     case WireFormatLite::TYPE_MESSAGE:
450       return ReadBytes(input, index, datap);
451     case WireFormatLite::TYPE_BYTES:
452       return ReadBytes(input, index, datap);
453     case WireFormatLite::TYPE_UINT32:
454       switch (dtype) {
455         case DataType::DT_UINT64:
456           return ReadPrimitive<uint32, uint64, WireFormatLite::TYPE_UINT32>(
457               input, index, datap);
458         case DataType::DT_UINT32:
459           return ReadPrimitive<uint32, uint32, WireFormatLite::TYPE_UINT32>(
460               input, index, datap);
461         default:
462           return errors::DataLoss("Failed reading TYPE_UINT32 for ",
463                                   DataTypeString(dtype));
464       }
465     case WireFormatLite::TYPE_ENUM:
466       return ReadPrimitive<int32, int32, WireFormatLite::TYPE_ENUM>(
467           input, index, datap);
468     case WireFormatLite::TYPE_SFIXED32:
469       switch (dtype) {
470         case DataType::DT_INT64:
471           return ReadPrimitive<int32, int64_t, WireFormatLite::TYPE_SFIXED32>(
472               input, index, datap);
473         case DataType::DT_INT32:
474           return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SFIXED32>(
475               input, index, datap);
476         default:
477           return errors::DataLoss("Failed reading TYPE_SFIXED32 for ",
478                                   DataTypeString(dtype));
479       }
480     case WireFormatLite::TYPE_SFIXED64:
481       return ReadPrimitive<protobuf_int64, int64_t,
482                            WireFormatLite::TYPE_SFIXED64>(input, index, datap);
483     case WireFormatLite::TYPE_SINT32:
484       switch (dtype) {
485         case DataType::DT_INT64:
486           return ReadPrimitive<int32, int64_t, WireFormatLite::TYPE_SINT32>(
487               input, index, datap);
488         case DataType::DT_INT32:
489           return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SINT32>(
490               input, index, datap);
491         default:
492           return errors::DataLoss("Failed reading TYPE_SINT32 for ",
493                                   DataTypeString(dtype));
494       }
495     case WireFormatLite::TYPE_SINT64:
496       return ReadPrimitive<protobuf_int64, int64_t,
497                            WireFormatLite::TYPE_SINT64>(input, index, datap);
498       // default: intentionally omitted in order to enable static checking.
499   }
500   // Unreachable.
501   return errors::DataLoss("Failed reading unknown wire type");
502 }
503 
504 // Reads and stores a length-delimited list of values.
ReadPackedFromArray(const void * buf,size_t buf_size,const WireFormatLite::FieldType field_type,const int field_number,const DataType dtype,const int stride,int * index,void * data)505 inline Status ReadPackedFromArray(const void* buf, size_t buf_size,
506                                   const WireFormatLite::FieldType field_type,
507                                   const int field_number, const DataType dtype,
508                                   const int stride, int* index, void* data) {
509   // Dispatch to the appropriately typed field reader based on the schema type.
510   switch (field_type) {
511     case WireFormatLite::TYPE_DOUBLE:
512       *index += ReadPackedPrimitives<double, WireFormatLite::TYPE_DOUBLE>(
513           buf, buf_size, *index, stride, data);
514       return OkStatus();
515     case WireFormatLite::TYPE_FLOAT:
516       switch (dtype) {
517         case DataType::DT_DOUBLE:
518           *index += ReadPackedPrimitives<double, WireFormatLite::TYPE_FLOAT>(
519               buf, buf_size, *index, stride, data);
520           return OkStatus();
521         case DataType::DT_FLOAT:
522           *index += ReadPackedPrimitives<float, WireFormatLite::TYPE_FLOAT>(
523               buf, buf_size, *index, stride, data);
524           return OkStatus();
525         default:
526           return errors::DataLoss("Failed reading TYPE_FLOAT for ",
527                                   DataTypeString(dtype));
528       }
529     case WireFormatLite::TYPE_INT64:
530       *index += ReadPackedPrimitives<int64_t, WireFormatLite::TYPE_INT64>(
531           buf, buf_size, *index, stride, data);
532       return OkStatus();
533     case WireFormatLite::TYPE_UINT64:
534       *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_UINT64>(
535           buf, buf_size, *index, stride, data);
536       return OkStatus();
537     case WireFormatLite::TYPE_INT32:
538       switch (dtype) {
539         case DataType::DT_INT64:
540           *index += ReadPackedPrimitives<int64_t, WireFormatLite::TYPE_INT32>(
541               buf, buf_size, *index, stride, data);
542           return OkStatus();
543         case DataType::DT_INT32:
544           *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_INT32>(
545               buf, buf_size, *index, stride, data);
546           return OkStatus();
547         default:
548           return errors::DataLoss("Failed reading TYPE_INT32 for ",
549                                   DataTypeString(dtype));
550       }
551     case WireFormatLite::TYPE_FIXED64:
552       *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_FIXED64>(
553           buf, buf_size, *index, stride, data);
554       return OkStatus();
555     case WireFormatLite::TYPE_FIXED32:
556       switch (dtype) {
557         case DataType::DT_UINT64:
558           *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_FIXED32>(
559               buf, buf_size, *index, stride, data);
560           return OkStatus();
561         case DataType::DT_UINT32:
562           *index += ReadPackedPrimitives<uint32, WireFormatLite::TYPE_FIXED32>(
563               buf, buf_size, *index, stride, data);
564           return OkStatus();
565         default:
566           return errors::DataLoss("Failed reading TYPE_FIXED32 for ",
567                                   DataTypeString(dtype));
568       }
569     case WireFormatLite::TYPE_BOOL:
570       *index += ReadPackedPrimitives<bool, WireFormatLite::TYPE_BOOL>(
571           buf, buf_size, *index, stride, data);
572       return OkStatus();
573     case WireFormatLite::TYPE_STRING:
574     case WireFormatLite::TYPE_GROUP:
575     case WireFormatLite::TYPE_MESSAGE:
576     case WireFormatLite::TYPE_BYTES:
577       return errors::DataLoss("Non-primitive type encountered as packed");
578     case WireFormatLite::TYPE_UINT32:
579       switch (dtype) {
580         case DataType::DT_UINT64:
581           *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_UINT32>(
582               buf, buf_size, *index, stride, data);
583           return OkStatus();
584         case DataType::DT_UINT32:
585           *index += ReadPackedPrimitives<uint32, WireFormatLite::TYPE_UINT32>(
586               buf, buf_size, *index, stride, data);
587           return OkStatus();
588         default:
589           return errors::DataLoss("Failed reading TYPE_UINT32 for ",
590                                   DataTypeString(dtype));
591       }
592     case WireFormatLite::TYPE_ENUM:
593       *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_ENUM>(
594           buf, buf_size, *index, stride, data);
595       return OkStatus();
596     case WireFormatLite::TYPE_SFIXED32:
597       switch (dtype) {
598         case DataType::DT_INT64:
599           *index +=
600               ReadPackedPrimitives<int64_t, WireFormatLite::TYPE_SFIXED32>(
601                   buf, buf_size, *index, stride, data);
602           return OkStatus();
603         case DataType::DT_INT32:
604           *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SFIXED32>(
605               buf, buf_size, *index, stride, data);
606           return OkStatus();
607         default:
608           return errors::DataLoss("Failed reading TYPE_INT32 for ",
609                                   DataTypeString(dtype));
610       }
611     case WireFormatLite::TYPE_SFIXED64:
612       *index += ReadPackedPrimitives<int64_t, WireFormatLite::TYPE_SFIXED64>(
613           buf, buf_size, *index, stride, data);
614       return OkStatus();
615 
616     case WireFormatLite::TYPE_SINT32:
617       switch (dtype) {
618         case DataType::DT_INT64:
619           *index += ReadPackedPrimitives<int64_t, WireFormatLite::TYPE_SINT32>(
620               buf, buf_size, *index, stride, data);
621           return OkStatus();
622         case DataType::DT_INT32:
623           *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SINT32>(
624               buf, buf_size, *index, stride, data);
625           return OkStatus();
626         default:
627           return errors::DataLoss("Failed reading TYPE_SINT32 for ",
628                                   DataTypeString(dtype));
629       }
630     case WireFormatLite::TYPE_SINT64:
631       *index += ReadPackedPrimitives<int64_t, WireFormatLite::TYPE_SINT64>(
632           buf, buf_size, *index, stride, data);
633       return OkStatus();
634       // default: intentionally omitted in order to enable static checking.
635   }
636   // Unreachable.
637   return errors::DataLoss("Failed reading unknown wire type");
638 }
639 
640 // Reads a varint from the given buffer, write it to *value, and return the
641 // new buffer pointer.
642 // This was copied from coded_stream.cc where it is private.
643 // Important: This routine may read as much as kMaxVarintBytes from
644 // the buffer. It is the caller's responsibility to make sure that there is
645 // enough space in the buffer.
ReadVarint64FromArray(const uint8 * buffer,bool * ok,uint64 * value)646 inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok,
647                                           uint64* value) {
648   const uint8* ptr = buffer;
649   uint32 b;
650 
651   // Splitting into 32-bit pieces gives better performance on 32-bit
652   // processors.
653   uint32 part0 = 0, part1 = 0, part2 = 0;
654 
655   b = *(ptr++);
656   part0 = b;
657   if (!(b & 0x80)) goto done;
658   part0 -= 0x80;
659   b = *(ptr++);
660   part0 += b << 7;
661   if (!(b & 0x80)) goto done;
662   part0 -= 0x80 << 7;
663   b = *(ptr++);
664   part0 += b << 14;
665   if (!(b & 0x80)) goto done;
666   part0 -= 0x80 << 14;
667   b = *(ptr++);
668   part0 += b << 21;
669   if (!(b & 0x80)) goto done;
670   part0 -= 0x80 << 21;
671   b = *(ptr++);
672   part1 = b;
673   if (!(b & 0x80)) goto done;
674   part1 -= 0x80;
675   b = *(ptr++);
676   part1 += b << 7;
677   if (!(b & 0x80)) goto done;
678   part1 -= 0x80 << 7;
679   b = *(ptr++);
680   part1 += b << 14;
681   if (!(b & 0x80)) goto done;
682   part1 -= 0x80 << 14;
683   b = *(ptr++);
684   part1 += b << 21;
685   if (!(b & 0x80)) goto done;
686   part1 -= 0x80 << 21;
687   b = *(ptr++);
688   part2 = b;
689   if (!(b & 0x80)) goto done;
690   part2 -= 0x80;
691   b = *(ptr++);
692   part2 += b << 7;
693   if (!(b & 0x80)) goto done;
694   // "part2 -= 0x80 << 7" is irrelevant because (0x80 << 7) << 56 is 0.
695 
696   // We have overrun the maximum size of a varint (10 bytes).  Assume
697   // the data is corrupt.
698   *ok = false;
699   return ptr;
700 
701 done:
702   *ok = true;
703   *value = (static_cast<uint64>(part0)) | (static_cast<uint64>(part1) << 28) |
704            (static_cast<uint64>(part2) << 56);
705   return ptr;
706 }
707 
708 }  // namespace internal
709 }  // namespace tensorflow
710 
711 #endif  // TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
712