1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use crate::{MlsDecode, MlsEncode, MlsSize};
6 use alloc::vec::Vec;
7 
8 impl<T: MlsSize> MlsSize for Option<T> {
9     #[inline]
mls_encoded_len(&self) -> usize10     fn mls_encoded_len(&self) -> usize {
11         1 + match self {
12             Some(v) => v.mls_encoded_len(),
13             None => 0,
14         }
15     }
16 }
17 
18 impl<T: MlsEncode> MlsEncode for Option<T> {
mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), crate::Error>19     fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), crate::Error> {
20         if let Some(item) = self {
21             writer.push(1);
22             item.mls_encode(writer)
23         } else {
24             writer.push(0);
25             Ok(())
26         }
27     }
28 }
29 
30 impl<T: MlsDecode> MlsDecode for Option<T> {
mls_decode(reader: &mut &[u8]) -> Result<Self, crate::Error>31     fn mls_decode(reader: &mut &[u8]) -> Result<Self, crate::Error> {
32         match u8::mls_decode(reader)? {
33             0 => Ok(None),
34             1 => T::mls_decode(reader).map(Some),
35             n => Err(crate::Error::OptionOutOfRange(n)),
36         }
37     }
38 }
39 
40 #[cfg(test)]
41 mod tests {
42     use alloc::vec;
43 
44     use crate::{Error, MlsDecode, MlsEncode};
45     use assert_matches::assert_matches;
46 
47     #[cfg(target_arch = "wasm32")]
48     use wasm_bindgen_test::wasm_bindgen_test as test;
49 
50     #[test]
none_is_serialized_correctly()51     fn none_is_serialized_correctly() {
52         assert_eq!(vec![0u8], None::<u8>.mls_encode_to_vec().unwrap());
53     }
54 
55     #[test]
some_is_serialized_correctly()56     fn some_is_serialized_correctly() {
57         assert_eq!(vec![1u8, 2], Some(2u8).mls_encode_to_vec().unwrap());
58     }
59 
60     #[test]
none_round_trips()61     fn none_round_trips() {
62         let val = None::<u8>;
63         let x = val.mls_encode_to_vec().unwrap();
64         assert_eq!(val, Option::mls_decode(&mut &*x).unwrap());
65     }
66 
67     #[test]
some_round_trips()68     fn some_round_trips() {
69         let val = Some(32u8);
70         let x = val.mls_encode_to_vec().unwrap();
71         assert_eq!(val, Option::mls_decode(&mut &*x).unwrap());
72     }
73 
74     #[test]
deserializing_invalid_discriminant_fails()75     fn deserializing_invalid_discriminant_fails() {
76         assert_matches!(
77             Option::<u8>::mls_decode(&mut &[2u8][..]),
78             Err(Error::OptionOutOfRange(_))
79         );
80     }
81 }
82