xref: /aosp_15_r20/external/cronet/third_party/rust/chromium_crates_io/vendor/prost-0.12.3/src/encoding.rs (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 //! Utility functions and types for encoding and decoding Protobuf types.
2 //!
3 //! Meant to be used only from `Message` implementations.
4 
5 #![allow(clippy::implicit_hasher, clippy::ptr_arg)]
6 
7 use alloc::collections::BTreeMap;
8 use alloc::format;
9 use alloc::string::String;
10 use alloc::vec::Vec;
11 use core::cmp::min;
12 use core::convert::TryFrom;
13 use core::mem;
14 use core::str;
15 use core::u32;
16 use core::usize;
17 
18 use ::bytes::{Buf, BufMut, Bytes};
19 
20 use crate::DecodeError;
21 use crate::Message;
22 
23 /// Encodes an integer value into LEB128 variable length format, and writes it to the buffer.
24 /// The buffer must have enough remaining space (maximum 10 bytes).
25 #[inline]
encode_varint<B>(mut value: u64, buf: &mut B) where B: BufMut,26 pub fn encode_varint<B>(mut value: u64, buf: &mut B)
27 where
28     B: BufMut,
29 {
30     loop {
31         if value < 0x80 {
32             buf.put_u8(value as u8);
33             break;
34         } else {
35             buf.put_u8(((value & 0x7F) | 0x80) as u8);
36             value >>= 7;
37         }
38     }
39 }
40 
41 /// Decodes a LEB128-encoded variable length integer from the buffer.
42 #[inline]
decode_varint<B>(buf: &mut B) -> Result<u64, DecodeError> where B: Buf,43 pub fn decode_varint<B>(buf: &mut B) -> Result<u64, DecodeError>
44 where
45     B: Buf,
46 {
47     let bytes = buf.chunk();
48     let len = bytes.len();
49     if len == 0 {
50         return Err(DecodeError::new("invalid varint"));
51     }
52 
53     let byte = bytes[0];
54     if byte < 0x80 {
55         buf.advance(1);
56         Ok(u64::from(byte))
57     } else if len > 10 || bytes[len - 1] < 0x80 {
58         let (value, advance) = decode_varint_slice(bytes)?;
59         buf.advance(advance);
60         Ok(value)
61     } else {
62         decode_varint_slow(buf)
63     }
64 }
65 
66 /// Decodes a LEB128-encoded variable length integer from the slice, returning the value and the
67 /// number of bytes read.
68 ///
69 /// Based loosely on [`ReadVarint64FromArray`][1] with a varint overflow check from
70 /// [`ConsumeVarint`][2].
71 ///
72 /// ## Safety
73 ///
74 /// The caller must ensure that `bytes` is non-empty and either `bytes.len() >= 10` or the last
75 /// element in bytes is < `0x80`.
76 ///
77 /// [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.cc#L365-L406
78 /// [2]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
79 #[inline]
decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), DecodeError>80 fn decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), DecodeError> {
81     // Fully unrolled varint decoding loop. Splitting into 32-bit pieces gives better performance.
82 
83     // Use assertions to ensure memory safety, but it should always be optimized after inline.
84     assert!(!bytes.is_empty());
85     assert!(bytes.len() > 10 || bytes[bytes.len() - 1] < 0x80);
86 
87     let mut b: u8 = unsafe { *bytes.get_unchecked(0) };
88     let mut part0: u32 = u32::from(b);
89     if b < 0x80 {
90         return Ok((u64::from(part0), 1));
91     };
92     part0 -= 0x80;
93     b = unsafe { *bytes.get_unchecked(1) };
94     part0 += u32::from(b) << 7;
95     if b < 0x80 {
96         return Ok((u64::from(part0), 2));
97     };
98     part0 -= 0x80 << 7;
99     b = unsafe { *bytes.get_unchecked(2) };
100     part0 += u32::from(b) << 14;
101     if b < 0x80 {
102         return Ok((u64::from(part0), 3));
103     };
104     part0 -= 0x80 << 14;
105     b = unsafe { *bytes.get_unchecked(3) };
106     part0 += u32::from(b) << 21;
107     if b < 0x80 {
108         return Ok((u64::from(part0), 4));
109     };
110     part0 -= 0x80 << 21;
111     let value = u64::from(part0);
112 
113     b = unsafe { *bytes.get_unchecked(4) };
114     let mut part1: u32 = u32::from(b);
115     if b < 0x80 {
116         return Ok((value + (u64::from(part1) << 28), 5));
117     };
118     part1 -= 0x80;
119     b = unsafe { *bytes.get_unchecked(5) };
120     part1 += u32::from(b) << 7;
121     if b < 0x80 {
122         return Ok((value + (u64::from(part1) << 28), 6));
123     };
124     part1 -= 0x80 << 7;
125     b = unsafe { *bytes.get_unchecked(6) };
126     part1 += u32::from(b) << 14;
127     if b < 0x80 {
128         return Ok((value + (u64::from(part1) << 28), 7));
129     };
130     part1 -= 0x80 << 14;
131     b = unsafe { *bytes.get_unchecked(7) };
132     part1 += u32::from(b) << 21;
133     if b < 0x80 {
134         return Ok((value + (u64::from(part1) << 28), 8));
135     };
136     part1 -= 0x80 << 21;
137     let value = value + ((u64::from(part1)) << 28);
138 
139     b = unsafe { *bytes.get_unchecked(8) };
140     let mut part2: u32 = u32::from(b);
141     if b < 0x80 {
142         return Ok((value + (u64::from(part2) << 56), 9));
143     };
144     part2 -= 0x80;
145     b = unsafe { *bytes.get_unchecked(9) };
146     part2 += u32::from(b) << 7;
147     // Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details.
148     // [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
149     if b < 0x02 {
150         return Ok((value + (u64::from(part2) << 56), 10));
151     };
152 
153     // We have overrun the maximum size of a varint (10 bytes) or the final byte caused an overflow.
154     // Assume the data is corrupt.
155     Err(DecodeError::new("invalid varint"))
156 }
157 
158 /// Decodes a LEB128-encoded variable length integer from the buffer, advancing the buffer as
159 /// necessary.
160 ///
161 /// Contains a varint overflow check from [`ConsumeVarint`][1].
162 ///
163 /// [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
164 #[inline(never)]
165 #[cold]
decode_varint_slow<B>(buf: &mut B) -> Result<u64, DecodeError> where B: Buf,166 fn decode_varint_slow<B>(buf: &mut B) -> Result<u64, DecodeError>
167 where
168     B: Buf,
169 {
170     let mut value = 0;
171     for count in 0..min(10, buf.remaining()) {
172         let byte = buf.get_u8();
173         value |= u64::from(byte & 0x7F) << (count * 7);
174         if byte <= 0x7F {
175             // Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details.
176             // [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
177             if count == 9 && byte >= 0x02 {
178                 return Err(DecodeError::new("invalid varint"));
179             } else {
180                 return Ok(value);
181             }
182         }
183     }
184 
185     Err(DecodeError::new("invalid varint"))
186 }
187 
188 /// Additional information passed to every decode/merge function.
189 ///
190 /// The context should be passed by value and can be freely cloned. When passing
191 /// to a function which is decoding a nested object, then use `enter_recursion`.
192 #[derive(Clone, Debug)]
193 #[cfg_attr(feature = "no-recursion-limit", derive(Default))]
194 pub struct DecodeContext {
195     /// How many times we can recurse in the current decode stack before we hit
196     /// the recursion limit.
197     ///
198     /// The recursion limit is defined by `RECURSION_LIMIT` and cannot be
199     /// customized. The recursion limit can be ignored by building the Prost
200     /// crate with the `no-recursion-limit` feature.
201     #[cfg(not(feature = "no-recursion-limit"))]
202     recurse_count: u32,
203 }
204 
205 #[cfg(not(feature = "no-recursion-limit"))]
206 impl Default for DecodeContext {
207     #[inline]
default() -> DecodeContext208     fn default() -> DecodeContext {
209         DecodeContext {
210             recurse_count: crate::RECURSION_LIMIT,
211         }
212     }
213 }
214 
215 impl DecodeContext {
216     /// Call this function before recursively decoding.
217     ///
218     /// There is no `exit` function since this function creates a new `DecodeContext`
219     /// to be used at the next level of recursion. Continue to use the old context
220     // at the previous level of recursion.
221     #[cfg(not(feature = "no-recursion-limit"))]
222     #[inline]
enter_recursion(&self) -> DecodeContext223     pub(crate) fn enter_recursion(&self) -> DecodeContext {
224         DecodeContext {
225             recurse_count: self.recurse_count - 1,
226         }
227     }
228 
229     #[cfg(feature = "no-recursion-limit")]
230     #[inline]
enter_recursion(&self) -> DecodeContext231     pub(crate) fn enter_recursion(&self) -> DecodeContext {
232         DecodeContext {}
233     }
234 
235     /// Checks whether the recursion limit has been reached in the stack of
236     /// decodes described by the `DecodeContext` at `self.ctx`.
237     ///
238     /// Returns `Ok<()>` if it is ok to continue recursing.
239     /// Returns `Err<DecodeError>` if the recursion limit has been reached.
240     #[cfg(not(feature = "no-recursion-limit"))]
241     #[inline]
limit_reached(&self) -> Result<(), DecodeError>242     pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
243         if self.recurse_count == 0 {
244             Err(DecodeError::new("recursion limit reached"))
245         } else {
246             Ok(())
247         }
248     }
249 
250     #[cfg(feature = "no-recursion-limit")]
251     #[inline]
252     #[allow(clippy::unnecessary_wraps)] // needed in other features
limit_reached(&self) -> Result<(), DecodeError>253     pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
254         Ok(())
255     }
256 }
257 
258 /// Returns the encoded length of the value in LEB128 variable length format.
259 /// The returned value will be between 1 and 10, inclusive.
260 #[inline]
encoded_len_varint(value: u64) -> usize261 pub fn encoded_len_varint(value: u64) -> usize {
262     // Based on [VarintSize64][1].
263     // [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.h#L1301-L1309
264     ((((value | 1).leading_zeros() ^ 63) * 9 + 73) / 64) as usize
265 }
266 
267 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
268 #[repr(u8)]
269 pub enum WireType {
270     Varint = 0,
271     SixtyFourBit = 1,
272     LengthDelimited = 2,
273     StartGroup = 3,
274     EndGroup = 4,
275     ThirtyTwoBit = 5,
276 }
277 
278 pub const MIN_TAG: u32 = 1;
279 pub const MAX_TAG: u32 = (1 << 29) - 1;
280 
281 impl TryFrom<u64> for WireType {
282     type Error = DecodeError;
283 
284     #[inline]
try_from(value: u64) -> Result<Self, Self::Error>285     fn try_from(value: u64) -> Result<Self, Self::Error> {
286         match value {
287             0 => Ok(WireType::Varint),
288             1 => Ok(WireType::SixtyFourBit),
289             2 => Ok(WireType::LengthDelimited),
290             3 => Ok(WireType::StartGroup),
291             4 => Ok(WireType::EndGroup),
292             5 => Ok(WireType::ThirtyTwoBit),
293             _ => Err(DecodeError::new(format!(
294                 "invalid wire type value: {}",
295                 value
296             ))),
297         }
298     }
299 }
300 
301 /// Encodes a Protobuf field key, which consists of a wire type designator and
302 /// the field tag.
303 #[inline]
encode_key<B>(tag: u32, wire_type: WireType, buf: &mut B) where B: BufMut,304 pub fn encode_key<B>(tag: u32, wire_type: WireType, buf: &mut B)
305 where
306     B: BufMut,
307 {
308     debug_assert!((MIN_TAG..=MAX_TAG).contains(&tag));
309     let key = (tag << 3) | wire_type as u32;
310     encode_varint(u64::from(key), buf);
311 }
312 
313 /// Decodes a Protobuf field key, which consists of a wire type designator and
314 /// the field tag.
315 #[inline(always)]
decode_key<B>(buf: &mut B) -> Result<(u32, WireType), DecodeError> where B: Buf,316 pub fn decode_key<B>(buf: &mut B) -> Result<(u32, WireType), DecodeError>
317 where
318     B: Buf,
319 {
320     let key = decode_varint(buf)?;
321     if key > u64::from(u32::MAX) {
322         return Err(DecodeError::new(format!("invalid key value: {}", key)));
323     }
324     let wire_type = WireType::try_from(key & 0x07)?;
325     let tag = key as u32 >> 3;
326 
327     if tag < MIN_TAG {
328         return Err(DecodeError::new("invalid tag value: 0"));
329     }
330 
331     Ok((tag, wire_type))
332 }
333 
334 /// Returns the width of an encoded Protobuf field key with the given tag.
335 /// The returned width will be between 1 and 5 bytes (inclusive).
336 #[inline]
key_len(tag: u32) -> usize337 pub fn key_len(tag: u32) -> usize {
338     encoded_len_varint(u64::from(tag << 3))
339 }
340 
341 /// Checks that the expected wire type matches the actual wire type,
342 /// or returns an error result.
343 #[inline]
check_wire_type(expected: WireType, actual: WireType) -> Result<(), DecodeError>344 pub fn check_wire_type(expected: WireType, actual: WireType) -> Result<(), DecodeError> {
345     if expected != actual {
346         return Err(DecodeError::new(format!(
347             "invalid wire type: {:?} (expected {:?})",
348             actual, expected
349         )));
350     }
351     Ok(())
352 }
353 
354 /// Helper function which abstracts reading a length delimiter prefix followed
355 /// by decoding values until the length of bytes is exhausted.
merge_loop<T, M, B>( value: &mut T, buf: &mut B, ctx: DecodeContext, mut merge: M, ) -> Result<(), DecodeError> where M: FnMut(&mut T, &mut B, DecodeContext) -> Result<(), DecodeError>, B: Buf,356 pub fn merge_loop<T, M, B>(
357     value: &mut T,
358     buf: &mut B,
359     ctx: DecodeContext,
360     mut merge: M,
361 ) -> Result<(), DecodeError>
362 where
363     M: FnMut(&mut T, &mut B, DecodeContext) -> Result<(), DecodeError>,
364     B: Buf,
365 {
366     let len = decode_varint(buf)?;
367     let remaining = buf.remaining();
368     if len > remaining as u64 {
369         return Err(DecodeError::new("buffer underflow"));
370     }
371 
372     let limit = remaining - len as usize;
373     while buf.remaining() > limit {
374         merge(value, buf, ctx.clone())?;
375     }
376 
377     if buf.remaining() != limit {
378         return Err(DecodeError::new("delimited length exceeded"));
379     }
380     Ok(())
381 }
382 
skip_field<B>( wire_type: WireType, tag: u32, buf: &mut B, ctx: DecodeContext, ) -> Result<(), DecodeError> where B: Buf,383 pub fn skip_field<B>(
384     wire_type: WireType,
385     tag: u32,
386     buf: &mut B,
387     ctx: DecodeContext,
388 ) -> Result<(), DecodeError>
389 where
390     B: Buf,
391 {
392     ctx.limit_reached()?;
393     let len = match wire_type {
394         WireType::Varint => decode_varint(buf).map(|_| 0)?,
395         WireType::ThirtyTwoBit => 4,
396         WireType::SixtyFourBit => 8,
397         WireType::LengthDelimited => decode_varint(buf)?,
398         WireType::StartGroup => loop {
399             let (inner_tag, inner_wire_type) = decode_key(buf)?;
400             match inner_wire_type {
401                 WireType::EndGroup => {
402                     if inner_tag != tag {
403                         return Err(DecodeError::new("unexpected end group tag"));
404                     }
405                     break 0;
406                 }
407                 _ => skip_field(inner_wire_type, inner_tag, buf, ctx.enter_recursion())?,
408             }
409         },
410         WireType::EndGroup => return Err(DecodeError::new("unexpected end group tag")),
411     };
412 
413     if len > buf.remaining() as u64 {
414         return Err(DecodeError::new("buffer underflow"));
415     }
416 
417     buf.advance(len as usize);
418     Ok(())
419 }
420 
421 /// Helper macro which emits an `encode_repeated` function for the type.
422 macro_rules! encode_repeated {
423     ($ty:ty) => {
424         pub fn encode_repeated<B>(tag: u32, values: &[$ty], buf: &mut B)
425         where
426             B: BufMut,
427         {
428             for value in values {
429                 encode(tag, value, buf);
430             }
431         }
432     };
433 }
434 
435 /// Helper macro which emits a `merge_repeated` function for the numeric type.
436 macro_rules! merge_repeated_numeric {
437     ($ty:ty,
438      $wire_type:expr,
439      $merge:ident,
440      $merge_repeated:ident) => {
441         pub fn $merge_repeated<B>(
442             wire_type: WireType,
443             values: &mut Vec<$ty>,
444             buf: &mut B,
445             ctx: DecodeContext,
446         ) -> Result<(), DecodeError>
447         where
448             B: Buf,
449         {
450             if wire_type == WireType::LengthDelimited {
451                 // Packed.
452                 merge_loop(values, buf, ctx, |values, buf, ctx| {
453                     let mut value = Default::default();
454                     $merge($wire_type, &mut value, buf, ctx)?;
455                     values.push(value);
456                     Ok(())
457                 })
458             } else {
459                 // Unpacked.
460                 check_wire_type($wire_type, wire_type)?;
461                 let mut value = Default::default();
462                 $merge(wire_type, &mut value, buf, ctx)?;
463                 values.push(value);
464                 Ok(())
465             }
466         }
467     };
468 }
469 
470 /// Macro which emits a module containing a set of encoding functions for a
471 /// variable width numeric type.
472 macro_rules! varint {
473     ($ty:ty,
474      $proto_ty:ident) => (
475         varint!($ty,
476                 $proto_ty,
477                 to_uint64(value) { *value as u64 },
478                 from_uint64(value) { value as $ty });
479     );
480 
481     ($ty:ty,
482      $proto_ty:ident,
483      to_uint64($to_uint64_value:ident) $to_uint64:expr,
484      from_uint64($from_uint64_value:ident) $from_uint64:expr) => (
485 
486          pub mod $proto_ty {
487             use crate::encoding::*;
488 
489             pub fn encode<B>(tag: u32, $to_uint64_value: &$ty, buf: &mut B) where B: BufMut {
490                 encode_key(tag, WireType::Varint, buf);
491                 encode_varint($to_uint64, buf);
492             }
493 
494             pub fn merge<B>(wire_type: WireType, value: &mut $ty, buf: &mut B, _ctx: DecodeContext) -> Result<(), DecodeError> where B: Buf {
495                 check_wire_type(WireType::Varint, wire_type)?;
496                 let $from_uint64_value = decode_varint(buf)?;
497                 *value = $from_uint64;
498                 Ok(())
499             }
500 
501             encode_repeated!($ty);
502 
503             pub fn encode_packed<B>(tag: u32, values: &[$ty], buf: &mut B) where B: BufMut {
504                 if values.is_empty() { return; }
505 
506                 encode_key(tag, WireType::LengthDelimited, buf);
507                 let len: usize = values.iter().map(|$to_uint64_value| {
508                     encoded_len_varint($to_uint64)
509                 }).sum();
510                 encode_varint(len as u64, buf);
511 
512                 for $to_uint64_value in values {
513                     encode_varint($to_uint64, buf);
514                 }
515             }
516 
517             merge_repeated_numeric!($ty, WireType::Varint, merge, merge_repeated);
518 
519             #[inline]
520             pub fn encoded_len(tag: u32, $to_uint64_value: &$ty) -> usize {
521                 key_len(tag) + encoded_len_varint($to_uint64)
522             }
523 
524             #[inline]
525             pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
526                 key_len(tag) * values.len() + values.iter().map(|$to_uint64_value| {
527                     encoded_len_varint($to_uint64)
528                 }).sum::<usize>()
529             }
530 
531             #[inline]
532             pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize {
533                 if values.is_empty() {
534                     0
535                 } else {
536                     let len = values.iter()
537                                     .map(|$to_uint64_value| encoded_len_varint($to_uint64))
538                                     .sum::<usize>();
539                     key_len(tag) + encoded_len_varint(len as u64) + len
540                 }
541             }
542 
543             #[cfg(test)]
544             mod test {
545                 use proptest::prelude::*;
546 
547                 use crate::encoding::$proto_ty::*;
548                 use crate::encoding::test::{
549                     check_collection_type,
550                     check_type,
551                 };
552 
553                 proptest! {
554                     #[test]
555                     fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) {
556                         check_type(value, tag, WireType::Varint,
557                                    encode, merge, encoded_len)?;
558                     }
559                     #[test]
560                     fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
561                         check_collection_type(value, tag, WireType::Varint,
562                                               encode_repeated, merge_repeated,
563                                               encoded_len_repeated)?;
564                     }
565                     #[test]
566                     fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
567                         check_type(value, tag, WireType::LengthDelimited,
568                                    encode_packed, merge_repeated,
569                                    encoded_len_packed)?;
570                     }
571                 }
572             }
573          }
574 
575     );
576 }
577 varint!(bool, bool,
578         to_uint64(value) u64::from(*value),
579         from_uint64(value) value != 0);
580 varint!(i32, int32);
581 varint!(i64, int64);
582 varint!(u32, uint32);
583 varint!(u64, uint64);
584 varint!(i32, sint32,
585 to_uint64(value) {
586     ((value << 1) ^ (value >> 31)) as u32 as u64
587 },
588 from_uint64(value) {
589     let value = value as u32;
590     ((value >> 1) as i32) ^ (-((value & 1) as i32))
591 });
592 varint!(i64, sint64,
593 to_uint64(value) {
594     ((value << 1) ^ (value >> 63)) as u64
595 },
596 from_uint64(value) {
597     ((value >> 1) as i64) ^ (-((value & 1) as i64))
598 });
599 
600 /// Macro which emits a module containing a set of encoding functions for a
601 /// fixed width numeric type.
602 macro_rules! fixed_width {
603     ($ty:ty,
604      $width:expr,
605      $wire_type:expr,
606      $proto_ty:ident,
607      $put:ident,
608      $get:ident) => {
609         pub mod $proto_ty {
610             use crate::encoding::*;
611 
612             pub fn encode<B>(tag: u32, value: &$ty, buf: &mut B)
613             where
614                 B: BufMut,
615             {
616                 encode_key(tag, $wire_type, buf);
617                 buf.$put(*value);
618             }
619 
620             pub fn merge<B>(
621                 wire_type: WireType,
622                 value: &mut $ty,
623                 buf: &mut B,
624                 _ctx: DecodeContext,
625             ) -> Result<(), DecodeError>
626             where
627                 B: Buf,
628             {
629                 check_wire_type($wire_type, wire_type)?;
630                 if buf.remaining() < $width {
631                     return Err(DecodeError::new("buffer underflow"));
632                 }
633                 *value = buf.$get();
634                 Ok(())
635             }
636 
637             encode_repeated!($ty);
638 
639             pub fn encode_packed<B>(tag: u32, values: &[$ty], buf: &mut B)
640             where
641                 B: BufMut,
642             {
643                 if values.is_empty() {
644                     return;
645                 }
646 
647                 encode_key(tag, WireType::LengthDelimited, buf);
648                 let len = values.len() as u64 * $width;
649                 encode_varint(len as u64, buf);
650 
651                 for value in values {
652                     buf.$put(*value);
653                 }
654             }
655 
656             merge_repeated_numeric!($ty, $wire_type, merge, merge_repeated);
657 
658             #[inline]
659             pub fn encoded_len(tag: u32, _: &$ty) -> usize {
660                 key_len(tag) + $width
661             }
662 
663             #[inline]
664             pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
665                 (key_len(tag) + $width) * values.len()
666             }
667 
668             #[inline]
669             pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize {
670                 if values.is_empty() {
671                     0
672                 } else {
673                     let len = $width * values.len();
674                     key_len(tag) + encoded_len_varint(len as u64) + len
675                 }
676             }
677 
678             #[cfg(test)]
679             mod test {
680                 use proptest::prelude::*;
681 
682                 use super::super::test::{check_collection_type, check_type};
683                 use super::*;
684 
685                 proptest! {
686                     #[test]
687                     fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) {
688                         check_type(value, tag, $wire_type,
689                                    encode, merge, encoded_len)?;
690                     }
691                     #[test]
692                     fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
693                         check_collection_type(value, tag, $wire_type,
694                                               encode_repeated, merge_repeated,
695                                               encoded_len_repeated)?;
696                     }
697                     #[test]
698                     fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
699                         check_type(value, tag, WireType::LengthDelimited,
700                                    encode_packed, merge_repeated,
701                                    encoded_len_packed)?;
702                     }
703                 }
704             }
705         }
706     };
707 }
708 fixed_width!(
709     f32,
710     4,
711     WireType::ThirtyTwoBit,
712     float,
713     put_f32_le,
714     get_f32_le
715 );
716 fixed_width!(
717     f64,
718     8,
719     WireType::SixtyFourBit,
720     double,
721     put_f64_le,
722     get_f64_le
723 );
724 fixed_width!(
725     u32,
726     4,
727     WireType::ThirtyTwoBit,
728     fixed32,
729     put_u32_le,
730     get_u32_le
731 );
732 fixed_width!(
733     u64,
734     8,
735     WireType::SixtyFourBit,
736     fixed64,
737     put_u64_le,
738     get_u64_le
739 );
740 fixed_width!(
741     i32,
742     4,
743     WireType::ThirtyTwoBit,
744     sfixed32,
745     put_i32_le,
746     get_i32_le
747 );
748 fixed_width!(
749     i64,
750     8,
751     WireType::SixtyFourBit,
752     sfixed64,
753     put_i64_le,
754     get_i64_le
755 );
756 
757 /// Macro which emits encoding functions for a length-delimited type.
758 macro_rules! length_delimited {
759     ($ty:ty) => {
760         encode_repeated!($ty);
761 
762         pub fn merge_repeated<B>(
763             wire_type: WireType,
764             values: &mut Vec<$ty>,
765             buf: &mut B,
766             ctx: DecodeContext,
767         ) -> Result<(), DecodeError>
768         where
769             B: Buf,
770         {
771             check_wire_type(WireType::LengthDelimited, wire_type)?;
772             let mut value = Default::default();
773             merge(wire_type, &mut value, buf, ctx)?;
774             values.push(value);
775             Ok(())
776         }
777 
778         #[inline]
779         pub fn encoded_len(tag: u32, value: &$ty) -> usize {
780             key_len(tag) + encoded_len_varint(value.len() as u64) + value.len()
781         }
782 
783         #[inline]
784         pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
785             key_len(tag) * values.len()
786                 + values
787                     .iter()
788                     .map(|value| encoded_len_varint(value.len() as u64) + value.len())
789                     .sum::<usize>()
790         }
791     };
792 }
793 
794 pub mod string {
795     use super::*;
796 
encode<B>(tag: u32, value: &String, buf: &mut B) where B: BufMut,797     pub fn encode<B>(tag: u32, value: &String, buf: &mut B)
798     where
799         B: BufMut,
800     {
801         encode_key(tag, WireType::LengthDelimited, buf);
802         encode_varint(value.len() as u64, buf);
803         buf.put_slice(value.as_bytes());
804     }
merge<B>( wire_type: WireType, value: &mut String, buf: &mut B, ctx: DecodeContext, ) -> Result<(), DecodeError> where B: Buf,805     pub fn merge<B>(
806         wire_type: WireType,
807         value: &mut String,
808         buf: &mut B,
809         ctx: DecodeContext,
810     ) -> Result<(), DecodeError>
811     where
812         B: Buf,
813     {
814         // ## Unsafety
815         //
816         // `string::merge` reuses `bytes::merge`, with an additional check of utf-8
817         // well-formedness. If the utf-8 is not well-formed, or if any other error occurs, then the
818         // string is cleared, so as to avoid leaking a string field with invalid data.
819         //
820         // This implementation uses the unsafe `String::as_mut_vec` method instead of the safe
821         // alternative of temporarily swapping an empty `String` into the field, because it results
822         // in up to 10% better performance on the protobuf message decoding benchmarks.
823         //
824         // It's required when using `String::as_mut_vec` that invalid utf-8 data not be leaked into
825         // the backing `String`. To enforce this, even in the event of a panic in `bytes::merge` or
826         // in the buf implementation, a drop guard is used.
827         unsafe {
828             struct DropGuard<'a>(&'a mut Vec<u8>);
829             impl<'a> Drop for DropGuard<'a> {
830                 #[inline]
831                 fn drop(&mut self) {
832                     self.0.clear();
833                 }
834             }
835 
836             let drop_guard = DropGuard(value.as_mut_vec());
837             bytes::merge_one_copy(wire_type, drop_guard.0, buf, ctx)?;
838             match str::from_utf8(drop_guard.0) {
839                 Ok(_) => {
840                     // Success; do not clear the bytes.
841                     mem::forget(drop_guard);
842                     Ok(())
843                 }
844                 Err(_) => Err(DecodeError::new(
845                     "invalid string value: data is not UTF-8 encoded",
846                 )),
847             }
848         }
849     }
850 
851     length_delimited!(String);
852 
853     #[cfg(test)]
854     mod test {
855         use proptest::prelude::*;
856 
857         use super::super::test::{check_collection_type, check_type};
858         use super::*;
859 
860         proptest! {
861             #[test]
862             fn check(value: String, tag in MIN_TAG..=MAX_TAG) {
863                 super::test::check_type(value, tag, WireType::LengthDelimited,
864                                         encode, merge, encoded_len)?;
865             }
866             #[test]
867             fn check_repeated(value: Vec<String>, tag in MIN_TAG..=MAX_TAG) {
868                 super::test::check_collection_type(value, tag, WireType::LengthDelimited,
869                                                    encode_repeated, merge_repeated,
870                                                    encoded_len_repeated)?;
871             }
872         }
873     }
874 }
875 
876 pub trait BytesAdapter: sealed::BytesAdapter {}
877 
878 mod sealed {
879     use super::{Buf, BufMut};
880 
881     pub trait BytesAdapter: Default + Sized + 'static {
len(&self) -> usize882         fn len(&self) -> usize;
883 
884         /// Replace contents of this buffer with the contents of another buffer.
replace_with<B>(&mut self, buf: B) where B: Buf885         fn replace_with<B>(&mut self, buf: B)
886         where
887             B: Buf;
888 
889         /// Appends this buffer to the (contents of) other buffer.
append_to<B>(&self, buf: &mut B) where B: BufMut890         fn append_to<B>(&self, buf: &mut B)
891         where
892             B: BufMut;
893 
is_empty(&self) -> bool894         fn is_empty(&self) -> bool {
895             self.len() == 0
896         }
897     }
898 }
899 
900 impl BytesAdapter for Bytes {}
901 
902 impl sealed::BytesAdapter for Bytes {
len(&self) -> usize903     fn len(&self) -> usize {
904         Buf::remaining(self)
905     }
906 
replace_with<B>(&mut self, mut buf: B) where B: Buf,907     fn replace_with<B>(&mut self, mut buf: B)
908     where
909         B: Buf,
910     {
911         *self = buf.copy_to_bytes(buf.remaining());
912     }
913 
append_to<B>(&self, buf: &mut B) where B: BufMut,914     fn append_to<B>(&self, buf: &mut B)
915     where
916         B: BufMut,
917     {
918         buf.put(self.clone())
919     }
920 }
921 
922 impl BytesAdapter for Vec<u8> {}
923 
924 impl sealed::BytesAdapter for Vec<u8> {
len(&self) -> usize925     fn len(&self) -> usize {
926         Vec::len(self)
927     }
928 
replace_with<B>(&mut self, buf: B) where B: Buf,929     fn replace_with<B>(&mut self, buf: B)
930     where
931         B: Buf,
932     {
933         self.clear();
934         self.reserve(buf.remaining());
935         self.put(buf);
936     }
937 
append_to<B>(&self, buf: &mut B) where B: BufMut,938     fn append_to<B>(&self, buf: &mut B)
939     where
940         B: BufMut,
941     {
942         buf.put(self.as_slice())
943     }
944 }
945 
946 pub mod bytes {
947     use super::*;
948 
encode<A, B>(tag: u32, value: &A, buf: &mut B) where A: BytesAdapter, B: BufMut,949     pub fn encode<A, B>(tag: u32, value: &A, buf: &mut B)
950     where
951         A: BytesAdapter,
952         B: BufMut,
953     {
954         encode_key(tag, WireType::LengthDelimited, buf);
955         encode_varint(value.len() as u64, buf);
956         value.append_to(buf);
957     }
958 
merge<A, B>( wire_type: WireType, value: &mut A, buf: &mut B, _ctx: DecodeContext, ) -> Result<(), DecodeError> where A: BytesAdapter, B: Buf,959     pub fn merge<A, B>(
960         wire_type: WireType,
961         value: &mut A,
962         buf: &mut B,
963         _ctx: DecodeContext,
964     ) -> Result<(), DecodeError>
965     where
966         A: BytesAdapter,
967         B: Buf,
968     {
969         check_wire_type(WireType::LengthDelimited, wire_type)?;
970         let len = decode_varint(buf)?;
971         if len > buf.remaining() as u64 {
972             return Err(DecodeError::new("buffer underflow"));
973         }
974         let len = len as usize;
975 
976         // Clear the existing value. This follows from the following rule in the encoding guide[1]:
977         //
978         // > Normally, an encoded message would never have more than one instance of a non-repeated
979         // > field. However, parsers are expected to handle the case in which they do. For numeric
980         // > types and strings, if the same field appears multiple times, the parser accepts the
981         // > last value it sees.
982         //
983         // [1]: https://developers.google.com/protocol-buffers/docs/encoding#optional
984         //
985         // This is intended for A and B both being Bytes so it is zero-copy.
986         // Some combinations of A and B types may cause a double-copy,
987         // in which case merge_one_copy() should be used instead.
988         value.replace_with(buf.copy_to_bytes(len));
989         Ok(())
990     }
991 
merge_one_copy<A, B>( wire_type: WireType, value: &mut A, buf: &mut B, _ctx: DecodeContext, ) -> Result<(), DecodeError> where A: BytesAdapter, B: Buf,992     pub(super) fn merge_one_copy<A, B>(
993         wire_type: WireType,
994         value: &mut A,
995         buf: &mut B,
996         _ctx: DecodeContext,
997     ) -> Result<(), DecodeError>
998     where
999         A: BytesAdapter,
1000         B: Buf,
1001     {
1002         check_wire_type(WireType::LengthDelimited, wire_type)?;
1003         let len = decode_varint(buf)?;
1004         if len > buf.remaining() as u64 {
1005             return Err(DecodeError::new("buffer underflow"));
1006         }
1007         let len = len as usize;
1008 
1009         // If we must copy, make sure to copy only once.
1010         value.replace_with(buf.take(len));
1011         Ok(())
1012     }
1013 
1014     length_delimited!(impl BytesAdapter);
1015 
1016     #[cfg(test)]
1017     mod test {
1018         use proptest::prelude::*;
1019 
1020         use super::super::test::{check_collection_type, check_type};
1021         use super::*;
1022 
1023         proptest! {
1024             #[test]
1025             fn check_vec(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) {
1026                 super::test::check_type::<Vec<u8>, Vec<u8>>(value, tag, WireType::LengthDelimited,
1027                                                             encode, merge, encoded_len)?;
1028             }
1029 
1030             #[test]
1031             fn check_bytes(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) {
1032                 let value = Bytes::from(value);
1033                 super::test::check_type::<Bytes, Bytes>(value, tag, WireType::LengthDelimited,
1034                                                         encode, merge, encoded_len)?;
1035             }
1036 
1037             #[test]
1038             fn check_repeated_vec(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) {
1039                 super::test::check_collection_type(value, tag, WireType::LengthDelimited,
1040                                                    encode_repeated, merge_repeated,
1041                                                    encoded_len_repeated)?;
1042             }
1043 
1044             #[test]
1045             fn check_repeated_bytes(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) {
1046                 let value = value.into_iter().map(Bytes::from).collect();
1047                 super::test::check_collection_type(value, tag, WireType::LengthDelimited,
1048                                                    encode_repeated, merge_repeated,
1049                                                    encoded_len_repeated)?;
1050             }
1051         }
1052     }
1053 }
1054 
1055 pub mod message {
1056     use super::*;
1057 
encode<M, B>(tag: u32, msg: &M, buf: &mut B) where M: Message, B: BufMut,1058     pub fn encode<M, B>(tag: u32, msg: &M, buf: &mut B)
1059     where
1060         M: Message,
1061         B: BufMut,
1062     {
1063         encode_key(tag, WireType::LengthDelimited, buf);
1064         encode_varint(msg.encoded_len() as u64, buf);
1065         msg.encode_raw(buf);
1066     }
1067 
merge<M, B>( wire_type: WireType, msg: &mut M, buf: &mut B, ctx: DecodeContext, ) -> Result<(), DecodeError> where M: Message, B: Buf,1068     pub fn merge<M, B>(
1069         wire_type: WireType,
1070         msg: &mut M,
1071         buf: &mut B,
1072         ctx: DecodeContext,
1073     ) -> Result<(), DecodeError>
1074     where
1075         M: Message,
1076         B: Buf,
1077     {
1078         check_wire_type(WireType::LengthDelimited, wire_type)?;
1079         ctx.limit_reached()?;
1080         merge_loop(
1081             msg,
1082             buf,
1083             ctx.enter_recursion(),
1084             |msg: &mut M, buf: &mut B, ctx| {
1085                 let (tag, wire_type) = decode_key(buf)?;
1086                 msg.merge_field(tag, wire_type, buf, ctx)
1087             },
1088         )
1089     }
1090 
encode_repeated<M, B>(tag: u32, messages: &[M], buf: &mut B) where M: Message, B: BufMut,1091     pub fn encode_repeated<M, B>(tag: u32, messages: &[M], buf: &mut B)
1092     where
1093         M: Message,
1094         B: BufMut,
1095     {
1096         for msg in messages {
1097             encode(tag, msg, buf);
1098         }
1099     }
1100 
merge_repeated<M, B>( wire_type: WireType, messages: &mut Vec<M>, buf: &mut B, ctx: DecodeContext, ) -> Result<(), DecodeError> where M: Message + Default, B: Buf,1101     pub fn merge_repeated<M, B>(
1102         wire_type: WireType,
1103         messages: &mut Vec<M>,
1104         buf: &mut B,
1105         ctx: DecodeContext,
1106     ) -> Result<(), DecodeError>
1107     where
1108         M: Message + Default,
1109         B: Buf,
1110     {
1111         check_wire_type(WireType::LengthDelimited, wire_type)?;
1112         let mut msg = M::default();
1113         merge(WireType::LengthDelimited, &mut msg, buf, ctx)?;
1114         messages.push(msg);
1115         Ok(())
1116     }
1117 
1118     #[inline]
encoded_len<M>(tag: u32, msg: &M) -> usize where M: Message,1119     pub fn encoded_len<M>(tag: u32, msg: &M) -> usize
1120     where
1121         M: Message,
1122     {
1123         let len = msg.encoded_len();
1124         key_len(tag) + encoded_len_varint(len as u64) + len
1125     }
1126 
1127     #[inline]
encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize where M: Message,1128     pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize
1129     where
1130         M: Message,
1131     {
1132         key_len(tag) * messages.len()
1133             + messages
1134                 .iter()
1135                 .map(Message::encoded_len)
1136                 .map(|len| len + encoded_len_varint(len as u64))
1137                 .sum::<usize>()
1138     }
1139 }
1140 
1141 pub mod group {
1142     use super::*;
1143 
encode<M, B>(tag: u32, msg: &M, buf: &mut B) where M: Message, B: BufMut,1144     pub fn encode<M, B>(tag: u32, msg: &M, buf: &mut B)
1145     where
1146         M: Message,
1147         B: BufMut,
1148     {
1149         encode_key(tag, WireType::StartGroup, buf);
1150         msg.encode_raw(buf);
1151         encode_key(tag, WireType::EndGroup, buf);
1152     }
1153 
merge<M, B>( tag: u32, wire_type: WireType, msg: &mut M, buf: &mut B, ctx: DecodeContext, ) -> Result<(), DecodeError> where M: Message, B: Buf,1154     pub fn merge<M, B>(
1155         tag: u32,
1156         wire_type: WireType,
1157         msg: &mut M,
1158         buf: &mut B,
1159         ctx: DecodeContext,
1160     ) -> Result<(), DecodeError>
1161     where
1162         M: Message,
1163         B: Buf,
1164     {
1165         check_wire_type(WireType::StartGroup, wire_type)?;
1166 
1167         ctx.limit_reached()?;
1168         loop {
1169             let (field_tag, field_wire_type) = decode_key(buf)?;
1170             if field_wire_type == WireType::EndGroup {
1171                 if field_tag != tag {
1172                     return Err(DecodeError::new("unexpected end group tag"));
1173                 }
1174                 return Ok(());
1175             }
1176 
1177             M::merge_field(msg, field_tag, field_wire_type, buf, ctx.enter_recursion())?;
1178         }
1179     }
1180 
encode_repeated<M, B>(tag: u32, messages: &[M], buf: &mut B) where M: Message, B: BufMut,1181     pub fn encode_repeated<M, B>(tag: u32, messages: &[M], buf: &mut B)
1182     where
1183         M: Message,
1184         B: BufMut,
1185     {
1186         for msg in messages {
1187             encode(tag, msg, buf);
1188         }
1189     }
1190 
merge_repeated<M, B>( tag: u32, wire_type: WireType, messages: &mut Vec<M>, buf: &mut B, ctx: DecodeContext, ) -> Result<(), DecodeError> where M: Message + Default, B: Buf,1191     pub fn merge_repeated<M, B>(
1192         tag: u32,
1193         wire_type: WireType,
1194         messages: &mut Vec<M>,
1195         buf: &mut B,
1196         ctx: DecodeContext,
1197     ) -> Result<(), DecodeError>
1198     where
1199         M: Message + Default,
1200         B: Buf,
1201     {
1202         check_wire_type(WireType::StartGroup, wire_type)?;
1203         let mut msg = M::default();
1204         merge(tag, WireType::StartGroup, &mut msg, buf, ctx)?;
1205         messages.push(msg);
1206         Ok(())
1207     }
1208 
1209     #[inline]
encoded_len<M>(tag: u32, msg: &M) -> usize where M: Message,1210     pub fn encoded_len<M>(tag: u32, msg: &M) -> usize
1211     where
1212         M: Message,
1213     {
1214         2 * key_len(tag) + msg.encoded_len()
1215     }
1216 
1217     #[inline]
encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize where M: Message,1218     pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize
1219     where
1220         M: Message,
1221     {
1222         2 * key_len(tag) * messages.len() + messages.iter().map(Message::encoded_len).sum::<usize>()
1223     }
1224 }
1225 
1226 /// Rust doesn't have a `Map` trait, so macros are currently the best way to be
1227 /// generic over `HashMap` and `BTreeMap`.
1228 macro_rules! map {
1229     ($map_ty:ident) => {
1230         use crate::encoding::*;
1231         use core::hash::Hash;
1232 
1233         /// Generic protobuf map encode function.
1234         pub fn encode<K, V, B, KE, KL, VE, VL>(
1235             key_encode: KE,
1236             key_encoded_len: KL,
1237             val_encode: VE,
1238             val_encoded_len: VL,
1239             tag: u32,
1240             values: &$map_ty<K, V>,
1241             buf: &mut B,
1242         ) where
1243             K: Default + Eq + Hash + Ord,
1244             V: Default + PartialEq,
1245             B: BufMut,
1246             KE: Fn(u32, &K, &mut B),
1247             KL: Fn(u32, &K) -> usize,
1248             VE: Fn(u32, &V, &mut B),
1249             VL: Fn(u32, &V) -> usize,
1250         {
1251             encode_with_default(
1252                 key_encode,
1253                 key_encoded_len,
1254                 val_encode,
1255                 val_encoded_len,
1256                 &V::default(),
1257                 tag,
1258                 values,
1259                 buf,
1260             )
1261         }
1262 
1263         /// Generic protobuf map merge function.
1264         pub fn merge<K, V, B, KM, VM>(
1265             key_merge: KM,
1266             val_merge: VM,
1267             values: &mut $map_ty<K, V>,
1268             buf: &mut B,
1269             ctx: DecodeContext,
1270         ) -> Result<(), DecodeError>
1271         where
1272             K: Default + Eq + Hash + Ord,
1273             V: Default,
1274             B: Buf,
1275             KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>,
1276             VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>,
1277         {
1278             merge_with_default(key_merge, val_merge, V::default(), values, buf, ctx)
1279         }
1280 
1281         /// Generic protobuf map encode function.
1282         pub fn encoded_len<K, V, KL, VL>(
1283             key_encoded_len: KL,
1284             val_encoded_len: VL,
1285             tag: u32,
1286             values: &$map_ty<K, V>,
1287         ) -> usize
1288         where
1289             K: Default + Eq + Hash + Ord,
1290             V: Default + PartialEq,
1291             KL: Fn(u32, &K) -> usize,
1292             VL: Fn(u32, &V) -> usize,
1293         {
1294             encoded_len_with_default(key_encoded_len, val_encoded_len, &V::default(), tag, values)
1295         }
1296 
1297         /// Generic protobuf map encode function with an overridden value default.
1298         ///
1299         /// This is necessary because enumeration values can have a default value other
1300         /// than 0 in proto2.
1301         pub fn encode_with_default<K, V, B, KE, KL, VE, VL>(
1302             key_encode: KE,
1303             key_encoded_len: KL,
1304             val_encode: VE,
1305             val_encoded_len: VL,
1306             val_default: &V,
1307             tag: u32,
1308             values: &$map_ty<K, V>,
1309             buf: &mut B,
1310         ) where
1311             K: Default + Eq + Hash + Ord,
1312             V: PartialEq,
1313             B: BufMut,
1314             KE: Fn(u32, &K, &mut B),
1315             KL: Fn(u32, &K) -> usize,
1316             VE: Fn(u32, &V, &mut B),
1317             VL: Fn(u32, &V) -> usize,
1318         {
1319             for (key, val) in values.iter() {
1320                 let skip_key = key == &K::default();
1321                 let skip_val = val == val_default;
1322 
1323                 let len = (if skip_key { 0 } else { key_encoded_len(1, key) })
1324                     + (if skip_val { 0 } else { val_encoded_len(2, val) });
1325 
1326                 encode_key(tag, WireType::LengthDelimited, buf);
1327                 encode_varint(len as u64, buf);
1328                 if !skip_key {
1329                     key_encode(1, key, buf);
1330                 }
1331                 if !skip_val {
1332                     val_encode(2, val, buf);
1333                 }
1334             }
1335         }
1336 
1337         /// Generic protobuf map merge function with an overridden value default.
1338         ///
1339         /// This is necessary because enumeration values can have a default value other
1340         /// than 0 in proto2.
1341         pub fn merge_with_default<K, V, B, KM, VM>(
1342             key_merge: KM,
1343             val_merge: VM,
1344             val_default: V,
1345             values: &mut $map_ty<K, V>,
1346             buf: &mut B,
1347             ctx: DecodeContext,
1348         ) -> Result<(), DecodeError>
1349         where
1350             K: Default + Eq + Hash + Ord,
1351             B: Buf,
1352             KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>,
1353             VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>,
1354         {
1355             let mut key = Default::default();
1356             let mut val = val_default;
1357             ctx.limit_reached()?;
1358             merge_loop(
1359                 &mut (&mut key, &mut val),
1360                 buf,
1361                 ctx.enter_recursion(),
1362                 |&mut (ref mut key, ref mut val), buf, ctx| {
1363                     let (tag, wire_type) = decode_key(buf)?;
1364                     match tag {
1365                         1 => key_merge(wire_type, key, buf, ctx),
1366                         2 => val_merge(wire_type, val, buf, ctx),
1367                         _ => skip_field(wire_type, tag, buf, ctx),
1368                     }
1369                 },
1370             )?;
1371             values.insert(key, val);
1372 
1373             Ok(())
1374         }
1375 
1376         /// Generic protobuf map encode function with an overridden value default.
1377         ///
1378         /// This is necessary because enumeration values can have a default value other
1379         /// than 0 in proto2.
1380         pub fn encoded_len_with_default<K, V, KL, VL>(
1381             key_encoded_len: KL,
1382             val_encoded_len: VL,
1383             val_default: &V,
1384             tag: u32,
1385             values: &$map_ty<K, V>,
1386         ) -> usize
1387         where
1388             K: Default + Eq + Hash + Ord,
1389             V: PartialEq,
1390             KL: Fn(u32, &K) -> usize,
1391             VL: Fn(u32, &V) -> usize,
1392         {
1393             key_len(tag) * values.len()
1394                 + values
1395                     .iter()
1396                     .map(|(key, val)| {
1397                         let len = (if key == &K::default() {
1398                             0
1399                         } else {
1400                             key_encoded_len(1, key)
1401                         }) + (if val == val_default {
1402                             0
1403                         } else {
1404                             val_encoded_len(2, val)
1405                         });
1406                         encoded_len_varint(len as u64) + len
1407                     })
1408                     .sum::<usize>()
1409         }
1410     };
1411 }
1412 
1413 #[cfg(feature = "std")]
1414 pub mod hash_map {
1415     use std::collections::HashMap;
1416     map!(HashMap);
1417 }
1418 
1419 pub mod btree_map {
1420     map!(BTreeMap);
1421 }
1422 
1423 #[cfg(test)]
1424 mod test {
1425     use alloc::string::ToString;
1426     use core::borrow::Borrow;
1427     use core::fmt::Debug;
1428     use core::u64;
1429 
1430     use ::bytes::{Bytes, BytesMut};
1431     use proptest::{prelude::*, test_runner::TestCaseResult};
1432 
1433     use crate::encoding::*;
1434 
check_type<T, B>( value: T, tag: u32, wire_type: WireType, encode: fn(u32, &B, &mut BytesMut), merge: fn(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>, encoded_len: fn(u32, &B) -> usize, ) -> TestCaseResult where T: Debug + Default + PartialEq + Borrow<B>, B: ?Sized,1435     pub fn check_type<T, B>(
1436         value: T,
1437         tag: u32,
1438         wire_type: WireType,
1439         encode: fn(u32, &B, &mut BytesMut),
1440         merge: fn(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
1441         encoded_len: fn(u32, &B) -> usize,
1442     ) -> TestCaseResult
1443     where
1444         T: Debug + Default + PartialEq + Borrow<B>,
1445         B: ?Sized,
1446     {
1447         prop_assume!((MIN_TAG..=MAX_TAG).contains(&tag));
1448 
1449         let expected_len = encoded_len(tag, value.borrow());
1450 
1451         let mut buf = BytesMut::with_capacity(expected_len);
1452         encode(tag, value.borrow(), &mut buf);
1453 
1454         let mut buf = buf.freeze();
1455 
1456         prop_assert_eq!(
1457             buf.remaining(),
1458             expected_len,
1459             "encoded_len wrong; expected: {}, actual: {}",
1460             expected_len,
1461             buf.remaining()
1462         );
1463 
1464         if !buf.has_remaining() {
1465             // Short circuit for empty packed values.
1466             return Ok(());
1467         }
1468 
1469         let (decoded_tag, decoded_wire_type) =
1470             decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
1471         prop_assert_eq!(
1472             tag,
1473             decoded_tag,
1474             "decoded tag does not match; expected: {}, actual: {}",
1475             tag,
1476             decoded_tag
1477         );
1478 
1479         prop_assert_eq!(
1480             wire_type,
1481             decoded_wire_type,
1482             "decoded wire type does not match; expected: {:?}, actual: {:?}",
1483             wire_type,
1484             decoded_wire_type,
1485         );
1486 
1487         match wire_type {
1488             WireType::SixtyFourBit if buf.remaining() != 8 => Err(TestCaseError::fail(format!(
1489                 "64bit wire type illegal remaining: {}, tag: {}",
1490                 buf.remaining(),
1491                 tag
1492             ))),
1493             WireType::ThirtyTwoBit if buf.remaining() != 4 => Err(TestCaseError::fail(format!(
1494                 "32bit wire type illegal remaining: {}, tag: {}",
1495                 buf.remaining(),
1496                 tag
1497             ))),
1498             _ => Ok(()),
1499         }?;
1500 
1501         let mut roundtrip_value = T::default();
1502         merge(
1503             wire_type,
1504             &mut roundtrip_value,
1505             &mut buf,
1506             DecodeContext::default(),
1507         )
1508         .map_err(|error| TestCaseError::fail(error.to_string()))?;
1509 
1510         prop_assert!(
1511             !buf.has_remaining(),
1512             "expected buffer to be empty, remaining: {}",
1513             buf.remaining()
1514         );
1515 
1516         prop_assert_eq!(value, roundtrip_value);
1517 
1518         Ok(())
1519     }
1520 
check_collection_type<T, B, E, M, L>( value: T, tag: u32, wire_type: WireType, encode: E, mut merge: M, encoded_len: L, ) -> TestCaseResult where T: Debug + Default + PartialEq + Borrow<B>, B: ?Sized, E: FnOnce(u32, &B, &mut BytesMut), M: FnMut(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>, L: FnOnce(u32, &B) -> usize,1521     pub fn check_collection_type<T, B, E, M, L>(
1522         value: T,
1523         tag: u32,
1524         wire_type: WireType,
1525         encode: E,
1526         mut merge: M,
1527         encoded_len: L,
1528     ) -> TestCaseResult
1529     where
1530         T: Debug + Default + PartialEq + Borrow<B>,
1531         B: ?Sized,
1532         E: FnOnce(u32, &B, &mut BytesMut),
1533         M: FnMut(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
1534         L: FnOnce(u32, &B) -> usize,
1535     {
1536         prop_assume!((MIN_TAG..=MAX_TAG).contains(&tag));
1537 
1538         let expected_len = encoded_len(tag, value.borrow());
1539 
1540         let mut buf = BytesMut::with_capacity(expected_len);
1541         encode(tag, value.borrow(), &mut buf);
1542 
1543         let mut buf = buf.freeze();
1544 
1545         prop_assert_eq!(
1546             buf.remaining(),
1547             expected_len,
1548             "encoded_len wrong; expected: {}, actual: {}",
1549             expected_len,
1550             buf.remaining()
1551         );
1552 
1553         let mut roundtrip_value = Default::default();
1554         while buf.has_remaining() {
1555             let (decoded_tag, decoded_wire_type) =
1556                 decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
1557 
1558             prop_assert_eq!(
1559                 tag,
1560                 decoded_tag,
1561                 "decoded tag does not match; expected: {}, actual: {}",
1562                 tag,
1563                 decoded_tag
1564             );
1565 
1566             prop_assert_eq!(
1567                 wire_type,
1568                 decoded_wire_type,
1569                 "decoded wire type does not match; expected: {:?}, actual: {:?}",
1570                 wire_type,
1571                 decoded_wire_type
1572             );
1573 
1574             merge(
1575                 wire_type,
1576                 &mut roundtrip_value,
1577                 &mut buf,
1578                 DecodeContext::default(),
1579             )
1580             .map_err(|error| TestCaseError::fail(error.to_string()))?;
1581         }
1582 
1583         prop_assert_eq!(value, roundtrip_value);
1584 
1585         Ok(())
1586     }
1587 
1588     #[test]
string_merge_invalid_utf8()1589     fn string_merge_invalid_utf8() {
1590         let mut s = String::new();
1591         let buf = b"\x02\x80\x80";
1592 
1593         let r = string::merge(
1594             WireType::LengthDelimited,
1595             &mut s,
1596             &mut &buf[..],
1597             DecodeContext::default(),
1598         );
1599         r.expect_err("must be an error");
1600         assert!(s.is_empty());
1601     }
1602 
1603     #[test]
varint()1604     fn varint() {
1605         fn check(value: u64, mut encoded: &[u8]) {
1606             // Small buffer.
1607             let mut buf = Vec::with_capacity(1);
1608             encode_varint(value, &mut buf);
1609             assert_eq!(buf, encoded);
1610 
1611             // Large buffer.
1612             let mut buf = Vec::with_capacity(100);
1613             encode_varint(value, &mut buf);
1614             assert_eq!(buf, encoded);
1615 
1616             assert_eq!(encoded_len_varint(value), encoded.len());
1617 
1618             let roundtrip_value =
1619                 decode_varint(&mut <&[u8]>::clone(&encoded)).expect("decoding failed");
1620             assert_eq!(value, roundtrip_value);
1621 
1622             let roundtrip_value = decode_varint_slow(&mut encoded).expect("slow decoding failed");
1623             assert_eq!(value, roundtrip_value);
1624         }
1625 
1626         check(2u64.pow(0) - 1, &[0x00]);
1627         check(2u64.pow(0), &[0x01]);
1628 
1629         check(2u64.pow(7) - 1, &[0x7F]);
1630         check(2u64.pow(7), &[0x80, 0x01]);
1631         check(300, &[0xAC, 0x02]);
1632 
1633         check(2u64.pow(14) - 1, &[0xFF, 0x7F]);
1634         check(2u64.pow(14), &[0x80, 0x80, 0x01]);
1635 
1636         check(2u64.pow(21) - 1, &[0xFF, 0xFF, 0x7F]);
1637         check(2u64.pow(21), &[0x80, 0x80, 0x80, 0x01]);
1638 
1639         check(2u64.pow(28) - 1, &[0xFF, 0xFF, 0xFF, 0x7F]);
1640         check(2u64.pow(28), &[0x80, 0x80, 0x80, 0x80, 0x01]);
1641 
1642         check(2u64.pow(35) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
1643         check(2u64.pow(35), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);
1644 
1645         check(2u64.pow(42) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
1646         check(2u64.pow(42), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);
1647 
1648         check(
1649             2u64.pow(49) - 1,
1650             &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
1651         );
1652         check(
1653             2u64.pow(49),
1654             &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
1655         );
1656 
1657         check(
1658             2u64.pow(56) - 1,
1659             &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
1660         );
1661         check(
1662             2u64.pow(56),
1663             &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
1664         );
1665 
1666         check(
1667             2u64.pow(63) - 1,
1668             &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
1669         );
1670         check(
1671             2u64.pow(63),
1672             &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
1673         );
1674 
1675         check(
1676             u64::MAX,
1677             &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01],
1678         );
1679     }
1680 
1681     #[test]
varint_overflow()1682     fn varint_overflow() {
1683         let mut u64_max_plus_one: &[u8] =
1684             &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x02];
1685 
1686         decode_varint(&mut u64_max_plus_one).expect_err("decoding u64::MAX + 1 succeeded");
1687         decode_varint_slow(&mut u64_max_plus_one)
1688             .expect_err("slow decoding u64::MAX + 1 succeeded");
1689     }
1690 
1691     /// This big bowl o' macro soup generates an encoding property test for each combination of map
1692     /// type, scalar map key, and value type.
1693     /// TODO: these tests take a long time to compile, can this be improved?
1694     #[cfg(feature = "std")]
1695     macro_rules! map_tests {
1696         (keys: $keys:tt,
1697          vals: $vals:tt) => {
1698             mod hash_map {
1699                 map_tests!(@private HashMap, hash_map, $keys, $vals);
1700             }
1701             mod btree_map {
1702                 map_tests!(@private BTreeMap, btree_map, $keys, $vals);
1703             }
1704         };
1705 
1706         (@private $map_type:ident,
1707                   $mod_name:ident,
1708                   [$(($key_ty:ty, $key_proto:ident)),*],
1709                   $vals:tt) => {
1710             $(
1711                 mod $key_proto {
1712                     use std::collections::$map_type;
1713 
1714                     use proptest::prelude::*;
1715 
1716                     use crate::encoding::*;
1717                     use crate::encoding::test::check_collection_type;
1718 
1719                     map_tests!(@private $map_type, $mod_name, ($key_ty, $key_proto), $vals);
1720                 }
1721             )*
1722         };
1723 
1724         (@private $map_type:ident,
1725                   $mod_name:ident,
1726                   ($key_ty:ty, $key_proto:ident),
1727                   [$(($val_ty:ty, $val_proto:ident)),*]) => {
1728             $(
1729                 proptest! {
1730                     #[test]
1731                     fn $val_proto(values: $map_type<$key_ty, $val_ty>, tag in MIN_TAG..=MAX_TAG) {
1732                         check_collection_type(values, tag, WireType::LengthDelimited,
1733                                               |tag, values, buf| {
1734                                                   $mod_name::encode($key_proto::encode,
1735                                                                     $key_proto::encoded_len,
1736                                                                     $val_proto::encode,
1737                                                                     $val_proto::encoded_len,
1738                                                                     tag,
1739                                                                     values,
1740                                                                     buf)
1741                                               },
1742                                               |wire_type, values, buf, ctx| {
1743                                                   check_wire_type(WireType::LengthDelimited, wire_type)?;
1744                                                   $mod_name::merge($key_proto::merge,
1745                                                                    $val_proto::merge,
1746                                                                    values,
1747                                                                    buf,
1748                                                                    ctx)
1749                                               },
1750                                               |tag, values| {
1751                                                   $mod_name::encoded_len($key_proto::encoded_len,
1752                                                                          $val_proto::encoded_len,
1753                                                                          tag,
1754                                                                          values)
1755                                               })?;
1756                     }
1757                 }
1758              )*
1759         };
1760     }
1761 
1762     #[cfg(feature = "std")]
1763     map_tests!(keys: [
1764         (i32, int32),
1765         (i64, int64),
1766         (u32, uint32),
1767         (u64, uint64),
1768         (i32, sint32),
1769         (i64, sint64),
1770         (u32, fixed32),
1771         (u64, fixed64),
1772         (i32, sfixed32),
1773         (i64, sfixed64),
1774         (bool, bool),
1775         (String, string)
1776     ],
1777     vals: [
1778         (f32, float),
1779         (f64, double),
1780         (i32, int32),
1781         (i64, int64),
1782         (u32, uint32),
1783         (u64, uint64),
1784         (i32, sint32),
1785         (i64, sint64),
1786         (u32, fixed32),
1787         (u64, fixed64),
1788         (i32, sfixed32),
1789         (i64, sfixed64),
1790         (bool, bool),
1791         (String, string),
1792         (Vec<u8>, bytes)
1793     ]);
1794 }
1795