1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use crate::{MlsDecode, MlsEncode, MlsSize};
6 use alloc::vec::Vec;
7 
8 macro_rules! impl_stdint {
9     ($t:ty) => {
10         impl MlsSize for $t {
11             fn mls_encoded_len(&self) -> usize {
12                 core::mem::size_of::<$t>()
13             }
14         }
15 
16         impl MlsEncode for $t {
17             fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), crate::Error> {
18                 writer.extend(self.to_be_bytes());
19                 Ok(())
20             }
21         }
22 
23         impl MlsDecode for $t {
24             fn mls_decode(reader: &mut &[u8]) -> Result<Self, crate::Error> {
25                 MlsDecode::mls_decode(reader).map(<$t>::from_be_bytes)
26             }
27         }
28     };
29 }
30 
31 impl_stdint!(u8);
32 impl_stdint!(u16);
33 impl_stdint!(u32);
34 impl_stdint!(u64);
35 impl_stdint!(u128);
36 
37 #[cfg(test)]
38 mod tests {
39     #[cfg(target_arch = "wasm32")]
40     use wasm_bindgen_test::wasm_bindgen_test as test;
41 
42     use crate::{MlsDecode, MlsEncode};
43 
44     use alloc::vec;
45 
46     #[test]
u8_round_trip()47     fn u8_round_trip() {
48         let serialized = 42u8.mls_encode_to_vec().unwrap();
49         assert_eq!(serialized, vec![42u8]);
50 
51         let recovered = u8::mls_decode(&mut &*serialized).unwrap();
52 
53         assert_eq!(recovered, 42u8);
54     }
55 
56     #[test]
u16_round_trip()57     fn u16_round_trip() {
58         let serialized = 1024u16.mls_encode_to_vec().unwrap();
59         assert_eq!(serialized, vec![4, 0]);
60 
61         let recovered = u16::mls_decode(&mut &*serialized).unwrap();
62 
63         assert_eq!(recovered, 1024u16);
64     }
65 
66     #[test]
u32_round_trip()67     fn u32_round_trip() {
68         let serialized = 1000000u32.mls_encode_to_vec().unwrap();
69         assert_eq!(serialized, vec![0, 15, 66, 64]);
70 
71         let recovered = u32::mls_decode(&mut &*serialized).unwrap();
72 
73         assert_eq!(recovered, 1000000u32);
74     }
75 
76     #[test]
u64_round_trip()77     fn u64_round_trip() {
78         let serialized = 100000000000u64.mls_encode_to_vec().unwrap();
79         assert_eq!(serialized, vec![0, 0, 0, 23, 72, 118, 232, 0]);
80 
81         let recovered = u64::mls_decode(&mut &*serialized).unwrap();
82 
83         assert_eq!(recovered, 100000000000u64);
84     }
85 
86     #[test]
u128_round_trip()87     fn u128_round_trip() {
88         let serialized = 10000000000000000u128.mls_encode_to_vec().unwrap();
89         assert_eq!(
90             serialized,
91             vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 134, 242, 111, 193, 0, 0]
92         );
93 
94         let recovered = u128::mls_decode(&mut &*serialized).unwrap();
95 
96         assert_eq!(recovered, 10000000000000000u128);
97     }
98 }
99