1 // Copyright 2024 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 //! Helpers for message fragmentation and reassembly. 16 17 use alloc::borrow::Cow; 18 use alloc::vec::Vec; 19 20 /// Prefix byte indicating more data is to come. 21 const PREFIX_MORE_TO_COME: u8 = 0xcc; // 'ccontinues' 22 /// Prefix byte indicating that this is the final fragment. 23 const PREFIX_FINAL_FRAGMENT: u8 = 0xdd; // 'ddone' 24 25 /// Empty placeholder message indicating more data is due. 26 pub const PLACEHOLDER_MORE_TO_COME: &[u8] = &[PREFIX_MORE_TO_COME]; 27 28 /// Helper to emit a single message in fragments. 29 pub struct Fragmenter<'a> { 30 data: &'a [u8], 31 max_size: usize, 32 } 33 34 impl<'a> Fragmenter<'a> { 35 /// Create a fragmentation iterator for the given data. new(data: &'a [u8], max_size: usize) -> Fragmenter<'a>36 pub fn new(data: &'a [u8], max_size: usize) -> Fragmenter<'a> { 37 assert!(max_size > 1); 38 Self { data, max_size } 39 } 40 } 41 42 impl<'a> Iterator for Fragmenter<'a> { 43 type Item = Vec<u8>; next(&mut self) -> Option<Self::Item>44 fn next(&mut self) -> Option<Self::Item> { 45 if self.data.is_empty() { 46 None 47 } else { 48 let consume = core::cmp::min(self.max_size - 1, self.data.len()); 49 let marker = 50 if consume < self.data.len() { PREFIX_MORE_TO_COME } else { PREFIX_FINAL_FRAGMENT }; 51 let mut result = Vec::with_capacity(consume + 1); 52 result.push(marker); 53 result.extend_from_slice(&self.data[..consume]); 54 self.data = &self.data[consume..]; 55 Some(result) 56 } 57 } 58 } 59 60 /// Buffer to accumulate fragmented messages. 61 #[derive(Default)] 62 pub struct Reassembler(Vec<u8>); 63 64 impl Reassembler { 65 /// Accumulate message data, possibly resulting in a complete message. accumulate<'a>(&mut self, frag: &'a [u8]) -> Option<Cow<'a, [u8]>>66 pub fn accumulate<'a>(&mut self, frag: &'a [u8]) -> Option<Cow<'a, [u8]>> { 67 let (more, content) = Self::split_msg(frag); 68 if more { 69 // More to come, so accumulate this data and return empty response. 70 self.0.extend_from_slice(content); 71 None 72 } else if self.0.is_empty() { 73 // For shorter messages (the mainline case) we can directly pass through the single 74 // message's content. 75 Some(Cow::Borrowed(content)) 76 } else { 77 // Process the accumulated full request as an owned vector 78 let mut full_req = core::mem::take(&mut self.0); 79 full_req.extend_from_slice(content); 80 Some(Cow::Owned(full_req)) 81 } 82 } 83 84 /// Split a message into an indication of whether more data is to come, and the content. 85 /// 86 /// # Panics 87 /// 88 /// This function panics if the provided message fragment has an unexpected message prefix. split_msg(data: &[u8]) -> (bool, &[u8])89 fn split_msg(data: &[u8]) -> (bool, &[u8]) { 90 if data.is_empty() { 91 (false, data) 92 } else { 93 match data[0] { 94 PREFIX_MORE_TO_COME => (true, &data[1..]), 95 PREFIX_FINAL_FRAGMENT => (false, &data[1..]), 96 _ => panic!("data fragment with incorrect prefix"), 97 } 98 } 99 } 100 } 101 102 #[cfg(test)] 103 mod tests { 104 use super::*; 105 use alloc::string::{String, ToString}; 106 use alloc::vec; 107 use core::cell::RefCell; 108 109 #[test] test_fragmentation()110 fn test_fragmentation() { 111 let tests = [ 112 ( 113 "a0a1a2a3a4a5a6a7a8a9", 114 2, 115 vec![ 116 "cca0", "cca1", "cca2", "cca3", "cca4", "cca5", "cca6", "cca7", "cca8", "dda9", 117 ], 118 ), 119 ("a0a1a2a3a4a5a6a7a8a9", 5, vec!["cca0a1a2a3", "cca4a5a6a7", "dda8a9"]), 120 ("a0a1a2a3a4a5a6a7a8a9", 9, vec!["cca0a1a2a3a4a5a6a7", "dda8a9"]), 121 ("a0a1a2a3a4a5a6a7a8a9", 80, vec!["dda0a1a2a3a4a5a6a7a8a9"]), 122 ]; 123 for (input, max_size, want) in &tests { 124 let data = hex::decode(input).unwrap(); 125 let fragmenter = Fragmenter::new(&data, *max_size); 126 let got: Vec<String> = fragmenter.map(hex::encode).collect(); 127 let want: Vec<String> = want.iter().map(|s| s.to_string()).collect(); 128 assert_eq!(got, want, "for input {input} max_size {max_size}"); 129 } 130 } 131 132 #[test] 133 #[should_panic] test_reassembly_wrong_prefix()134 fn test_reassembly_wrong_prefix() { 135 // Failure case: unexpected marker byte 136 let mut pending = Reassembler::default(); 137 let _ = pending.accumulate(&[0x00, 0x01, 0x02, 0x03]); 138 } 139 140 #[test] test_reassembly()141 fn test_reassembly() { 142 let tests = [ 143 // Single messages 144 (vec!["dd"], ""), 145 (vec!["dd0000"], "0000"), 146 (vec!["dd010203"], "010203"), 147 // Multipart messages. 148 (vec!["cc0102", "dd0304"], "01020304"), 149 (vec!["cc01", "cc02", "dd0304"], "01020304"), 150 (vec!["cc", "cc02", "dd0304"], "020304"), 151 // Failure case: empty message (no marker byte) 152 (vec![], ""), 153 ]; 154 for (frags, want) in &tests { 155 let mut done = false; 156 let mut pending = Reassembler::default(); 157 for frag in frags { 158 assert!(!done, "left over fragments found"); 159 let frag = hex::decode(frag).unwrap(); 160 let result = pending.accumulate(&frag); 161 if let Some(got) = result { 162 assert_eq!(&hex::encode(got), want, "for input {frags:?}"); 163 done = true; 164 } 165 } 166 } 167 } 168 169 #[test] test_fragmentation_reassembly()170 fn test_fragmentation_reassembly() { 171 let input = "a0a1a2a3a4a5a6a7a8a9b0b1b2b3b4b5b6b7b8b9c0c1c2c3c4c5c6c7c8c9"; 172 let data = hex::decode(input).unwrap(); 173 for max_size in 2..data.len() + 2 { 174 let fragmenter = Fragmenter::new(&data, max_size); 175 let mut done = false; 176 let mut pending = Reassembler::default(); 177 for frag in fragmenter { 178 assert!(!done, "left over fragments found"); 179 let result = pending.accumulate(&frag); 180 if let Some(got) = result { 181 assert_eq!(&hex::encode(got), input, "for max_size {max_size}"); 182 done = true; 183 } 184 } 185 assert!(done); 186 } 187 } 188 189 #[test] test_ta_fragmentation_wrapper()190 fn test_ta_fragmentation_wrapper() { 191 // Simulate a `send()` standalone function for responses. 192 let rsp_reassembler = RefCell::new(Reassembler::default()); 193 let full_rsp: RefCell<Option<Vec<u8>>> = RefCell::new(None); 194 let send = |data: &[u8]| { 195 if let Some(msg) = rsp_reassembler.borrow_mut().accumulate(data) { 196 *full_rsp.borrow_mut() = Some(msg.to_vec()); 197 } 198 }; 199 200 // Simulate a TA. 201 struct Ta { 202 pending_req: Reassembler, 203 max_size: usize, 204 } 205 impl Ta { 206 // Request fragment, response fragments emitted via `send`. 207 fn process_fragment<S: Fn(&[u8])>(&mut self, req_frag: &[u8], send: S) { 208 // Accumulate request fragments until able to feed complete request to `process`. 209 if let Some(full_req) = self.pending_req.accumulate(req_frag) { 210 // Full request message is available, invoke the callback to get a full 211 // response. 212 let full_rsp = self.process(&full_req); 213 for rsp_frag in Fragmenter::new(&full_rsp, self.max_size) { 214 send(&rsp_frag); 215 } 216 } 217 } 218 // Full request to full response. 219 fn process(&self, req: &[u8]) -> Vec<u8> { 220 // Simulate processing a request by echoing it back as a response. 221 req.to_vec() 222 } 223 } 224 225 let req = "a0a1a2a3a4a5a6a7a8a9b0b1b2b3b4b5b6b7b8b9c0c1c2c3c4c5c6c7c8c9"; 226 let req_data = hex::decode(req).unwrap(); 227 for max_size in 2..req_data.len() + 2 { 228 let mut ta = Ta { pending_req: Default::default(), max_size }; 229 // Simulate multiple fragmented messages arriving at the TA. 230 for _msg_idx in 0..3 { 231 // Reset the received-response buffer. 232 *rsp_reassembler.borrow_mut() = Reassembler::default(); 233 *full_rsp.borrow_mut() = None; 234 235 for req_frag in Fragmenter::new(&req_data, max_size) { 236 assert!(full_rsp.borrow().is_none(), "left over fragments found"); 237 ta.process_fragment(&req_frag, send); 238 } 239 // After all the request fragments have been sent in, expect to have a complete 240 // response. 241 if let Some(rsp) = full_rsp.borrow().as_ref() { 242 assert_eq!(hex::encode(rsp), req); 243 } else { 244 panic!("no response received"); 245 } 246 } 247 } 248 } 249 } 250