1 //! Utility functions and types for encoding and decoding Protobuf types.
2 //!
3 //! Meant to be used only from `Message` implementations.
4
5 #![allow(clippy::implicit_hasher, clippy::ptr_arg)]
6
7 use alloc::collections::BTreeMap;
8 use alloc::format;
9 use alloc::string::String;
10 use alloc::vec::Vec;
11 use core::cmp::min;
12 use core::convert::TryFrom;
13 use core::mem;
14 use core::str;
15 use core::u32;
16 use core::usize;
17
18 use ::bytes::{Buf, BufMut, Bytes};
19
20 use crate::DecodeError;
21 use crate::Message;
22
23 /// Encodes an integer value into LEB128 variable length format, and writes it to the buffer.
24 /// The buffer must have enough remaining space (maximum 10 bytes).
25 #[inline]
encode_varint<B>(mut value: u64, buf: &mut B) where B: BufMut,26 pub fn encode_varint<B>(mut value: u64, buf: &mut B)
27 where
28 B: BufMut,
29 {
30 loop {
31 if value < 0x80 {
32 buf.put_u8(value as u8);
33 break;
34 } else {
35 buf.put_u8(((value & 0x7F) | 0x80) as u8);
36 value >>= 7;
37 }
38 }
39 }
40
41 /// Decodes a LEB128-encoded variable length integer from the buffer.
42 #[inline]
decode_varint<B>(buf: &mut B) -> Result<u64, DecodeError> where B: Buf,43 pub fn decode_varint<B>(buf: &mut B) -> Result<u64, DecodeError>
44 where
45 B: Buf,
46 {
47 let bytes = buf.chunk();
48 let len = bytes.len();
49 if len == 0 {
50 return Err(DecodeError::new("invalid varint"));
51 }
52
53 let byte = bytes[0];
54 if byte < 0x80 {
55 buf.advance(1);
56 Ok(u64::from(byte))
57 } else if len > 10 || bytes[len - 1] < 0x80 {
58 let (value, advance) = decode_varint_slice(bytes)?;
59 buf.advance(advance);
60 Ok(value)
61 } else {
62 decode_varint_slow(buf)
63 }
64 }
65
66 /// Decodes a LEB128-encoded variable length integer from the slice, returning the value and the
67 /// number of bytes read.
68 ///
69 /// Based loosely on [`ReadVarint64FromArray`][1] with a varint overflow check from
70 /// [`ConsumeVarint`][2].
71 ///
72 /// ## Safety
73 ///
74 /// The caller must ensure that `bytes` is non-empty and either `bytes.len() >= 10` or the last
75 /// element in bytes is < `0x80`.
76 ///
77 /// [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.cc#L365-L406
78 /// [2]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
79 #[inline]
decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), DecodeError>80 fn decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), DecodeError> {
81 // Fully unrolled varint decoding loop. Splitting into 32-bit pieces gives better performance.
82
83 // Use assertions to ensure memory safety, but it should always be optimized after inline.
84 assert!(!bytes.is_empty());
85 assert!(bytes.len() > 10 || bytes[bytes.len() - 1] < 0x80);
86
87 let mut b: u8 = unsafe { *bytes.get_unchecked(0) };
88 let mut part0: u32 = u32::from(b);
89 if b < 0x80 {
90 return Ok((u64::from(part0), 1));
91 };
92 part0 -= 0x80;
93 b = unsafe { *bytes.get_unchecked(1) };
94 part0 += u32::from(b) << 7;
95 if b < 0x80 {
96 return Ok((u64::from(part0), 2));
97 };
98 part0 -= 0x80 << 7;
99 b = unsafe { *bytes.get_unchecked(2) };
100 part0 += u32::from(b) << 14;
101 if b < 0x80 {
102 return Ok((u64::from(part0), 3));
103 };
104 part0 -= 0x80 << 14;
105 b = unsafe { *bytes.get_unchecked(3) };
106 part0 += u32::from(b) << 21;
107 if b < 0x80 {
108 return Ok((u64::from(part0), 4));
109 };
110 part0 -= 0x80 << 21;
111 let value = u64::from(part0);
112
113 b = unsafe { *bytes.get_unchecked(4) };
114 let mut part1: u32 = u32::from(b);
115 if b < 0x80 {
116 return Ok((value + (u64::from(part1) << 28), 5));
117 };
118 part1 -= 0x80;
119 b = unsafe { *bytes.get_unchecked(5) };
120 part1 += u32::from(b) << 7;
121 if b < 0x80 {
122 return Ok((value + (u64::from(part1) << 28), 6));
123 };
124 part1 -= 0x80 << 7;
125 b = unsafe { *bytes.get_unchecked(6) };
126 part1 += u32::from(b) << 14;
127 if b < 0x80 {
128 return Ok((value + (u64::from(part1) << 28), 7));
129 };
130 part1 -= 0x80 << 14;
131 b = unsafe { *bytes.get_unchecked(7) };
132 part1 += u32::from(b) << 21;
133 if b < 0x80 {
134 return Ok((value + (u64::from(part1) << 28), 8));
135 };
136 part1 -= 0x80 << 21;
137 let value = value + ((u64::from(part1)) << 28);
138
139 b = unsafe { *bytes.get_unchecked(8) };
140 let mut part2: u32 = u32::from(b);
141 if b < 0x80 {
142 return Ok((value + (u64::from(part2) << 56), 9));
143 };
144 part2 -= 0x80;
145 b = unsafe { *bytes.get_unchecked(9) };
146 part2 += u32::from(b) << 7;
147 // Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details.
148 // [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
149 if b < 0x02 {
150 return Ok((value + (u64::from(part2) << 56), 10));
151 };
152
153 // We have overrun the maximum size of a varint (10 bytes) or the final byte caused an overflow.
154 // Assume the data is corrupt.
155 Err(DecodeError::new("invalid varint"))
156 }
157
158 /// Decodes a LEB128-encoded variable length integer from the buffer, advancing the buffer as
159 /// necessary.
160 ///
161 /// Contains a varint overflow check from [`ConsumeVarint`][1].
162 ///
163 /// [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
164 #[inline(never)]
165 #[cold]
decode_varint_slow<B>(buf: &mut B) -> Result<u64, DecodeError> where B: Buf,166 fn decode_varint_slow<B>(buf: &mut B) -> Result<u64, DecodeError>
167 where
168 B: Buf,
169 {
170 let mut value = 0;
171 for count in 0..min(10, buf.remaining()) {
172 let byte = buf.get_u8();
173 value |= u64::from(byte & 0x7F) << (count * 7);
174 if byte <= 0x7F {
175 // Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details.
176 // [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
177 if count == 9 && byte >= 0x02 {
178 return Err(DecodeError::new("invalid varint"));
179 } else {
180 return Ok(value);
181 }
182 }
183 }
184
185 Err(DecodeError::new("invalid varint"))
186 }
187
188 /// Additional information passed to every decode/merge function.
189 ///
190 /// The context should be passed by value and can be freely cloned. When passing
191 /// to a function which is decoding a nested object, then use `enter_recursion`.
192 #[derive(Clone, Debug)]
193 #[cfg_attr(feature = "no-recursion-limit", derive(Default))]
194 pub struct DecodeContext {
195 /// How many times we can recurse in the current decode stack before we hit
196 /// the recursion limit.
197 ///
198 /// The recursion limit is defined by `RECURSION_LIMIT` and cannot be
199 /// customized. The recursion limit can be ignored by building the Prost
200 /// crate with the `no-recursion-limit` feature.
201 #[cfg(not(feature = "no-recursion-limit"))]
202 recurse_count: u32,
203 }
204
205 #[cfg(not(feature = "no-recursion-limit"))]
206 impl Default for DecodeContext {
207 #[inline]
default() -> DecodeContext208 fn default() -> DecodeContext {
209 DecodeContext {
210 recurse_count: crate::RECURSION_LIMIT,
211 }
212 }
213 }
214
215 impl DecodeContext {
216 /// Call this function before recursively decoding.
217 ///
218 /// There is no `exit` function since this function creates a new `DecodeContext`
219 /// to be used at the next level of recursion. Continue to use the old context
220 // at the previous level of recursion.
221 #[cfg(not(feature = "no-recursion-limit"))]
222 #[inline]
enter_recursion(&self) -> DecodeContext223 pub(crate) fn enter_recursion(&self) -> DecodeContext {
224 DecodeContext {
225 recurse_count: self.recurse_count - 1,
226 }
227 }
228
229 #[cfg(feature = "no-recursion-limit")]
230 #[inline]
enter_recursion(&self) -> DecodeContext231 pub(crate) fn enter_recursion(&self) -> DecodeContext {
232 DecodeContext {}
233 }
234
235 /// Checks whether the recursion limit has been reached in the stack of
236 /// decodes described by the `DecodeContext` at `self.ctx`.
237 ///
238 /// Returns `Ok<()>` if it is ok to continue recursing.
239 /// Returns `Err<DecodeError>` if the recursion limit has been reached.
240 #[cfg(not(feature = "no-recursion-limit"))]
241 #[inline]
limit_reached(&self) -> Result<(), DecodeError>242 pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
243 if self.recurse_count == 0 {
244 Err(DecodeError::new("recursion limit reached"))
245 } else {
246 Ok(())
247 }
248 }
249
250 #[cfg(feature = "no-recursion-limit")]
251 #[inline]
252 #[allow(clippy::unnecessary_wraps)] // needed in other features
limit_reached(&self) -> Result<(), DecodeError>253 pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
254 Ok(())
255 }
256 }
257
258 /// Returns the encoded length of the value in LEB128 variable length format.
259 /// The returned value will be between 1 and 10, inclusive.
260 #[inline]
encoded_len_varint(value: u64) -> usize261 pub fn encoded_len_varint(value: u64) -> usize {
262 // Based on [VarintSize64][1].
263 // [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.h#L1301-L1309
264 ((((value | 1).leading_zeros() ^ 63) * 9 + 73) / 64) as usize
265 }
266
267 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
268 #[repr(u8)]
269 pub enum WireType {
270 Varint = 0,
271 SixtyFourBit = 1,
272 LengthDelimited = 2,
273 StartGroup = 3,
274 EndGroup = 4,
275 ThirtyTwoBit = 5,
276 }
277
278 pub const MIN_TAG: u32 = 1;
279 pub const MAX_TAG: u32 = (1 << 29) - 1;
280
281 impl TryFrom<u64> for WireType {
282 type Error = DecodeError;
283
284 #[inline]
try_from(value: u64) -> Result<Self, Self::Error>285 fn try_from(value: u64) -> Result<Self, Self::Error> {
286 match value {
287 0 => Ok(WireType::Varint),
288 1 => Ok(WireType::SixtyFourBit),
289 2 => Ok(WireType::LengthDelimited),
290 3 => Ok(WireType::StartGroup),
291 4 => Ok(WireType::EndGroup),
292 5 => Ok(WireType::ThirtyTwoBit),
293 _ => Err(DecodeError::new(format!(
294 "invalid wire type value: {}",
295 value
296 ))),
297 }
298 }
299 }
300
301 /// Encodes a Protobuf field key, which consists of a wire type designator and
302 /// the field tag.
303 #[inline]
encode_key<B>(tag: u32, wire_type: WireType, buf: &mut B) where B: BufMut,304 pub fn encode_key<B>(tag: u32, wire_type: WireType, buf: &mut B)
305 where
306 B: BufMut,
307 {
308 debug_assert!((MIN_TAG..=MAX_TAG).contains(&tag));
309 let key = (tag << 3) | wire_type as u32;
310 encode_varint(u64::from(key), buf);
311 }
312
313 /// Decodes a Protobuf field key, which consists of a wire type designator and
314 /// the field tag.
315 #[inline(always)]
decode_key<B>(buf: &mut B) -> Result<(u32, WireType), DecodeError> where B: Buf,316 pub fn decode_key<B>(buf: &mut B) -> Result<(u32, WireType), DecodeError>
317 where
318 B: Buf,
319 {
320 let key = decode_varint(buf)?;
321 if key > u64::from(u32::MAX) {
322 return Err(DecodeError::new(format!("invalid key value: {}", key)));
323 }
324 let wire_type = WireType::try_from(key & 0x07)?;
325 let tag = key as u32 >> 3;
326
327 if tag < MIN_TAG {
328 return Err(DecodeError::new("invalid tag value: 0"));
329 }
330
331 Ok((tag, wire_type))
332 }
333
334 /// Returns the width of an encoded Protobuf field key with the given tag.
335 /// The returned width will be between 1 and 5 bytes (inclusive).
336 #[inline]
key_len(tag: u32) -> usize337 pub fn key_len(tag: u32) -> usize {
338 encoded_len_varint(u64::from(tag << 3))
339 }
340
341 /// Checks that the expected wire type matches the actual wire type,
342 /// or returns an error result.
343 #[inline]
check_wire_type(expected: WireType, actual: WireType) -> Result<(), DecodeError>344 pub fn check_wire_type(expected: WireType, actual: WireType) -> Result<(), DecodeError> {
345 if expected != actual {
346 return Err(DecodeError::new(format!(
347 "invalid wire type: {:?} (expected {:?})",
348 actual, expected
349 )));
350 }
351 Ok(())
352 }
353
354 /// Helper function which abstracts reading a length delimiter prefix followed
355 /// by decoding values until the length of bytes is exhausted.
merge_loop<T, M, B>( value: &mut T, buf: &mut B, ctx: DecodeContext, mut merge: M, ) -> Result<(), DecodeError> where M: FnMut(&mut T, &mut B, DecodeContext) -> Result<(), DecodeError>, B: Buf,356 pub fn merge_loop<T, M, B>(
357 value: &mut T,
358 buf: &mut B,
359 ctx: DecodeContext,
360 mut merge: M,
361 ) -> Result<(), DecodeError>
362 where
363 M: FnMut(&mut T, &mut B, DecodeContext) -> Result<(), DecodeError>,
364 B: Buf,
365 {
366 let len = decode_varint(buf)?;
367 let remaining = buf.remaining();
368 if len > remaining as u64 {
369 return Err(DecodeError::new("buffer underflow"));
370 }
371
372 let limit = remaining - len as usize;
373 while buf.remaining() > limit {
374 merge(value, buf, ctx.clone())?;
375 }
376
377 if buf.remaining() != limit {
378 return Err(DecodeError::new("delimited length exceeded"));
379 }
380 Ok(())
381 }
382
skip_field<B>( wire_type: WireType, tag: u32, buf: &mut B, ctx: DecodeContext, ) -> Result<(), DecodeError> where B: Buf,383 pub fn skip_field<B>(
384 wire_type: WireType,
385 tag: u32,
386 buf: &mut B,
387 ctx: DecodeContext,
388 ) -> Result<(), DecodeError>
389 where
390 B: Buf,
391 {
392 ctx.limit_reached()?;
393 let len = match wire_type {
394 WireType::Varint => decode_varint(buf).map(|_| 0)?,
395 WireType::ThirtyTwoBit => 4,
396 WireType::SixtyFourBit => 8,
397 WireType::LengthDelimited => decode_varint(buf)?,
398 WireType::StartGroup => loop {
399 let (inner_tag, inner_wire_type) = decode_key(buf)?;
400 match inner_wire_type {
401 WireType::EndGroup => {
402 if inner_tag != tag {
403 return Err(DecodeError::new("unexpected end group tag"));
404 }
405 break 0;
406 }
407 _ => skip_field(inner_wire_type, inner_tag, buf, ctx.enter_recursion())?,
408 }
409 },
410 WireType::EndGroup => return Err(DecodeError::new("unexpected end group tag")),
411 };
412
413 if len > buf.remaining() as u64 {
414 return Err(DecodeError::new("buffer underflow"));
415 }
416
417 buf.advance(len as usize);
418 Ok(())
419 }
420
421 /// Helper macro which emits an `encode_repeated` function for the type.
422 macro_rules! encode_repeated {
423 ($ty:ty) => {
424 pub fn encode_repeated<B>(tag: u32, values: &[$ty], buf: &mut B)
425 where
426 B: BufMut,
427 {
428 for value in values {
429 encode(tag, value, buf);
430 }
431 }
432 };
433 }
434
435 /// Helper macro which emits a `merge_repeated` function for the numeric type.
436 macro_rules! merge_repeated_numeric {
437 ($ty:ty,
438 $wire_type:expr,
439 $merge:ident,
440 $merge_repeated:ident) => {
441 pub fn $merge_repeated<B>(
442 wire_type: WireType,
443 values: &mut Vec<$ty>,
444 buf: &mut B,
445 ctx: DecodeContext,
446 ) -> Result<(), DecodeError>
447 where
448 B: Buf,
449 {
450 if wire_type == WireType::LengthDelimited {
451 // Packed.
452 merge_loop(values, buf, ctx, |values, buf, ctx| {
453 let mut value = Default::default();
454 $merge($wire_type, &mut value, buf, ctx)?;
455 values.push(value);
456 Ok(())
457 })
458 } else {
459 // Unpacked.
460 check_wire_type($wire_type, wire_type)?;
461 let mut value = Default::default();
462 $merge(wire_type, &mut value, buf, ctx)?;
463 values.push(value);
464 Ok(())
465 }
466 }
467 };
468 }
469
470 /// Macro which emits a module containing a set of encoding functions for a
471 /// variable width numeric type.
472 macro_rules! varint {
473 ($ty:ty,
474 $proto_ty:ident) => (
475 varint!($ty,
476 $proto_ty,
477 to_uint64(value) { *value as u64 },
478 from_uint64(value) { value as $ty });
479 );
480
481 ($ty:ty,
482 $proto_ty:ident,
483 to_uint64($to_uint64_value:ident) $to_uint64:expr,
484 from_uint64($from_uint64_value:ident) $from_uint64:expr) => (
485
486 pub mod $proto_ty {
487 use crate::encoding::*;
488
489 pub fn encode<B>(tag: u32, $to_uint64_value: &$ty, buf: &mut B) where B: BufMut {
490 encode_key(tag, WireType::Varint, buf);
491 encode_varint($to_uint64, buf);
492 }
493
494 pub fn merge<B>(wire_type: WireType, value: &mut $ty, buf: &mut B, _ctx: DecodeContext) -> Result<(), DecodeError> where B: Buf {
495 check_wire_type(WireType::Varint, wire_type)?;
496 let $from_uint64_value = decode_varint(buf)?;
497 *value = $from_uint64;
498 Ok(())
499 }
500
501 encode_repeated!($ty);
502
503 pub fn encode_packed<B>(tag: u32, values: &[$ty], buf: &mut B) where B: BufMut {
504 if values.is_empty() { return; }
505
506 encode_key(tag, WireType::LengthDelimited, buf);
507 let len: usize = values.iter().map(|$to_uint64_value| {
508 encoded_len_varint($to_uint64)
509 }).sum();
510 encode_varint(len as u64, buf);
511
512 for $to_uint64_value in values {
513 encode_varint($to_uint64, buf);
514 }
515 }
516
517 merge_repeated_numeric!($ty, WireType::Varint, merge, merge_repeated);
518
519 #[inline]
520 pub fn encoded_len(tag: u32, $to_uint64_value: &$ty) -> usize {
521 key_len(tag) + encoded_len_varint($to_uint64)
522 }
523
524 #[inline]
525 pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
526 key_len(tag) * values.len() + values.iter().map(|$to_uint64_value| {
527 encoded_len_varint($to_uint64)
528 }).sum::<usize>()
529 }
530
531 #[inline]
532 pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize {
533 if values.is_empty() {
534 0
535 } else {
536 let len = values.iter()
537 .map(|$to_uint64_value| encoded_len_varint($to_uint64))
538 .sum::<usize>();
539 key_len(tag) + encoded_len_varint(len as u64) + len
540 }
541 }
542
543 #[cfg(test)]
544 mod test {
545 use proptest::prelude::*;
546
547 use crate::encoding::$proto_ty::*;
548 use crate::encoding::test::{
549 check_collection_type,
550 check_type,
551 };
552
553 proptest! {
554 #[test]
555 fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) {
556 check_type(value, tag, WireType::Varint,
557 encode, merge, encoded_len)?;
558 }
559 #[test]
560 fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
561 check_collection_type(value, tag, WireType::Varint,
562 encode_repeated, merge_repeated,
563 encoded_len_repeated)?;
564 }
565 #[test]
566 fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
567 check_type(value, tag, WireType::LengthDelimited,
568 encode_packed, merge_repeated,
569 encoded_len_packed)?;
570 }
571 }
572 }
573 }
574
575 );
576 }
577 varint!(bool, bool,
578 to_uint64(value) u64::from(*value),
579 from_uint64(value) value != 0);
580 varint!(i32, int32);
581 varint!(i64, int64);
582 varint!(u32, uint32);
583 varint!(u64, uint64);
584 varint!(i32, sint32,
585 to_uint64(value) {
586 ((value << 1) ^ (value >> 31)) as u32 as u64
587 },
588 from_uint64(value) {
589 let value = value as u32;
590 ((value >> 1) as i32) ^ (-((value & 1) as i32))
591 });
592 varint!(i64, sint64,
593 to_uint64(value) {
594 ((value << 1) ^ (value >> 63)) as u64
595 },
596 from_uint64(value) {
597 ((value >> 1) as i64) ^ (-((value & 1) as i64))
598 });
599
600 /// Macro which emits a module containing a set of encoding functions for a
601 /// fixed width numeric type.
602 macro_rules! fixed_width {
603 ($ty:ty,
604 $width:expr,
605 $wire_type:expr,
606 $proto_ty:ident,
607 $put:ident,
608 $get:ident) => {
609 pub mod $proto_ty {
610 use crate::encoding::*;
611
612 pub fn encode<B>(tag: u32, value: &$ty, buf: &mut B)
613 where
614 B: BufMut,
615 {
616 encode_key(tag, $wire_type, buf);
617 buf.$put(*value);
618 }
619
620 pub fn merge<B>(
621 wire_type: WireType,
622 value: &mut $ty,
623 buf: &mut B,
624 _ctx: DecodeContext,
625 ) -> Result<(), DecodeError>
626 where
627 B: Buf,
628 {
629 check_wire_type($wire_type, wire_type)?;
630 if buf.remaining() < $width {
631 return Err(DecodeError::new("buffer underflow"));
632 }
633 *value = buf.$get();
634 Ok(())
635 }
636
637 encode_repeated!($ty);
638
639 pub fn encode_packed<B>(tag: u32, values: &[$ty], buf: &mut B)
640 where
641 B: BufMut,
642 {
643 if values.is_empty() {
644 return;
645 }
646
647 encode_key(tag, WireType::LengthDelimited, buf);
648 let len = values.len() as u64 * $width;
649 encode_varint(len as u64, buf);
650
651 for value in values {
652 buf.$put(*value);
653 }
654 }
655
656 merge_repeated_numeric!($ty, $wire_type, merge, merge_repeated);
657
658 #[inline]
659 pub fn encoded_len(tag: u32, _: &$ty) -> usize {
660 key_len(tag) + $width
661 }
662
663 #[inline]
664 pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
665 (key_len(tag) + $width) * values.len()
666 }
667
668 #[inline]
669 pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize {
670 if values.is_empty() {
671 0
672 } else {
673 let len = $width * values.len();
674 key_len(tag) + encoded_len_varint(len as u64) + len
675 }
676 }
677
678 #[cfg(test)]
679 mod test {
680 use proptest::prelude::*;
681
682 use super::super::test::{check_collection_type, check_type};
683 use super::*;
684
685 proptest! {
686 #[test]
687 fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) {
688 check_type(value, tag, $wire_type,
689 encode, merge, encoded_len)?;
690 }
691 #[test]
692 fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
693 check_collection_type(value, tag, $wire_type,
694 encode_repeated, merge_repeated,
695 encoded_len_repeated)?;
696 }
697 #[test]
698 fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
699 check_type(value, tag, WireType::LengthDelimited,
700 encode_packed, merge_repeated,
701 encoded_len_packed)?;
702 }
703 }
704 }
705 }
706 };
707 }
708 fixed_width!(
709 f32,
710 4,
711 WireType::ThirtyTwoBit,
712 float,
713 put_f32_le,
714 get_f32_le
715 );
716 fixed_width!(
717 f64,
718 8,
719 WireType::SixtyFourBit,
720 double,
721 put_f64_le,
722 get_f64_le
723 );
724 fixed_width!(
725 u32,
726 4,
727 WireType::ThirtyTwoBit,
728 fixed32,
729 put_u32_le,
730 get_u32_le
731 );
732 fixed_width!(
733 u64,
734 8,
735 WireType::SixtyFourBit,
736 fixed64,
737 put_u64_le,
738 get_u64_le
739 );
740 fixed_width!(
741 i32,
742 4,
743 WireType::ThirtyTwoBit,
744 sfixed32,
745 put_i32_le,
746 get_i32_le
747 );
748 fixed_width!(
749 i64,
750 8,
751 WireType::SixtyFourBit,
752 sfixed64,
753 put_i64_le,
754 get_i64_le
755 );
756
757 /// Macro which emits encoding functions for a length-delimited type.
758 macro_rules! length_delimited {
759 ($ty:ty) => {
760 encode_repeated!($ty);
761
762 pub fn merge_repeated<B>(
763 wire_type: WireType,
764 values: &mut Vec<$ty>,
765 buf: &mut B,
766 ctx: DecodeContext,
767 ) -> Result<(), DecodeError>
768 where
769 B: Buf,
770 {
771 check_wire_type(WireType::LengthDelimited, wire_type)?;
772 let mut value = Default::default();
773 merge(wire_type, &mut value, buf, ctx)?;
774 values.push(value);
775 Ok(())
776 }
777
778 #[inline]
779 pub fn encoded_len(tag: u32, value: &$ty) -> usize {
780 key_len(tag) + encoded_len_varint(value.len() as u64) + value.len()
781 }
782
783 #[inline]
784 pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
785 key_len(tag) * values.len()
786 + values
787 .iter()
788 .map(|value| encoded_len_varint(value.len() as u64) + value.len())
789 .sum::<usize>()
790 }
791 };
792 }
793
794 pub mod string {
795 use super::*;
796
encode<B>(tag: u32, value: &String, buf: &mut B) where B: BufMut,797 pub fn encode<B>(tag: u32, value: &String, buf: &mut B)
798 where
799 B: BufMut,
800 {
801 encode_key(tag, WireType::LengthDelimited, buf);
802 encode_varint(value.len() as u64, buf);
803 buf.put_slice(value.as_bytes());
804 }
merge<B>( wire_type: WireType, value: &mut String, buf: &mut B, ctx: DecodeContext, ) -> Result<(), DecodeError> where B: Buf,805 pub fn merge<B>(
806 wire_type: WireType,
807 value: &mut String,
808 buf: &mut B,
809 ctx: DecodeContext,
810 ) -> Result<(), DecodeError>
811 where
812 B: Buf,
813 {
814 // ## Unsafety
815 //
816 // `string::merge` reuses `bytes::merge`, with an additional check of utf-8
817 // well-formedness. If the utf-8 is not well-formed, or if any other error occurs, then the
818 // string is cleared, so as to avoid leaking a string field with invalid data.
819 //
820 // This implementation uses the unsafe `String::as_mut_vec` method instead of the safe
821 // alternative of temporarily swapping an empty `String` into the field, because it results
822 // in up to 10% better performance on the protobuf message decoding benchmarks.
823 //
824 // It's required when using `String::as_mut_vec` that invalid utf-8 data not be leaked into
825 // the backing `String`. To enforce this, even in the event of a panic in `bytes::merge` or
826 // in the buf implementation, a drop guard is used.
827 unsafe {
828 struct DropGuard<'a>(&'a mut Vec<u8>);
829 impl<'a> Drop for DropGuard<'a> {
830 #[inline]
831 fn drop(&mut self) {
832 self.0.clear();
833 }
834 }
835
836 let drop_guard = DropGuard(value.as_mut_vec());
837 bytes::merge_one_copy(wire_type, drop_guard.0, buf, ctx)?;
838 match str::from_utf8(drop_guard.0) {
839 Ok(_) => {
840 // Success; do not clear the bytes.
841 mem::forget(drop_guard);
842 Ok(())
843 }
844 Err(_) => Err(DecodeError::new(
845 "invalid string value: data is not UTF-8 encoded",
846 )),
847 }
848 }
849 }
850
851 length_delimited!(String);
852
853 #[cfg(test)]
854 mod test {
855 use proptest::prelude::*;
856
857 use super::super::test::{check_collection_type, check_type};
858 use super::*;
859
860 proptest! {
861 #[test]
862 fn check(value: String, tag in MIN_TAG..=MAX_TAG) {
863 super::test::check_type(value, tag, WireType::LengthDelimited,
864 encode, merge, encoded_len)?;
865 }
866 #[test]
867 fn check_repeated(value: Vec<String>, tag in MIN_TAG..=MAX_TAG) {
868 super::test::check_collection_type(value, tag, WireType::LengthDelimited,
869 encode_repeated, merge_repeated,
870 encoded_len_repeated)?;
871 }
872 }
873 }
874 }
875
876 pub trait BytesAdapter: sealed::BytesAdapter {}
877
878 mod sealed {
879 use super::{Buf, BufMut};
880
881 pub trait BytesAdapter: Default + Sized + 'static {
len(&self) -> usize882 fn len(&self) -> usize;
883
884 /// Replace contents of this buffer with the contents of another buffer.
replace_with<B>(&mut self, buf: B) where B: Buf885 fn replace_with<B>(&mut self, buf: B)
886 where
887 B: Buf;
888
889 /// Appends this buffer to the (contents of) other buffer.
append_to<B>(&self, buf: &mut B) where B: BufMut890 fn append_to<B>(&self, buf: &mut B)
891 where
892 B: BufMut;
893
is_empty(&self) -> bool894 fn is_empty(&self) -> bool {
895 self.len() == 0
896 }
897 }
898 }
899
900 impl BytesAdapter for Bytes {}
901
902 impl sealed::BytesAdapter for Bytes {
len(&self) -> usize903 fn len(&self) -> usize {
904 Buf::remaining(self)
905 }
906
replace_with<B>(&mut self, mut buf: B) where B: Buf,907 fn replace_with<B>(&mut self, mut buf: B)
908 where
909 B: Buf,
910 {
911 *self = buf.copy_to_bytes(buf.remaining());
912 }
913
append_to<B>(&self, buf: &mut B) where B: BufMut,914 fn append_to<B>(&self, buf: &mut B)
915 where
916 B: BufMut,
917 {
918 buf.put(self.clone())
919 }
920 }
921
922 impl BytesAdapter for Vec<u8> {}
923
924 impl sealed::BytesAdapter for Vec<u8> {
len(&self) -> usize925 fn len(&self) -> usize {
926 Vec::len(self)
927 }
928
replace_with<B>(&mut self, buf: B) where B: Buf,929 fn replace_with<B>(&mut self, buf: B)
930 where
931 B: Buf,
932 {
933 self.clear();
934 self.reserve(buf.remaining());
935 self.put(buf);
936 }
937
append_to<B>(&self, buf: &mut B) where B: BufMut,938 fn append_to<B>(&self, buf: &mut B)
939 where
940 B: BufMut,
941 {
942 buf.put(self.as_slice())
943 }
944 }
945
946 pub mod bytes {
947 use super::*;
948
encode<A, B>(tag: u32, value: &A, buf: &mut B) where A: BytesAdapter, B: BufMut,949 pub fn encode<A, B>(tag: u32, value: &A, buf: &mut B)
950 where
951 A: BytesAdapter,
952 B: BufMut,
953 {
954 encode_key(tag, WireType::LengthDelimited, buf);
955 encode_varint(value.len() as u64, buf);
956 value.append_to(buf);
957 }
958
merge<A, B>( wire_type: WireType, value: &mut A, buf: &mut B, _ctx: DecodeContext, ) -> Result<(), DecodeError> where A: BytesAdapter, B: Buf,959 pub fn merge<A, B>(
960 wire_type: WireType,
961 value: &mut A,
962 buf: &mut B,
963 _ctx: DecodeContext,
964 ) -> Result<(), DecodeError>
965 where
966 A: BytesAdapter,
967 B: Buf,
968 {
969 check_wire_type(WireType::LengthDelimited, wire_type)?;
970 let len = decode_varint(buf)?;
971 if len > buf.remaining() as u64 {
972 return Err(DecodeError::new("buffer underflow"));
973 }
974 let len = len as usize;
975
976 // Clear the existing value. This follows from the following rule in the encoding guide[1]:
977 //
978 // > Normally, an encoded message would never have more than one instance of a non-repeated
979 // > field. However, parsers are expected to handle the case in which they do. For numeric
980 // > types and strings, if the same field appears multiple times, the parser accepts the
981 // > last value it sees.
982 //
983 // [1]: https://developers.google.com/protocol-buffers/docs/encoding#optional
984 //
985 // This is intended for A and B both being Bytes so it is zero-copy.
986 // Some combinations of A and B types may cause a double-copy,
987 // in which case merge_one_copy() should be used instead.
988 value.replace_with(buf.copy_to_bytes(len));
989 Ok(())
990 }
991
merge_one_copy<A, B>( wire_type: WireType, value: &mut A, buf: &mut B, _ctx: DecodeContext, ) -> Result<(), DecodeError> where A: BytesAdapter, B: Buf,992 pub(super) fn merge_one_copy<A, B>(
993 wire_type: WireType,
994 value: &mut A,
995 buf: &mut B,
996 _ctx: DecodeContext,
997 ) -> Result<(), DecodeError>
998 where
999 A: BytesAdapter,
1000 B: Buf,
1001 {
1002 check_wire_type(WireType::LengthDelimited, wire_type)?;
1003 let len = decode_varint(buf)?;
1004 if len > buf.remaining() as u64 {
1005 return Err(DecodeError::new("buffer underflow"));
1006 }
1007 let len = len as usize;
1008
1009 // If we must copy, make sure to copy only once.
1010 value.replace_with(buf.take(len));
1011 Ok(())
1012 }
1013
1014 length_delimited!(impl BytesAdapter);
1015
1016 #[cfg(test)]
1017 mod test {
1018 use proptest::prelude::*;
1019
1020 use super::super::test::{check_collection_type, check_type};
1021 use super::*;
1022
1023 proptest! {
1024 #[test]
1025 fn check_vec(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) {
1026 super::test::check_type::<Vec<u8>, Vec<u8>>(value, tag, WireType::LengthDelimited,
1027 encode, merge, encoded_len)?;
1028 }
1029
1030 #[test]
1031 fn check_bytes(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) {
1032 let value = Bytes::from(value);
1033 super::test::check_type::<Bytes, Bytes>(value, tag, WireType::LengthDelimited,
1034 encode, merge, encoded_len)?;
1035 }
1036
1037 #[test]
1038 fn check_repeated_vec(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) {
1039 super::test::check_collection_type(value, tag, WireType::LengthDelimited,
1040 encode_repeated, merge_repeated,
1041 encoded_len_repeated)?;
1042 }
1043
1044 #[test]
1045 fn check_repeated_bytes(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) {
1046 let value = value.into_iter().map(Bytes::from).collect();
1047 super::test::check_collection_type(value, tag, WireType::LengthDelimited,
1048 encode_repeated, merge_repeated,
1049 encoded_len_repeated)?;
1050 }
1051 }
1052 }
1053 }
1054
1055 pub mod message {
1056 use super::*;
1057
encode<M, B>(tag: u32, msg: &M, buf: &mut B) where M: Message, B: BufMut,1058 pub fn encode<M, B>(tag: u32, msg: &M, buf: &mut B)
1059 where
1060 M: Message,
1061 B: BufMut,
1062 {
1063 encode_key(tag, WireType::LengthDelimited, buf);
1064 encode_varint(msg.encoded_len() as u64, buf);
1065 msg.encode_raw(buf);
1066 }
1067
merge<M, B>( wire_type: WireType, msg: &mut M, buf: &mut B, ctx: DecodeContext, ) -> Result<(), DecodeError> where M: Message, B: Buf,1068 pub fn merge<M, B>(
1069 wire_type: WireType,
1070 msg: &mut M,
1071 buf: &mut B,
1072 ctx: DecodeContext,
1073 ) -> Result<(), DecodeError>
1074 where
1075 M: Message,
1076 B: Buf,
1077 {
1078 check_wire_type(WireType::LengthDelimited, wire_type)?;
1079 ctx.limit_reached()?;
1080 merge_loop(
1081 msg,
1082 buf,
1083 ctx.enter_recursion(),
1084 |msg: &mut M, buf: &mut B, ctx| {
1085 let (tag, wire_type) = decode_key(buf)?;
1086 msg.merge_field(tag, wire_type, buf, ctx)
1087 },
1088 )
1089 }
1090
encode_repeated<M, B>(tag: u32, messages: &[M], buf: &mut B) where M: Message, B: BufMut,1091 pub fn encode_repeated<M, B>(tag: u32, messages: &[M], buf: &mut B)
1092 where
1093 M: Message,
1094 B: BufMut,
1095 {
1096 for msg in messages {
1097 encode(tag, msg, buf);
1098 }
1099 }
1100
merge_repeated<M, B>( wire_type: WireType, messages: &mut Vec<M>, buf: &mut B, ctx: DecodeContext, ) -> Result<(), DecodeError> where M: Message + Default, B: Buf,1101 pub fn merge_repeated<M, B>(
1102 wire_type: WireType,
1103 messages: &mut Vec<M>,
1104 buf: &mut B,
1105 ctx: DecodeContext,
1106 ) -> Result<(), DecodeError>
1107 where
1108 M: Message + Default,
1109 B: Buf,
1110 {
1111 check_wire_type(WireType::LengthDelimited, wire_type)?;
1112 let mut msg = M::default();
1113 merge(WireType::LengthDelimited, &mut msg, buf, ctx)?;
1114 messages.push(msg);
1115 Ok(())
1116 }
1117
1118 #[inline]
encoded_len<M>(tag: u32, msg: &M) -> usize where M: Message,1119 pub fn encoded_len<M>(tag: u32, msg: &M) -> usize
1120 where
1121 M: Message,
1122 {
1123 let len = msg.encoded_len();
1124 key_len(tag) + encoded_len_varint(len as u64) + len
1125 }
1126
1127 #[inline]
encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize where M: Message,1128 pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize
1129 where
1130 M: Message,
1131 {
1132 key_len(tag) * messages.len()
1133 + messages
1134 .iter()
1135 .map(Message::encoded_len)
1136 .map(|len| len + encoded_len_varint(len as u64))
1137 .sum::<usize>()
1138 }
1139 }
1140
1141 pub mod group {
1142 use super::*;
1143
encode<M, B>(tag: u32, msg: &M, buf: &mut B) where M: Message, B: BufMut,1144 pub fn encode<M, B>(tag: u32, msg: &M, buf: &mut B)
1145 where
1146 M: Message,
1147 B: BufMut,
1148 {
1149 encode_key(tag, WireType::StartGroup, buf);
1150 msg.encode_raw(buf);
1151 encode_key(tag, WireType::EndGroup, buf);
1152 }
1153
merge<M, B>( tag: u32, wire_type: WireType, msg: &mut M, buf: &mut B, ctx: DecodeContext, ) -> Result<(), DecodeError> where M: Message, B: Buf,1154 pub fn merge<M, B>(
1155 tag: u32,
1156 wire_type: WireType,
1157 msg: &mut M,
1158 buf: &mut B,
1159 ctx: DecodeContext,
1160 ) -> Result<(), DecodeError>
1161 where
1162 M: Message,
1163 B: Buf,
1164 {
1165 check_wire_type(WireType::StartGroup, wire_type)?;
1166
1167 ctx.limit_reached()?;
1168 loop {
1169 let (field_tag, field_wire_type) = decode_key(buf)?;
1170 if field_wire_type == WireType::EndGroup {
1171 if field_tag != tag {
1172 return Err(DecodeError::new("unexpected end group tag"));
1173 }
1174 return Ok(());
1175 }
1176
1177 M::merge_field(msg, field_tag, field_wire_type, buf, ctx.enter_recursion())?;
1178 }
1179 }
1180
encode_repeated<M, B>(tag: u32, messages: &[M], buf: &mut B) where M: Message, B: BufMut,1181 pub fn encode_repeated<M, B>(tag: u32, messages: &[M], buf: &mut B)
1182 where
1183 M: Message,
1184 B: BufMut,
1185 {
1186 for msg in messages {
1187 encode(tag, msg, buf);
1188 }
1189 }
1190
merge_repeated<M, B>( tag: u32, wire_type: WireType, messages: &mut Vec<M>, buf: &mut B, ctx: DecodeContext, ) -> Result<(), DecodeError> where M: Message + Default, B: Buf,1191 pub fn merge_repeated<M, B>(
1192 tag: u32,
1193 wire_type: WireType,
1194 messages: &mut Vec<M>,
1195 buf: &mut B,
1196 ctx: DecodeContext,
1197 ) -> Result<(), DecodeError>
1198 where
1199 M: Message + Default,
1200 B: Buf,
1201 {
1202 check_wire_type(WireType::StartGroup, wire_type)?;
1203 let mut msg = M::default();
1204 merge(tag, WireType::StartGroup, &mut msg, buf, ctx)?;
1205 messages.push(msg);
1206 Ok(())
1207 }
1208
1209 #[inline]
encoded_len<M>(tag: u32, msg: &M) -> usize where M: Message,1210 pub fn encoded_len<M>(tag: u32, msg: &M) -> usize
1211 where
1212 M: Message,
1213 {
1214 2 * key_len(tag) + msg.encoded_len()
1215 }
1216
1217 #[inline]
encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize where M: Message,1218 pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize
1219 where
1220 M: Message,
1221 {
1222 2 * key_len(tag) * messages.len() + messages.iter().map(Message::encoded_len).sum::<usize>()
1223 }
1224 }
1225
1226 /// Rust doesn't have a `Map` trait, so macros are currently the best way to be
1227 /// generic over `HashMap` and `BTreeMap`.
1228 macro_rules! map {
1229 ($map_ty:ident) => {
1230 use crate::encoding::*;
1231 use core::hash::Hash;
1232
1233 /// Generic protobuf map encode function.
1234 pub fn encode<K, V, B, KE, KL, VE, VL>(
1235 key_encode: KE,
1236 key_encoded_len: KL,
1237 val_encode: VE,
1238 val_encoded_len: VL,
1239 tag: u32,
1240 values: &$map_ty<K, V>,
1241 buf: &mut B,
1242 ) where
1243 K: Default + Eq + Hash + Ord,
1244 V: Default + PartialEq,
1245 B: BufMut,
1246 KE: Fn(u32, &K, &mut B),
1247 KL: Fn(u32, &K) -> usize,
1248 VE: Fn(u32, &V, &mut B),
1249 VL: Fn(u32, &V) -> usize,
1250 {
1251 encode_with_default(
1252 key_encode,
1253 key_encoded_len,
1254 val_encode,
1255 val_encoded_len,
1256 &V::default(),
1257 tag,
1258 values,
1259 buf,
1260 )
1261 }
1262
1263 /// Generic protobuf map merge function.
1264 pub fn merge<K, V, B, KM, VM>(
1265 key_merge: KM,
1266 val_merge: VM,
1267 values: &mut $map_ty<K, V>,
1268 buf: &mut B,
1269 ctx: DecodeContext,
1270 ) -> Result<(), DecodeError>
1271 where
1272 K: Default + Eq + Hash + Ord,
1273 V: Default,
1274 B: Buf,
1275 KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>,
1276 VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>,
1277 {
1278 merge_with_default(key_merge, val_merge, V::default(), values, buf, ctx)
1279 }
1280
1281 /// Generic protobuf map encode function.
1282 pub fn encoded_len<K, V, KL, VL>(
1283 key_encoded_len: KL,
1284 val_encoded_len: VL,
1285 tag: u32,
1286 values: &$map_ty<K, V>,
1287 ) -> usize
1288 where
1289 K: Default + Eq + Hash + Ord,
1290 V: Default + PartialEq,
1291 KL: Fn(u32, &K) -> usize,
1292 VL: Fn(u32, &V) -> usize,
1293 {
1294 encoded_len_with_default(key_encoded_len, val_encoded_len, &V::default(), tag, values)
1295 }
1296
1297 /// Generic protobuf map encode function with an overridden value default.
1298 ///
1299 /// This is necessary because enumeration values can have a default value other
1300 /// than 0 in proto2.
1301 pub fn encode_with_default<K, V, B, KE, KL, VE, VL>(
1302 key_encode: KE,
1303 key_encoded_len: KL,
1304 val_encode: VE,
1305 val_encoded_len: VL,
1306 val_default: &V,
1307 tag: u32,
1308 values: &$map_ty<K, V>,
1309 buf: &mut B,
1310 ) where
1311 K: Default + Eq + Hash + Ord,
1312 V: PartialEq,
1313 B: BufMut,
1314 KE: Fn(u32, &K, &mut B),
1315 KL: Fn(u32, &K) -> usize,
1316 VE: Fn(u32, &V, &mut B),
1317 VL: Fn(u32, &V) -> usize,
1318 {
1319 for (key, val) in values.iter() {
1320 let skip_key = key == &K::default();
1321 let skip_val = val == val_default;
1322
1323 let len = (if skip_key { 0 } else { key_encoded_len(1, key) })
1324 + (if skip_val { 0 } else { val_encoded_len(2, val) });
1325
1326 encode_key(tag, WireType::LengthDelimited, buf);
1327 encode_varint(len as u64, buf);
1328 if !skip_key {
1329 key_encode(1, key, buf);
1330 }
1331 if !skip_val {
1332 val_encode(2, val, buf);
1333 }
1334 }
1335 }
1336
1337 /// Generic protobuf map merge function with an overridden value default.
1338 ///
1339 /// This is necessary because enumeration values can have a default value other
1340 /// than 0 in proto2.
1341 pub fn merge_with_default<K, V, B, KM, VM>(
1342 key_merge: KM,
1343 val_merge: VM,
1344 val_default: V,
1345 values: &mut $map_ty<K, V>,
1346 buf: &mut B,
1347 ctx: DecodeContext,
1348 ) -> Result<(), DecodeError>
1349 where
1350 K: Default + Eq + Hash + Ord,
1351 B: Buf,
1352 KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>,
1353 VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>,
1354 {
1355 let mut key = Default::default();
1356 let mut val = val_default;
1357 ctx.limit_reached()?;
1358 merge_loop(
1359 &mut (&mut key, &mut val),
1360 buf,
1361 ctx.enter_recursion(),
1362 |&mut (ref mut key, ref mut val), buf, ctx| {
1363 let (tag, wire_type) = decode_key(buf)?;
1364 match tag {
1365 1 => key_merge(wire_type, key, buf, ctx),
1366 2 => val_merge(wire_type, val, buf, ctx),
1367 _ => skip_field(wire_type, tag, buf, ctx),
1368 }
1369 },
1370 )?;
1371 values.insert(key, val);
1372
1373 Ok(())
1374 }
1375
1376 /// Generic protobuf map encode function with an overridden value default.
1377 ///
1378 /// This is necessary because enumeration values can have a default value other
1379 /// than 0 in proto2.
1380 pub fn encoded_len_with_default<K, V, KL, VL>(
1381 key_encoded_len: KL,
1382 val_encoded_len: VL,
1383 val_default: &V,
1384 tag: u32,
1385 values: &$map_ty<K, V>,
1386 ) -> usize
1387 where
1388 K: Default + Eq + Hash + Ord,
1389 V: PartialEq,
1390 KL: Fn(u32, &K) -> usize,
1391 VL: Fn(u32, &V) -> usize,
1392 {
1393 key_len(tag) * values.len()
1394 + values
1395 .iter()
1396 .map(|(key, val)| {
1397 let len = (if key == &K::default() {
1398 0
1399 } else {
1400 key_encoded_len(1, key)
1401 }) + (if val == val_default {
1402 0
1403 } else {
1404 val_encoded_len(2, val)
1405 });
1406 encoded_len_varint(len as u64) + len
1407 })
1408 .sum::<usize>()
1409 }
1410 };
1411 }
1412
1413 #[cfg(feature = "std")]
1414 pub mod hash_map {
1415 use std::collections::HashMap;
1416 map!(HashMap);
1417 }
1418
1419 pub mod btree_map {
1420 map!(BTreeMap);
1421 }
1422
1423 #[cfg(test)]
1424 mod test {
1425 use alloc::string::ToString;
1426 use core::borrow::Borrow;
1427 use core::fmt::Debug;
1428 use core::u64;
1429
1430 use ::bytes::{Bytes, BytesMut};
1431 use proptest::{prelude::*, test_runner::TestCaseResult};
1432
1433 use crate::encoding::*;
1434
check_type<T, B>( value: T, tag: u32, wire_type: WireType, encode: fn(u32, &B, &mut BytesMut), merge: fn(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>, encoded_len: fn(u32, &B) -> usize, ) -> TestCaseResult where T: Debug + Default + PartialEq + Borrow<B>, B: ?Sized,1435 pub fn check_type<T, B>(
1436 value: T,
1437 tag: u32,
1438 wire_type: WireType,
1439 encode: fn(u32, &B, &mut BytesMut),
1440 merge: fn(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
1441 encoded_len: fn(u32, &B) -> usize,
1442 ) -> TestCaseResult
1443 where
1444 T: Debug + Default + PartialEq + Borrow<B>,
1445 B: ?Sized,
1446 {
1447 prop_assume!((MIN_TAG..=MAX_TAG).contains(&tag));
1448
1449 let expected_len = encoded_len(tag, value.borrow());
1450
1451 let mut buf = BytesMut::with_capacity(expected_len);
1452 encode(tag, value.borrow(), &mut buf);
1453
1454 let mut buf = buf.freeze();
1455
1456 prop_assert_eq!(
1457 buf.remaining(),
1458 expected_len,
1459 "encoded_len wrong; expected: {}, actual: {}",
1460 expected_len,
1461 buf.remaining()
1462 );
1463
1464 if !buf.has_remaining() {
1465 // Short circuit for empty packed values.
1466 return Ok(());
1467 }
1468
1469 let (decoded_tag, decoded_wire_type) =
1470 decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
1471 prop_assert_eq!(
1472 tag,
1473 decoded_tag,
1474 "decoded tag does not match; expected: {}, actual: {}",
1475 tag,
1476 decoded_tag
1477 );
1478
1479 prop_assert_eq!(
1480 wire_type,
1481 decoded_wire_type,
1482 "decoded wire type does not match; expected: {:?}, actual: {:?}",
1483 wire_type,
1484 decoded_wire_type,
1485 );
1486
1487 match wire_type {
1488 WireType::SixtyFourBit if buf.remaining() != 8 => Err(TestCaseError::fail(format!(
1489 "64bit wire type illegal remaining: {}, tag: {}",
1490 buf.remaining(),
1491 tag
1492 ))),
1493 WireType::ThirtyTwoBit if buf.remaining() != 4 => Err(TestCaseError::fail(format!(
1494 "32bit wire type illegal remaining: {}, tag: {}",
1495 buf.remaining(),
1496 tag
1497 ))),
1498 _ => Ok(()),
1499 }?;
1500
1501 let mut roundtrip_value = T::default();
1502 merge(
1503 wire_type,
1504 &mut roundtrip_value,
1505 &mut buf,
1506 DecodeContext::default(),
1507 )
1508 .map_err(|error| TestCaseError::fail(error.to_string()))?;
1509
1510 prop_assert!(
1511 !buf.has_remaining(),
1512 "expected buffer to be empty, remaining: {}",
1513 buf.remaining()
1514 );
1515
1516 prop_assert_eq!(value, roundtrip_value);
1517
1518 Ok(())
1519 }
1520
check_collection_type<T, B, E, M, L>( value: T, tag: u32, wire_type: WireType, encode: E, mut merge: M, encoded_len: L, ) -> TestCaseResult where T: Debug + Default + PartialEq + Borrow<B>, B: ?Sized, E: FnOnce(u32, &B, &mut BytesMut), M: FnMut(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>, L: FnOnce(u32, &B) -> usize,1521 pub fn check_collection_type<T, B, E, M, L>(
1522 value: T,
1523 tag: u32,
1524 wire_type: WireType,
1525 encode: E,
1526 mut merge: M,
1527 encoded_len: L,
1528 ) -> TestCaseResult
1529 where
1530 T: Debug + Default + PartialEq + Borrow<B>,
1531 B: ?Sized,
1532 E: FnOnce(u32, &B, &mut BytesMut),
1533 M: FnMut(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
1534 L: FnOnce(u32, &B) -> usize,
1535 {
1536 prop_assume!((MIN_TAG..=MAX_TAG).contains(&tag));
1537
1538 let expected_len = encoded_len(tag, value.borrow());
1539
1540 let mut buf = BytesMut::with_capacity(expected_len);
1541 encode(tag, value.borrow(), &mut buf);
1542
1543 let mut buf = buf.freeze();
1544
1545 prop_assert_eq!(
1546 buf.remaining(),
1547 expected_len,
1548 "encoded_len wrong; expected: {}, actual: {}",
1549 expected_len,
1550 buf.remaining()
1551 );
1552
1553 let mut roundtrip_value = Default::default();
1554 while buf.has_remaining() {
1555 let (decoded_tag, decoded_wire_type) =
1556 decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
1557
1558 prop_assert_eq!(
1559 tag,
1560 decoded_tag,
1561 "decoded tag does not match; expected: {}, actual: {}",
1562 tag,
1563 decoded_tag
1564 );
1565
1566 prop_assert_eq!(
1567 wire_type,
1568 decoded_wire_type,
1569 "decoded wire type does not match; expected: {:?}, actual: {:?}",
1570 wire_type,
1571 decoded_wire_type
1572 );
1573
1574 merge(
1575 wire_type,
1576 &mut roundtrip_value,
1577 &mut buf,
1578 DecodeContext::default(),
1579 )
1580 .map_err(|error| TestCaseError::fail(error.to_string()))?;
1581 }
1582
1583 prop_assert_eq!(value, roundtrip_value);
1584
1585 Ok(())
1586 }
1587
1588 #[test]
string_merge_invalid_utf8()1589 fn string_merge_invalid_utf8() {
1590 let mut s = String::new();
1591 let buf = b"\x02\x80\x80";
1592
1593 let r = string::merge(
1594 WireType::LengthDelimited,
1595 &mut s,
1596 &mut &buf[..],
1597 DecodeContext::default(),
1598 );
1599 r.expect_err("must be an error");
1600 assert!(s.is_empty());
1601 }
1602
1603 #[test]
varint()1604 fn varint() {
1605 fn check(value: u64, mut encoded: &[u8]) {
1606 // Small buffer.
1607 let mut buf = Vec::with_capacity(1);
1608 encode_varint(value, &mut buf);
1609 assert_eq!(buf, encoded);
1610
1611 // Large buffer.
1612 let mut buf = Vec::with_capacity(100);
1613 encode_varint(value, &mut buf);
1614 assert_eq!(buf, encoded);
1615
1616 assert_eq!(encoded_len_varint(value), encoded.len());
1617
1618 let roundtrip_value =
1619 decode_varint(&mut <&[u8]>::clone(&encoded)).expect("decoding failed");
1620 assert_eq!(value, roundtrip_value);
1621
1622 let roundtrip_value = decode_varint_slow(&mut encoded).expect("slow decoding failed");
1623 assert_eq!(value, roundtrip_value);
1624 }
1625
1626 check(2u64.pow(0) - 1, &[0x00]);
1627 check(2u64.pow(0), &[0x01]);
1628
1629 check(2u64.pow(7) - 1, &[0x7F]);
1630 check(2u64.pow(7), &[0x80, 0x01]);
1631 check(300, &[0xAC, 0x02]);
1632
1633 check(2u64.pow(14) - 1, &[0xFF, 0x7F]);
1634 check(2u64.pow(14), &[0x80, 0x80, 0x01]);
1635
1636 check(2u64.pow(21) - 1, &[0xFF, 0xFF, 0x7F]);
1637 check(2u64.pow(21), &[0x80, 0x80, 0x80, 0x01]);
1638
1639 check(2u64.pow(28) - 1, &[0xFF, 0xFF, 0xFF, 0x7F]);
1640 check(2u64.pow(28), &[0x80, 0x80, 0x80, 0x80, 0x01]);
1641
1642 check(2u64.pow(35) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
1643 check(2u64.pow(35), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);
1644
1645 check(2u64.pow(42) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
1646 check(2u64.pow(42), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);
1647
1648 check(
1649 2u64.pow(49) - 1,
1650 &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
1651 );
1652 check(
1653 2u64.pow(49),
1654 &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
1655 );
1656
1657 check(
1658 2u64.pow(56) - 1,
1659 &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
1660 );
1661 check(
1662 2u64.pow(56),
1663 &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
1664 );
1665
1666 check(
1667 2u64.pow(63) - 1,
1668 &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
1669 );
1670 check(
1671 2u64.pow(63),
1672 &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
1673 );
1674
1675 check(
1676 u64::MAX,
1677 &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01],
1678 );
1679 }
1680
1681 #[test]
varint_overflow()1682 fn varint_overflow() {
1683 let mut u64_max_plus_one: &[u8] =
1684 &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x02];
1685
1686 decode_varint(&mut u64_max_plus_one).expect_err("decoding u64::MAX + 1 succeeded");
1687 decode_varint_slow(&mut u64_max_plus_one)
1688 .expect_err("slow decoding u64::MAX + 1 succeeded");
1689 }
1690
1691 /// This big bowl o' macro soup generates an encoding property test for each combination of map
1692 /// type, scalar map key, and value type.
1693 /// TODO: these tests take a long time to compile, can this be improved?
1694 #[cfg(feature = "std")]
1695 macro_rules! map_tests {
1696 (keys: $keys:tt,
1697 vals: $vals:tt) => {
1698 mod hash_map {
1699 map_tests!(@private HashMap, hash_map, $keys, $vals);
1700 }
1701 mod btree_map {
1702 map_tests!(@private BTreeMap, btree_map, $keys, $vals);
1703 }
1704 };
1705
1706 (@private $map_type:ident,
1707 $mod_name:ident,
1708 [$(($key_ty:ty, $key_proto:ident)),*],
1709 $vals:tt) => {
1710 $(
1711 mod $key_proto {
1712 use std::collections::$map_type;
1713
1714 use proptest::prelude::*;
1715
1716 use crate::encoding::*;
1717 use crate::encoding::test::check_collection_type;
1718
1719 map_tests!(@private $map_type, $mod_name, ($key_ty, $key_proto), $vals);
1720 }
1721 )*
1722 };
1723
1724 (@private $map_type:ident,
1725 $mod_name:ident,
1726 ($key_ty:ty, $key_proto:ident),
1727 [$(($val_ty:ty, $val_proto:ident)),*]) => {
1728 $(
1729 proptest! {
1730 #[test]
1731 fn $val_proto(values: $map_type<$key_ty, $val_ty>, tag in MIN_TAG..=MAX_TAG) {
1732 check_collection_type(values, tag, WireType::LengthDelimited,
1733 |tag, values, buf| {
1734 $mod_name::encode($key_proto::encode,
1735 $key_proto::encoded_len,
1736 $val_proto::encode,
1737 $val_proto::encoded_len,
1738 tag,
1739 values,
1740 buf)
1741 },
1742 |wire_type, values, buf, ctx| {
1743 check_wire_type(WireType::LengthDelimited, wire_type)?;
1744 $mod_name::merge($key_proto::merge,
1745 $val_proto::merge,
1746 values,
1747 buf,
1748 ctx)
1749 },
1750 |tag, values| {
1751 $mod_name::encoded_len($key_proto::encoded_len,
1752 $val_proto::encoded_len,
1753 tag,
1754 values)
1755 })?;
1756 }
1757 }
1758 )*
1759 };
1760 }
1761
1762 #[cfg(feature = "std")]
1763 map_tests!(keys: [
1764 (i32, int32),
1765 (i64, int64),
1766 (u32, uint32),
1767 (u64, uint64),
1768 (i32, sint32),
1769 (i64, sint64),
1770 (u32, fixed32),
1771 (u64, fixed64),
1772 (i32, sfixed32),
1773 (i64, sfixed64),
1774 (bool, bool),
1775 (String, string)
1776 ],
1777 vals: [
1778 (f32, float),
1779 (f64, double),
1780 (i32, int32),
1781 (i64, int64),
1782 (u32, uint32),
1783 (u64, uint64),
1784 (i32, sint32),
1785 (i64, sint64),
1786 (u32, fixed32),
1787 (u64, fixed64),
1788 (i32, sfixed32),
1789 (i64, sfixed64),
1790 (bool, bool),
1791 (String, string),
1792 (Vec<u8>, bytes)
1793 ]);
1794 }
1795