1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use alloc::vec::Vec; 6 use core::fmt::{self, Debug}; 7 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 8 use mls_rs_core::error::IntoAnyError; 9 use zeroize::Zeroizing; 10 11 use crate::{ 12 client::MlsError, 13 crypto::CipherSuiteProvider, 14 group::{epoch::SenderDataSecret, framing::ContentType, key_schedule::kdf_expand_with_label}, 15 tree_kem::node::LeafIndex, 16 }; 17 18 use super::ReuseGuard; 19 20 #[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)] 21 pub(crate) struct SenderData { 22 pub sender: LeafIndex, 23 pub generation: u32, 24 pub reuse_guard: ReuseGuard, 25 } 26 27 #[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)] 28 pub(crate) struct SenderDataAAD { 29 #[mls_codec(with = "mls_rs_codec::byte_vec")] 30 pub group_id: Vec<u8>, 31 pub epoch: u64, 32 pub content_type: ContentType, 33 } 34 35 impl Debug for SenderDataAAD { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 37 f.debug_struct("SenderDataAAD") 38 .field( 39 "group_id", 40 &mls_rs_core::debug::pretty_group_id(&self.group_id), 41 ) 42 .field("epoch", &self.epoch) 43 .field("content_type", &self.content_type) 44 .finish() 45 } 46 } 47 48 pub(crate) struct SenderDataKey<'a, CP: CipherSuiteProvider> { 49 pub(crate) key: Zeroizing<Vec<u8>>, 50 pub(crate) nonce: Zeroizing<Vec<u8>>, 51 cipher_suite_provider: &'a CP, 52 } 53 54 impl<CP: CipherSuiteProvider + Debug> Debug for SenderDataKey<'_, CP> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 56 f.debug_struct("SenderDataKey") 57 .field("key", &mls_rs_core::debug::pretty_bytes(&self.key)) 58 .field("nonce", &mls_rs_core::debug::pretty_bytes(&self.nonce)) 59 .field("cipher_suite_provider", self.cipher_suite_provider) 60 .finish() 61 } 62 } 63 64 impl<'a, CP: CipherSuiteProvider> SenderDataKey<'a, CP> { 65 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] new( sender_data_secret: &SenderDataSecret, ciphertext: &[u8], cipher_suite_provider: &'a CP, ) -> Result<SenderDataKey<'a, CP>, MlsError>66 pub(super) async fn new( 67 sender_data_secret: &SenderDataSecret, 68 ciphertext: &[u8], 69 cipher_suite_provider: &'a CP, 70 ) -> Result<SenderDataKey<'a, CP>, MlsError> { 71 // Sample the first extract_size bytes of the ciphertext, and if it is shorter, just use 72 // the ciphertext itself 73 let extract_size = cipher_suite_provider.kdf_extract_size(); 74 let ciphertext_sample = ciphertext.get(0..extract_size).unwrap_or(ciphertext); 75 76 // Generate a sender data key and nonce using the sender_data_secret from the current 77 // epoch's key schedule 78 let key = kdf_expand_with_label( 79 cipher_suite_provider, 80 sender_data_secret, 81 b"key", 82 ciphertext_sample, 83 Some(cipher_suite_provider.aead_key_size()), 84 ) 85 .await?; 86 87 let nonce = kdf_expand_with_label( 88 cipher_suite_provider, 89 sender_data_secret, 90 b"nonce", 91 ciphertext_sample, 92 Some(cipher_suite_provider.aead_nonce_size()), 93 ) 94 .await?; 95 96 Ok(Self { 97 key, 98 nonce, 99 cipher_suite_provider, 100 }) 101 } 102 103 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] seal( &self, sender_data: &SenderData, aad: &SenderDataAAD, ) -> Result<Vec<u8>, MlsError>104 pub(crate) async fn seal( 105 &self, 106 sender_data: &SenderData, 107 aad: &SenderDataAAD, 108 ) -> Result<Vec<u8>, MlsError> { 109 self.cipher_suite_provider 110 .aead_seal( 111 &self.key, 112 &sender_data.mls_encode_to_vec()?, 113 Some(&aad.mls_encode_to_vec()?), 114 &self.nonce, 115 ) 116 .await 117 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error())) 118 } 119 120 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] open( &self, sender_data: &[u8], aad: &SenderDataAAD, ) -> Result<SenderData, MlsError>121 pub(crate) async fn open( 122 &self, 123 sender_data: &[u8], 124 aad: &SenderDataAAD, 125 ) -> Result<SenderData, MlsError> { 126 self.cipher_suite_provider 127 .aead_open( 128 &self.key, 129 sender_data, 130 Some(&aad.mls_encode_to_vec()?), 131 &self.nonce, 132 ) 133 .await 134 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error())) 135 .and_then(|data| SenderData::mls_decode(&mut &**data).map_err(From::from)) 136 } 137 } 138 139 #[cfg(test)] 140 pub(crate) mod test_utils { 141 use alloc::vec::Vec; 142 use mls_rs_core::crypto::CipherSuiteProvider; 143 144 use super::SenderDataKey; 145 146 #[derive(Debug, serde::Serialize, serde::Deserialize)] 147 pub struct InteropSenderData { 148 #[serde(with = "hex::serde")] 149 pub sender_data_secret: Vec<u8>, 150 #[serde(with = "hex::serde")] 151 pub ciphertext: Vec<u8>, 152 #[serde(with = "hex::serde")] 153 pub key: Vec<u8>, 154 #[serde(with = "hex::serde")] 155 pub nonce: Vec<u8>, 156 } 157 158 impl InteropSenderData { 159 #[cfg(not(mls_build_async))] 160 #[cfg_attr(coverage_nightly, coverage(off))] new<P: CipherSuiteProvider>(cs: &P) -> Self161 pub(crate) fn new<P: CipherSuiteProvider>(cs: &P) -> Self { 162 let secret = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap().into(); 163 let ciphertext = cs.random_bytes_vec(77).unwrap(); 164 let key = SenderDataKey::new(&secret, &ciphertext, cs).unwrap(); 165 let secret = (*secret).clone(); 166 167 Self { 168 ciphertext, 169 key: key.key.to_vec(), 170 nonce: key.nonce.to_vec(), 171 sender_data_secret: secret, 172 } 173 } 174 175 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] verify<P: CipherSuiteProvider>(&self, cs: &P)176 pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) { 177 let secret = self.sender_data_secret.clone().into(); 178 179 let key = SenderDataKey::new(&secret, &self.ciphertext, cs) 180 .await 181 .unwrap(); 182 183 assert_eq!(key.key.to_vec(), self.key, "sender data key mismatch"); 184 assert_eq!(key.nonce.to_vec(), self.nonce, "sender data nonce mismatch"); 185 } 186 } 187 } 188 189 #[cfg(test)] 190 mod tests { 191 192 use alloc::vec::Vec; 193 #[cfg(target_arch = "wasm32")] 194 use wasm_bindgen_test::wasm_bindgen_test as test; 195 196 use crate::{ 197 crypto::test_utils::try_test_cipher_suite_provider, 198 group::{ciphertext_processor::reuse_guard::ReuseGuard, framing::ContentType}, 199 tree_kem::node::LeafIndex, 200 }; 201 202 use super::{SenderData, SenderDataAAD, SenderDataKey}; 203 204 #[cfg(not(mls_build_async))] 205 use crate::{ 206 cipher_suite::CipherSuite, crypto::test_utils::test_cipher_suite_provider, 207 group::test_utils::random_bytes, CipherSuiteProvider, 208 }; 209 210 #[derive(serde::Deserialize, serde::Serialize)] 211 struct TestCase { 212 cipher_suite: u16, 213 #[serde(with = "hex::serde")] 214 secret: Vec<u8>, 215 #[serde(with = "hex::serde")] 216 ciphertext_bytes: Vec<u8>, 217 #[serde(with = "hex::serde")] 218 expected_key: Vec<u8>, 219 #[serde(with = "hex::serde")] 220 expected_nonce: Vec<u8>, 221 sender_data: TestSenderData, 222 sender_data_aad: TestSenderDataAAD, 223 #[serde(with = "hex::serde")] 224 expected_ciphertext: Vec<u8>, 225 } 226 227 #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] 228 struct TestSenderData { 229 sender: u32, 230 generation: u32, 231 #[serde(with = "hex::serde")] 232 reuse_guard: Vec<u8>, 233 } 234 235 impl From<TestSenderData> for SenderData { from(value: TestSenderData) -> Self236 fn from(value: TestSenderData) -> Self { 237 let reuse_guard = ReuseGuard::new(value.reuse_guard); 238 239 Self { 240 sender: LeafIndex(value.sender), 241 generation: value.generation, 242 reuse_guard, 243 } 244 } 245 } 246 247 #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] 248 struct TestSenderDataAAD { 249 epoch: u64, 250 #[serde(with = "hex::serde")] 251 group_id: Vec<u8>, 252 } 253 254 impl From<TestSenderDataAAD> for SenderDataAAD { from(value: TestSenderDataAAD) -> Self255 fn from(value: TestSenderDataAAD) -> Self { 256 Self { 257 epoch: value.epoch, 258 group_id: value.group_id, 259 content_type: ContentType::Application, 260 } 261 } 262 } 263 264 #[cfg(not(mls_build_async))] 265 #[cfg_attr(coverage_nightly, coverage(off))] generate_test_vector() -> Vec<TestCase>266 fn generate_test_vector() -> Vec<TestCase> { 267 let test_cases = CipherSuite::all().map(test_cipher_suite_provider).map( 268 #[cfg_attr(coverage_nightly, coverage(off))] 269 |provider| { 270 let ext_size = provider.kdf_extract_size(); 271 let secret = random_bytes(ext_size).into(); 272 let ciphertext_sizes = [ext_size - 5, ext_size, ext_size + 5]; 273 274 let sender_data = TestSenderData { 275 sender: 0, 276 generation: 13, 277 reuse_guard: random_bytes(4), 278 }; 279 280 let sender_data_aad = TestSenderDataAAD { 281 group_id: b"group".to_vec(), 282 epoch: 42, 283 }; 284 285 ciphertext_sizes.into_iter().map( 286 #[cfg_attr(coverage_nightly, coverage(off))] 287 move |ciphertext_size| { 288 let ciphertext_bytes = random_bytes(ciphertext_size); 289 290 let sender_data_key = 291 SenderDataKey::new(&secret, &ciphertext_bytes, &provider).unwrap(); 292 293 let expected_ciphertext = sender_data_key 294 .seal(&sender_data.clone().into(), &sender_data_aad.clone().into()) 295 .unwrap(); 296 297 TestCase { 298 cipher_suite: provider.cipher_suite().into(), 299 secret: secret.to_vec(), 300 ciphertext_bytes, 301 expected_key: sender_data_key.key.to_vec(), 302 expected_nonce: sender_data_key.nonce.to_vec(), 303 sender_data: sender_data.clone(), 304 sender_data_aad: sender_data_aad.clone(), 305 expected_ciphertext, 306 } 307 }, 308 ) 309 }, 310 ); 311 312 test_cases.flatten().collect() 313 } 314 315 #[cfg(mls_build_async)] generate_test_vector() -> Vec<TestCase>316 fn generate_test_vector() -> Vec<TestCase> { 317 panic!("Tests cannot be generated in async mode"); 318 } 319 load_test_cases() -> Vec<TestCase>320 fn load_test_cases() -> Vec<TestCase> { 321 load_test_case_json!(sender_data_key_test_vector, generate_test_vector()) 322 } 323 324 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] sender_data_key_test_vector()325 async fn sender_data_key_test_vector() { 326 for test_case in load_test_cases() { 327 let Some(provider) = try_test_cipher_suite_provider(test_case.cipher_suite) else { 328 continue; 329 }; 330 331 let sender_data_key = SenderDataKey::new( 332 &test_case.secret.into(), 333 &test_case.ciphertext_bytes, 334 &provider, 335 ) 336 .await 337 .unwrap(); 338 339 assert_eq!(sender_data_key.key.to_vec(), test_case.expected_key); 340 assert_eq!(sender_data_key.nonce.to_vec(), test_case.expected_nonce); 341 342 let sender_data = test_case.sender_data.into(); 343 let sender_data_aad = test_case.sender_data_aad.into(); 344 345 let ciphertext = sender_data_key 346 .seal(&sender_data, &sender_data_aad) 347 .await 348 .unwrap(); 349 350 assert_eq!(ciphertext, test_case.expected_ciphertext); 351 352 let plaintext = sender_data_key 353 .open(&ciphertext, &sender_data_aad) 354 .await 355 .unwrap(); 356 357 assert_eq!(plaintext, sender_data); 358 } 359 } 360 } 361