1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
6 
7 #[derive(Debug, Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
8 struct TestTupleStruct(u64);
9 
10 #[derive(Debug, Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
11 struct TestFieldStruct {
12     item1: Option<u8>,
13     item2: u64,
14 }
15 
16 #[derive(Debug, Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
17 struct TestType {
18     field_a: u8,
19     field_b: Vec<u8>,
20     field_c: u16,
21     field_d: Option<Vec<u8>>,
22     field_e: u32,
23     field_f: Option<u16>,
24     field_g: Vec<TestTupleStruct>,
25     field_h: TestFieldStruct,
26 }
27 
28 #[derive(Debug, Clone, PartialEq, Eq, MlsSize, MlsEncode)]
29 struct BorrowedTestType<'a> {
30     field_a: u8,
31     field_b: Option<&'a [u8]>,
32     field_c: &'a [u16],
33 }
34 
35 #[repr(u16)]
36 #[derive(Debug, Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
37 enum TestEnum {
38     Case1 = 1u16,
39     Case2(TestFieldStruct) = 200u16,
40     Case3(TestTupleStruct) = 42u16,
41 }
42 
43 #[derive(Debug, Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
44 #[repr(u8)]
45 enum TestEnumWithoutSuffixedLiterals {
46     Case1 = 1,
47     Case2(TestFieldStruct) = 200,
48     Case3(TestTupleStruct) = 42,
49 }
50 
51 #[derive(Debug, Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
52 struct TestGeneric<T: MlsSize + MlsEncode + MlsDecode>(T);
53 
54 #[test]
round_trip_struct_encode()55 fn round_trip_struct_encode() {
56     let item = TestType {
57         field_a: 42,
58         field_b: vec![1, 3, 5, 7, 9],
59         field_c: 65000,
60         field_d: Some(vec![0, 2, 4, 6, 8]),
61         field_e: 1000000,
62         field_f: None,
63         field_g: vec![
64             TestTupleStruct(100),
65             TestTupleStruct(200),
66             TestTupleStruct(300),
67         ],
68         field_h: TestFieldStruct {
69             item1: Some(42),
70             item2: 84,
71         },
72     };
73 
74     let data = item.mls_encode_to_vec().unwrap();
75     let restored = TestType::mls_decode(&mut &*data).unwrap();
76 
77     assert_eq!(restored, item);
78 }
79 
80 #[test]
round_trip_generic_encode()81 fn round_trip_generic_encode() {
82     let item = TestGeneric(42u16);
83     let data = item.mls_encode_to_vec().unwrap();
84     let restored = TestGeneric::mls_decode(&mut &*data).unwrap();
85 
86     assert_eq!(restored, item);
87 }
88 
89 #[test]
round_trip_enum_encode_simple()90 fn round_trip_enum_encode_simple() {
91     let item = TestEnum::Case1;
92 
93     let serialized = item.mls_encode_to_vec().unwrap();
94     let decoded = TestEnum::mls_decode(&mut &*serialized).unwrap();
95 
96     assert_eq!(serialized, 1u16.mls_encode_to_vec().unwrap());
97     assert_eq!(decoded, item);
98 }
99 
100 #[test]
round_trip_enum_encode_one_field()101 fn round_trip_enum_encode_one_field() {
102     let item = TestEnum::Case2(TestFieldStruct {
103         item1: None,
104         item2: 42,
105     });
106 
107     let serialized = item.mls_encode_to_vec().unwrap();
108     let decoded = TestEnum::mls_decode(&mut &*serialized).unwrap();
109 
110     assert_eq!(decoded, item);
111 }
112 
113 #[test]
round_trip_enum_encode_one_tuple()114 fn round_trip_enum_encode_one_tuple() {
115     let item = TestEnum::Case3(TestTupleStruct(42));
116 
117     let serialized = item.mls_encode_to_vec().unwrap();
118     let decoded = TestEnum::mls_decode(&mut &*serialized).unwrap();
119 
120     assert_eq!(decoded, item);
121 }
122 
123 #[test]
round_trip_custom_module_struct()124 fn round_trip_custom_module_struct() {
125     #[derive(Debug, PartialEq, Eq, Clone, MlsSize, MlsEncode, MlsDecode)]
126     struct TestCustomStruct {
127         #[mls_codec(with = "self::test_with")]
128         value: u8,
129     }
130 
131     let item = TestCustomStruct { value: 33 };
132 
133     let serialized = item.mls_encode_to_vec().unwrap();
134     assert_eq!(serialized.len(), 2);
135 
136     let decoded = TestCustomStruct::mls_decode(&mut &*serialized).unwrap();
137     assert_eq!(item, decoded);
138 }
139 
140 #[test]
round_trip_custom_module_enum()141 fn round_trip_custom_module_enum() {
142     #[derive(Debug, PartialEq, Eq, Clone, MlsSize, MlsEncode, MlsDecode)]
143     #[repr(u16)]
144     enum TestCustomEnum {
145         CustomCase(#[mls_codec(with = "self::test_with")] u8) = 2u16,
146     }
147 
148     let item = TestCustomEnum::CustomCase(33);
149 
150     let serialized = item.mls_encode_to_vec().unwrap();
151     assert_eq!(serialized.len(), 4);
152 
153     let decoded = TestCustomEnum::mls_decode(&mut &*serialized).unwrap();
154     assert_eq!(item, decoded)
155 }
156 
157 mod test_with {
158     use mls_rs_codec::MlsDecode;
159 
mls_encoded_len(_val: &u8) -> usize160     pub fn mls_encoded_len(_val: &u8) -> usize {
161         2
162     }
163 
mls_encode(val: &u8, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error>164     pub fn mls_encode(val: &u8, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
165         writer.extend([*val, 42]);
166         Ok(())
167     }
168 
mls_decode(reader: &mut &[u8]) -> Result<u8, mls_rs_codec::Error>169     pub fn mls_decode(reader: &mut &[u8]) -> Result<u8, mls_rs_codec::Error> {
170         Ok(<[u8; 2]>::mls_decode(reader)?[0])
171     }
172 }
173