1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 /// Padding used when sending an encrypted group message.
6 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)]
7 #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8 #[repr(u8)]
9 pub enum PaddingMode {
10     /// Step function based on the size of the message being sent.
11     /// The amount of padding used will increase with the size of the original
12     /// message.
13     #[default]
14     StepFunction,
15     /// No padding.
16     None,
17 }
18 
19 impl PaddingMode {
padded_size(&self, content_size: usize) -> usize20     pub(super) fn padded_size(&self, content_size: usize) -> usize {
21         match self {
22             PaddingMode::StepFunction => {
23                 // The padding hides all but 2 most significant bits of `length`. The hidden bits are replaced
24                 // by zeros and then the next number is taken to make sure the message fits.
25                 let blind = 1
26                     << ((content_size + 1)
27                         .next_power_of_two()
28                         .max(256)
29                         .trailing_zeros()
30                         - 3);
31 
32                 (content_size | (blind - 1)) + 1
33             }
34             PaddingMode::None => content_size,
35         }
36     }
37 }
38 
39 #[cfg(test)]
40 mod tests {
41     use super::PaddingMode;
42 
43     use alloc::vec;
44     use alloc::vec::Vec;
45     #[cfg(target_arch = "wasm32")]
46     use wasm_bindgen_test::wasm_bindgen_test as test;
47 
48     #[derive(serde::Deserialize, serde::Serialize)]
49     struct TestCase {
50         input: usize,
51         output: usize,
52     }
53 
54     #[cfg_attr(coverage_nightly, coverage(off))]
generate_message_padding_test_vector() -> Vec<TestCase>55     fn generate_message_padding_test_vector() -> Vec<TestCase> {
56         let mut test_cases = vec![];
57         for x in 1..1024 {
58             test_cases.push(TestCase {
59                 input: x,
60                 output: PaddingMode::StepFunction.padded_size(x),
61             });
62         }
63         test_cases
64     }
65 
load_test_cases() -> Vec<TestCase>66     fn load_test_cases() -> Vec<TestCase> {
67         load_test_case_json!(
68             message_padding_test_vector,
69             generate_message_padding_test_vector()
70         )
71     }
72 
73     #[test]
test_no_padding()74     fn test_no_padding() {
75         for i in [0, 100, 1000, 10000] {
76             assert_eq!(PaddingMode::None.padded_size(i), i)
77         }
78     }
79 
80     #[test]
test_padding_length()81     fn test_padding_length() {
82         assert_eq!(PaddingMode::StepFunction.padded_size(0), 32);
83 
84         // Short
85         assert_eq!(PaddingMode::StepFunction.padded_size(63), 64);
86         assert_eq!(PaddingMode::StepFunction.padded_size(64), 96);
87         assert_eq!(PaddingMode::StepFunction.padded_size(65), 96);
88 
89         // Almost long and almost short
90         assert_eq!(PaddingMode::StepFunction.padded_size(127), 128);
91         assert_eq!(PaddingMode::StepFunction.padded_size(128), 160);
92         assert_eq!(PaddingMode::StepFunction.padded_size(129), 160);
93 
94         // One length from each of the 4 buckets between 256 and 512
95         assert_eq!(PaddingMode::StepFunction.padded_size(260), 320);
96         assert_eq!(PaddingMode::StepFunction.padded_size(330), 384);
97         assert_eq!(PaddingMode::StepFunction.padded_size(390), 448);
98         assert_eq!(PaddingMode::StepFunction.padded_size(490), 512);
99 
100         // All test cases
101         let test_cases: Vec<TestCase> = load_test_cases();
102         for test_case in test_cases {
103             assert_eq!(
104                 test_case.output,
105                 PaddingMode::StepFunction.padded_size(test_case.input)
106             );
107         }
108     }
109 }
110