xref: /aosp_15_r20/system/authgraph/wire/src/fragmentation.rs (revision 4185b0660fbe514985fdcf75410317caad8afad1)
1 // Copyright 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 //! Helpers for message fragmentation and reassembly.
16 
17 use alloc::borrow::Cow;
18 use alloc::vec::Vec;
19 
20 /// Prefix byte indicating more data is to come.
21 const PREFIX_MORE_TO_COME: u8 = 0xcc; // 'ccontinues'
22 /// Prefix byte indicating that this is the final fragment.
23 const PREFIX_FINAL_FRAGMENT: u8 = 0xdd; // 'ddone'
24 
25 /// Empty placeholder message indicating more data is due.
26 pub const PLACEHOLDER_MORE_TO_COME: &[u8] = &[PREFIX_MORE_TO_COME];
27 
28 /// Helper to emit a single message in fragments.
29 pub struct Fragmenter<'a> {
30     data: &'a [u8],
31     max_size: usize,
32 }
33 
34 impl<'a> Fragmenter<'a> {
35     /// Create a fragmentation iterator for the given data.
new(data: &'a [u8], max_size: usize) -> Fragmenter<'a>36     pub fn new(data: &'a [u8], max_size: usize) -> Fragmenter<'a> {
37         assert!(max_size > 1);
38         Self { data, max_size }
39     }
40 }
41 
42 impl<'a> Iterator for Fragmenter<'a> {
43     type Item = Vec<u8>;
next(&mut self) -> Option<Self::Item>44     fn next(&mut self) -> Option<Self::Item> {
45         if self.data.is_empty() {
46             None
47         } else {
48             let consume = core::cmp::min(self.max_size - 1, self.data.len());
49             let marker =
50                 if consume < self.data.len() { PREFIX_MORE_TO_COME } else { PREFIX_FINAL_FRAGMENT };
51             let mut result = Vec::with_capacity(consume + 1);
52             result.push(marker);
53             result.extend_from_slice(&self.data[..consume]);
54             self.data = &self.data[consume..];
55             Some(result)
56         }
57     }
58 }
59 
60 /// Buffer to accumulate fragmented messages.
61 #[derive(Default)]
62 pub struct Reassembler(Vec<u8>);
63 
64 impl Reassembler {
65     /// Accumulate message data, possibly resulting in a complete message.
accumulate<'a>(&mut self, frag: &'a [u8]) -> Option<Cow<'a, [u8]>>66     pub fn accumulate<'a>(&mut self, frag: &'a [u8]) -> Option<Cow<'a, [u8]>> {
67         let (more, content) = Self::split_msg(frag);
68         if more {
69             // More to come, so accumulate this data and return empty response.
70             self.0.extend_from_slice(content);
71             None
72         } else if self.0.is_empty() {
73             // For shorter messages (the mainline case) we can directly pass through the single
74             // message's content.
75             Some(Cow::Borrowed(content))
76         } else {
77             // Process the accumulated full request as an owned vector
78             let mut full_req = core::mem::take(&mut self.0);
79             full_req.extend_from_slice(content);
80             Some(Cow::Owned(full_req))
81         }
82     }
83 
84     /// Split a message into an indication of whether more data is to come, and the content.
85     ///
86     /// # Panics
87     ///
88     /// This function panics if the provided message fragment has an unexpected message prefix.
split_msg(data: &[u8]) -> (bool, &[u8])89     fn split_msg(data: &[u8]) -> (bool, &[u8]) {
90         if data.is_empty() {
91             (false, data)
92         } else {
93             match data[0] {
94                 PREFIX_MORE_TO_COME => (true, &data[1..]),
95                 PREFIX_FINAL_FRAGMENT => (false, &data[1..]),
96                 _ => panic!("data fragment with incorrect prefix"),
97             }
98         }
99     }
100 }
101 
102 #[cfg(test)]
103 mod tests {
104     use super::*;
105     use alloc::string::{String, ToString};
106     use alloc::vec;
107     use core::cell::RefCell;
108 
109     #[test]
test_fragmentation()110     fn test_fragmentation() {
111         let tests = [
112             (
113                 "a0a1a2a3a4a5a6a7a8a9",
114                 2,
115                 vec![
116                     "cca0", "cca1", "cca2", "cca3", "cca4", "cca5", "cca6", "cca7", "cca8", "dda9",
117                 ],
118             ),
119             ("a0a1a2a3a4a5a6a7a8a9", 5, vec!["cca0a1a2a3", "cca4a5a6a7", "dda8a9"]),
120             ("a0a1a2a3a4a5a6a7a8a9", 9, vec!["cca0a1a2a3a4a5a6a7", "dda8a9"]),
121             ("a0a1a2a3a4a5a6a7a8a9", 80, vec!["dda0a1a2a3a4a5a6a7a8a9"]),
122         ];
123         for (input, max_size, want) in &tests {
124             let data = hex::decode(input).unwrap();
125             let fragmenter = Fragmenter::new(&data, *max_size);
126             let got: Vec<String> = fragmenter.map(hex::encode).collect();
127             let want: Vec<String> = want.iter().map(|s| s.to_string()).collect();
128             assert_eq!(got, want, "for input {input} max_size {max_size}");
129         }
130     }
131 
132     #[test]
133     #[should_panic]
test_reassembly_wrong_prefix()134     fn test_reassembly_wrong_prefix() {
135         // Failure case: unexpected marker byte
136         let mut pending = Reassembler::default();
137         let _ = pending.accumulate(&[0x00, 0x01, 0x02, 0x03]);
138     }
139 
140     #[test]
test_reassembly()141     fn test_reassembly() {
142         let tests = [
143             // Single messages
144             (vec!["dd"], ""),
145             (vec!["dd0000"], "0000"),
146             (vec!["dd010203"], "010203"),
147             // Multipart messages.
148             (vec!["cc0102", "dd0304"], "01020304"),
149             (vec!["cc01", "cc02", "dd0304"], "01020304"),
150             (vec!["cc", "cc02", "dd0304"], "020304"),
151             // Failure case: empty message (no marker byte)
152             (vec![], ""),
153         ];
154         for (frags, want) in &tests {
155             let mut done = false;
156             let mut pending = Reassembler::default();
157             for frag in frags {
158                 assert!(!done, "left over fragments found");
159                 let frag = hex::decode(frag).unwrap();
160                 let result = pending.accumulate(&frag);
161                 if let Some(got) = result {
162                     assert_eq!(&hex::encode(got), want, "for input {frags:?}");
163                     done = true;
164                 }
165             }
166         }
167     }
168 
169     #[test]
test_fragmentation_reassembly()170     fn test_fragmentation_reassembly() {
171         let input = "a0a1a2a3a4a5a6a7a8a9b0b1b2b3b4b5b6b7b8b9c0c1c2c3c4c5c6c7c8c9";
172         let data = hex::decode(input).unwrap();
173         for max_size in 2..data.len() + 2 {
174             let fragmenter = Fragmenter::new(&data, max_size);
175             let mut done = false;
176             let mut pending = Reassembler::default();
177             for frag in fragmenter {
178                 assert!(!done, "left over fragments found");
179                 let result = pending.accumulate(&frag);
180                 if let Some(got) = result {
181                     assert_eq!(&hex::encode(got), input, "for max_size {max_size}");
182                     done = true;
183                 }
184             }
185             assert!(done);
186         }
187     }
188 
189     #[test]
test_ta_fragmentation_wrapper()190     fn test_ta_fragmentation_wrapper() {
191         // Simulate a `send()` standalone function for responses.
192         let rsp_reassembler = RefCell::new(Reassembler::default());
193         let full_rsp: RefCell<Option<Vec<u8>>> = RefCell::new(None);
194         let send = |data: &[u8]| {
195             if let Some(msg) = rsp_reassembler.borrow_mut().accumulate(data) {
196                 *full_rsp.borrow_mut() = Some(msg.to_vec());
197             }
198         };
199 
200         // Simulate a TA.
201         struct Ta {
202             pending_req: Reassembler,
203             max_size: usize,
204         }
205         impl Ta {
206             // Request fragment, response fragments emitted via `send`.
207             fn process_fragment<S: Fn(&[u8])>(&mut self, req_frag: &[u8], send: S) {
208                 // Accumulate request fragments until able to feed complete request to `process`.
209                 if let Some(full_req) = self.pending_req.accumulate(req_frag) {
210                     // Full request message is available, invoke the callback to get a full
211                     // response.
212                     let full_rsp = self.process(&full_req);
213                     for rsp_frag in Fragmenter::new(&full_rsp, self.max_size) {
214                         send(&rsp_frag);
215                     }
216                 }
217             }
218             // Full request to full response.
219             fn process(&self, req: &[u8]) -> Vec<u8> {
220                 // Simulate processing a request by echoing it back as a response.
221                 req.to_vec()
222             }
223         }
224 
225         let req = "a0a1a2a3a4a5a6a7a8a9b0b1b2b3b4b5b6b7b8b9c0c1c2c3c4c5c6c7c8c9";
226         let req_data = hex::decode(req).unwrap();
227         for max_size in 2..req_data.len() + 2 {
228             let mut ta = Ta { pending_req: Default::default(), max_size };
229             // Simulate multiple fragmented messages arriving at the TA.
230             for _msg_idx in 0..3 {
231                 // Reset the received-response buffer.
232                 *rsp_reassembler.borrow_mut() = Reassembler::default();
233                 *full_rsp.borrow_mut() = None;
234 
235                 for req_frag in Fragmenter::new(&req_data, max_size) {
236                     assert!(full_rsp.borrow().is_none(), "left over fragments found");
237                     ta.process_fragment(&req_frag, send);
238                 }
239                 // After all the request fragments have been sent in, expect to have a complete
240                 // response.
241                 if let Some(rsp) = full_rsp.borrow().as_ref() {
242                     assert_eq!(hex::encode(rsp), req);
243                 } else {
244                     panic!("no response received");
245                 }
246             }
247         }
248     }
249 }
250