1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use alloc::vec::Vec; 6 use core::{ 7 fmt::{self, Debug}, 8 ops::{Deref, DerefMut}, 9 }; 10 11 use zeroize::Zeroizing; 12 13 use crate::{client::MlsError, tree_kem::math::TreeIndex, CipherSuiteProvider}; 14 15 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 16 use mls_rs_core::error::IntoAnyError; 17 18 #[cfg(feature = "std")] 19 use std::collections::HashMap; 20 21 #[cfg(not(feature = "std"))] 22 use alloc::collections::BTreeMap; 23 24 use super::key_schedule::kdf_expand_with_label; 25 26 pub(crate) const MAX_RATCHET_BACK_HISTORY: u32 = 1024; 27 28 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)] 29 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 30 #[repr(u8)] 31 enum SecretTreeNode { 32 Secret(TreeSecret) = 0u8, 33 Ratchet(SecretRatchets) = 1u8, 34 } 35 36 impl SecretTreeNode { into_secret(self) -> Option<TreeSecret>37 fn into_secret(self) -> Option<TreeSecret> { 38 if let SecretTreeNode::Secret(secret) = self { 39 Some(secret) 40 } else { 41 None 42 } 43 } 44 } 45 46 #[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)] 47 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 48 struct TreeSecret( 49 #[mls_codec(with = "mls_rs_codec::byte_vec")] 50 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))] 51 Zeroizing<Vec<u8>>, 52 ); 53 54 impl Debug for TreeSecret { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 56 mls_rs_core::debug::pretty_bytes(&self.0) 57 .named("TreeSecret") 58 .fmt(f) 59 } 60 } 61 62 impl Deref for TreeSecret { 63 type Target = Vec<u8>; 64 deref(&self) -> &Self::Target65 fn deref(&self) -> &Self::Target { 66 &self.0 67 } 68 } 69 70 impl DerefMut for TreeSecret { deref_mut(&mut self) -> &mut Self::Target71 fn deref_mut(&mut self) -> &mut Self::Target { 72 &mut self.0 73 } 74 } 75 76 impl AsRef<[u8]> for TreeSecret { as_ref(&self) -> &[u8]77 fn as_ref(&self) -> &[u8] { 78 &self.0 79 } 80 } 81 82 impl From<Vec<u8>> for TreeSecret { from(vec: Vec<u8>) -> Self83 fn from(vec: Vec<u8>) -> Self { 84 TreeSecret(Zeroizing::new(vec)) 85 } 86 } 87 88 impl From<Zeroizing<Vec<u8>>> for TreeSecret { from(vec: Zeroizing<Vec<u8>>) -> Self89 fn from(vec: Zeroizing<Vec<u8>>) -> Self { 90 TreeSecret(vec) 91 } 92 } 93 94 #[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)] 95 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 96 struct TreeSecretsVec<T: TreeIndex> { 97 #[cfg(feature = "std")] 98 inner: HashMap<T, SecretTreeNode>, 99 #[cfg(not(feature = "std"))] 100 inner: Vec<(T, SecretTreeNode)>, 101 } 102 103 #[cfg(feature = "std")] 104 impl<T: TreeIndex> TreeSecretsVec<T> { set_node(&mut self, index: T, value: SecretTreeNode)105 fn set_node(&mut self, index: T, value: SecretTreeNode) { 106 self.inner.insert(index, value); 107 } 108 take_node(&mut self, index: &T) -> Option<SecretTreeNode>109 fn take_node(&mut self, index: &T) -> Option<SecretTreeNode> { 110 self.inner.remove(index) 111 } 112 } 113 114 #[cfg(not(feature = "std"))] 115 impl<T: TreeIndex> TreeSecretsVec<T> { set_node(&mut self, index: T, value: SecretTreeNode)116 fn set_node(&mut self, index: T, value: SecretTreeNode) { 117 if let Some(i) = self.find_node(&index) { 118 self.inner[i] = (index, value) 119 } else { 120 self.inner.push((index, value)) 121 } 122 } 123 take_node(&mut self, index: &T) -> Option<SecretTreeNode>124 fn take_node(&mut self, index: &T) -> Option<SecretTreeNode> { 125 self.find_node(index).map(|i| self.inner.remove(i).1) 126 } 127 find_node(&self, index: &T) -> Option<usize>128 fn find_node(&self, index: &T) -> Option<usize> { 129 use itertools::Itertools; 130 131 self.inner 132 .iter() 133 .find_position(|(i, _)| i == index) 134 .map(|(i, _)| i) 135 } 136 } 137 138 #[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize)] 139 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 140 pub struct SecretTree<T: TreeIndex> { 141 known_secrets: TreeSecretsVec<T>, 142 leaf_count: T, 143 } 144 145 impl<T: TreeIndex> SecretTree<T> { empty() -> SecretTree<T>146 pub(crate) fn empty() -> SecretTree<T> { 147 SecretTree { 148 known_secrets: Default::default(), 149 leaf_count: T::zero(), 150 } 151 } 152 } 153 154 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)] 155 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 156 pub struct SecretRatchets { 157 pub application: SecretKeyRatchet, 158 pub handshake: SecretKeyRatchet, 159 } 160 161 impl SecretRatchets { 162 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] message_key_generation<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, generation: u32, key_type: KeyType, ) -> Result<MessageKeyData, MlsError>163 pub async fn message_key_generation<P: CipherSuiteProvider>( 164 &mut self, 165 cipher_suite_provider: &P, 166 generation: u32, 167 key_type: KeyType, 168 ) -> Result<MessageKeyData, MlsError> { 169 match key_type { 170 KeyType::Handshake => { 171 self.handshake 172 .get_message_key(cipher_suite_provider, generation) 173 .await 174 } 175 KeyType::Application => { 176 self.application 177 .get_message_key(cipher_suite_provider, generation) 178 .await 179 } 180 } 181 } 182 183 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] next_message_key<P: CipherSuiteProvider>( &mut self, cipher_suite: &P, key_type: KeyType, ) -> Result<MessageKeyData, MlsError>184 pub async fn next_message_key<P: CipherSuiteProvider>( 185 &mut self, 186 cipher_suite: &P, 187 key_type: KeyType, 188 ) -> Result<MessageKeyData, MlsError> { 189 match key_type { 190 KeyType::Handshake => self.handshake.next_message_key(cipher_suite).await, 191 KeyType::Application => self.application.next_message_key(cipher_suite).await, 192 } 193 } 194 } 195 196 impl<T: TreeIndex> SecretTree<T> { new(leaf_count: T, encryption_secret: Zeroizing<Vec<u8>>) -> SecretTree<T>197 pub fn new(leaf_count: T, encryption_secret: Zeroizing<Vec<u8>>) -> SecretTree<T> { 198 let mut known_secrets = TreeSecretsVec::default(); 199 200 let root_secret = SecretTreeNode::Secret(TreeSecret::from(encryption_secret)); 201 known_secrets.set_node(leaf_count.root(), root_secret); 202 203 Self { 204 known_secrets, 205 leaf_count, 206 } 207 } 208 209 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] consume_node<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, index: &T, ) -> Result<(), MlsError>210 async fn consume_node<P: CipherSuiteProvider>( 211 &mut self, 212 cipher_suite_provider: &P, 213 index: &T, 214 ) -> Result<(), MlsError> { 215 let node = self.known_secrets.take_node(index); 216 217 if let Some(secret) = node.and_then(|n| n.into_secret()) { 218 let left_index = index.left().ok_or(MlsError::LeafNodeNoChildren)?; 219 let right_index = index.right().ok_or(MlsError::LeafNodeNoChildren)?; 220 221 let left_secret = 222 kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"left", None) 223 .await?; 224 225 let right_secret = 226 kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"right", None) 227 .await?; 228 229 self.known_secrets 230 .set_node(left_index, SecretTreeNode::Secret(left_secret.into())); 231 232 self.known_secrets 233 .set_node(right_index, SecretTreeNode::Secret(right_secret.into())); 234 } 235 236 Ok(()) 237 } 238 239 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] take_leaf_ratchet<P: CipherSuiteProvider>( &mut self, cipher_suite: &P, leaf_index: &T, ) -> Result<SecretRatchets, MlsError>240 async fn take_leaf_ratchet<P: CipherSuiteProvider>( 241 &mut self, 242 cipher_suite: &P, 243 leaf_index: &T, 244 ) -> Result<SecretRatchets, MlsError> { 245 let node_index = leaf_index; 246 247 let node = match self.known_secrets.take_node(node_index) { 248 Some(node) => node, 249 None => { 250 // Start at the root node and work your way down consuming any intermediates needed 251 for i in node_index.direct_copath(&self.leaf_count).into_iter().rev() { 252 self.consume_node(cipher_suite, &i.path).await?; 253 } 254 255 self.known_secrets 256 .take_node(node_index) 257 .ok_or(MlsError::InvalidLeafConsumption)? 258 } 259 }; 260 261 Ok(match node { 262 SecretTreeNode::Ratchet(ratchet) => ratchet, 263 SecretTreeNode::Secret(secret) => SecretRatchets { 264 application: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Application) 265 .await?, 266 handshake: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Handshake).await?, 267 }, 268 }) 269 } 270 271 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] next_message_key<P: CipherSuiteProvider>( &mut self, cipher_suite: &P, leaf_index: T, key_type: KeyType, ) -> Result<MessageKeyData, MlsError>272 pub async fn next_message_key<P: CipherSuiteProvider>( 273 &mut self, 274 cipher_suite: &P, 275 leaf_index: T, 276 key_type: KeyType, 277 ) -> Result<MessageKeyData, MlsError> { 278 let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?; 279 let res = ratchet.next_message_key(cipher_suite, key_type).await?; 280 281 self.known_secrets 282 .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet)); 283 284 Ok(res) 285 } 286 287 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] message_key_generation<P: CipherSuiteProvider>( &mut self, cipher_suite: &P, leaf_index: T, key_type: KeyType, generation: u32, ) -> Result<MessageKeyData, MlsError>288 pub async fn message_key_generation<P: CipherSuiteProvider>( 289 &mut self, 290 cipher_suite: &P, 291 leaf_index: T, 292 key_type: KeyType, 293 generation: u32, 294 ) -> Result<MessageKeyData, MlsError> { 295 let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?; 296 297 let res = ratchet 298 .message_key_generation(cipher_suite, generation, key_type) 299 .await?; 300 301 self.known_secrets 302 .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet)); 303 304 Ok(res) 305 } 306 } 307 308 #[derive(Clone, Copy)] 309 pub enum KeyType { 310 Handshake, 311 Application, 312 } 313 314 #[cfg_attr( 315 all(feature = "ffi", not(test)), 316 safer_ffi_gen::ffi_type(clone, opaque) 317 )] 318 #[derive(Clone, PartialEq, Eq, MlsEncode, MlsDecode, MlsSize)] 319 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 320 /// AEAD key derived by the MLS secret tree. 321 pub struct MessageKeyData { 322 #[mls_codec(with = "mls_rs_codec::byte_vec")] 323 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))] 324 pub(crate) nonce: Zeroizing<Vec<u8>>, 325 #[mls_codec(with = "mls_rs_codec::byte_vec")] 326 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))] 327 pub(crate) key: Zeroizing<Vec<u8>>, 328 pub(crate) generation: u32, 329 } 330 331 impl Debug for MessageKeyData { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result332 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 333 f.debug_struct("MessageKeyData") 334 .field("nonce", &mls_rs_core::debug::pretty_bytes(&self.nonce)) 335 .field("key", &mls_rs_core::debug::pretty_bytes(&self.key)) 336 .field("generation", &self.generation) 337 .finish() 338 } 339 } 340 341 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)] 342 impl MessageKeyData { 343 /// AEAD nonce. 344 #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))] nonce(&self) -> &[u8]345 pub fn nonce(&self) -> &[u8] { 346 &self.nonce 347 } 348 349 /// AEAD key. 350 #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))] key(&self) -> &[u8]351 pub fn key(&self) -> &[u8] { 352 &self.key 353 } 354 355 /// Generation of this key within the key schedule. 356 #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))] generation(&self) -> u32357 pub fn generation(&self) -> u32 { 358 self.generation 359 } 360 } 361 362 #[derive(Debug, Clone, PartialEq)] 363 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 364 pub struct SecretKeyRatchet { 365 secret: TreeSecret, 366 generation: u32, 367 #[cfg(all(feature = "out_of_order", feature = "std"))] 368 history: HashMap<u32, MessageKeyData>, 369 #[cfg(all(feature = "out_of_order", not(feature = "std")))] 370 history: BTreeMap<u32, MessageKeyData>, 371 } 372 373 impl MlsSize for SecretKeyRatchet { mls_encoded_len(&self) -> usize374 fn mls_encoded_len(&self) -> usize { 375 let len = mls_rs_codec::byte_vec::mls_encoded_len(&self.secret) 376 + self.generation.mls_encoded_len(); 377 378 #[cfg(feature = "out_of_order")] 379 return len + mls_rs_codec::iter::mls_encoded_len(self.history.values()); 380 #[cfg(not(feature = "out_of_order"))] 381 return len; 382 } 383 } 384 385 #[cfg(feature = "out_of_order")] 386 impl MlsEncode for SecretKeyRatchet { mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error>387 fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> { 388 mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?; 389 self.generation.mls_encode(writer)?; 390 mls_rs_codec::iter::mls_encode(self.history.values(), writer) 391 } 392 } 393 394 #[cfg(not(feature = "out_of_order"))] 395 impl MlsEncode for SecretKeyRatchet { mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error>396 fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> { 397 mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?; 398 self.generation.mls_encode(writer) 399 } 400 } 401 402 impl MlsDecode for SecretKeyRatchet { mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error>403 fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> { 404 Ok(Self { 405 secret: mls_rs_codec::byte_vec::mls_decode(reader)?, 406 generation: u32::mls_decode(reader)?, 407 #[cfg(all(feature = "std", feature = "out_of_order"))] 408 history: mls_rs_codec::iter::mls_decode_collection(reader, |data| { 409 let mut items = HashMap::default(); 410 411 while !data.is_empty() { 412 let item = MessageKeyData::mls_decode(data)?; 413 items.insert(item.generation, item); 414 } 415 416 Ok(items) 417 })?, 418 #[cfg(all(not(feature = "std"), feature = "out_of_order"))] 419 history: mls_rs_codec::iter::mls_decode_collection(reader, |data| { 420 let mut items = alloc::collections::BTreeMap::default(); 421 422 while !data.is_empty() { 423 let item = MessageKeyData::mls_decode(data)?; 424 items.insert(item.generation, item); 425 } 426 427 Ok(items) 428 })?, 429 }) 430 } 431 } 432 433 impl SecretKeyRatchet { 434 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] new<P: CipherSuiteProvider>( cipher_suite_provider: &P, secret: &[u8], key_type: KeyType, ) -> Result<Self, MlsError>435 async fn new<P: CipherSuiteProvider>( 436 cipher_suite_provider: &P, 437 secret: &[u8], 438 key_type: KeyType, 439 ) -> Result<Self, MlsError> { 440 let label = match key_type { 441 KeyType::Handshake => b"handshake".as_slice(), 442 KeyType::Application => b"application".as_slice(), 443 }; 444 445 let secret = kdf_expand_with_label(cipher_suite_provider, secret, label, &[], None) 446 .await 447 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?; 448 449 Ok(Self { 450 secret: TreeSecret::from(secret), 451 generation: 0, 452 #[cfg(feature = "out_of_order")] 453 history: Default::default(), 454 }) 455 } 456 457 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] get_message_key<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, generation: u32, ) -> Result<MessageKeyData, MlsError>458 async fn get_message_key<P: CipherSuiteProvider>( 459 &mut self, 460 cipher_suite_provider: &P, 461 generation: u32, 462 ) -> Result<MessageKeyData, MlsError> { 463 #[cfg(feature = "out_of_order")] 464 if generation < self.generation { 465 return self 466 .history 467 .remove_entry(&generation) 468 .map(|(_, mk)| mk) 469 .ok_or(MlsError::KeyMissing(generation)); 470 } 471 472 #[cfg(not(feature = "out_of_order"))] 473 if generation < self.generation { 474 return Err(MlsError::KeyMissing(generation)); 475 } 476 477 let max_generation_allowed = self.generation + MAX_RATCHET_BACK_HISTORY; 478 479 if generation > max_generation_allowed { 480 return Err(MlsError::InvalidFutureGeneration(generation)); 481 } 482 483 #[cfg(not(feature = "out_of_order"))] 484 while self.generation < generation { 485 self.next_message_key(cipher_suite_provider)?; 486 } 487 488 #[cfg(feature = "out_of_order")] 489 while self.generation < generation { 490 let key_data = self.next_message_key(cipher_suite_provider).await?; 491 self.history.insert(key_data.generation, key_data); 492 } 493 494 self.next_message_key(cipher_suite_provider).await 495 } 496 497 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] next_message_key<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, ) -> Result<MessageKeyData, MlsError>498 async fn next_message_key<P: CipherSuiteProvider>( 499 &mut self, 500 cipher_suite_provider: &P, 501 ) -> Result<MessageKeyData, MlsError> { 502 let generation = self.generation; 503 504 let key = MessageKeyData { 505 nonce: self 506 .derive_secret( 507 cipher_suite_provider, 508 b"nonce", 509 cipher_suite_provider.aead_nonce_size(), 510 ) 511 .await?, 512 key: self 513 .derive_secret( 514 cipher_suite_provider, 515 b"key", 516 cipher_suite_provider.aead_key_size(), 517 ) 518 .await?, 519 generation, 520 }; 521 522 self.secret = self 523 .derive_secret( 524 cipher_suite_provider, 525 b"secret", 526 cipher_suite_provider.kdf_extract_size(), 527 ) 528 .await? 529 .into(); 530 531 self.generation = generation + 1; 532 533 Ok(key) 534 } 535 536 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] derive_secret<P: CipherSuiteProvider>( &self, cipher_suite_provider: &P, label: &[u8], len: usize, ) -> Result<Zeroizing<Vec<u8>>, MlsError>537 async fn derive_secret<P: CipherSuiteProvider>( 538 &self, 539 cipher_suite_provider: &P, 540 label: &[u8], 541 len: usize, 542 ) -> Result<Zeroizing<Vec<u8>>, MlsError> { 543 kdf_expand_with_label( 544 cipher_suite_provider, 545 self.secret.as_ref(), 546 label, 547 &self.generation.to_be_bytes(), 548 Some(len), 549 ) 550 .await 551 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error())) 552 } 553 } 554 555 #[cfg(test)] 556 pub(crate) mod test_utils { 557 use alloc::{string::String, vec::Vec}; 558 use mls_rs_core::crypto::CipherSuiteProvider; 559 use zeroize::Zeroizing; 560 561 use crate::{crypto::test_utils::try_test_cipher_suite_provider, tree_kem::math::TreeIndex}; 562 563 use super::{KeyType, SecretKeyRatchet, SecretTree}; 564 get_test_tree<T: TreeIndex>(secret: Vec<u8>, leaf_count: T) -> SecretTree<T>565 pub(crate) fn get_test_tree<T: TreeIndex>(secret: Vec<u8>, leaf_count: T) -> SecretTree<T> { 566 SecretTree::new(leaf_count, Zeroizing::new(secret)) 567 } 568 569 impl SecretTree<u32> { get_root_secret(&self) -> Vec<u8>570 pub(crate) fn get_root_secret(&self) -> Vec<u8> { 571 self.known_secrets 572 .clone() 573 .take_node(&self.leaf_count.root()) 574 .unwrap() 575 .into_secret() 576 .unwrap() 577 .to_vec() 578 } 579 } 580 581 #[derive(Debug, serde::Serialize, serde::Deserialize)] 582 pub struct RatchetInteropTestCase { 583 #[serde(with = "hex::serde")] 584 secret: Vec<u8>, 585 label: String, 586 generation: u32, 587 length: usize, 588 #[serde(with = "hex::serde")] 589 out: Vec<u8>, 590 } 591 592 #[derive(Debug, serde::Serialize, serde::Deserialize)] 593 pub struct InteropTestCase { 594 cipher_suite: u16, 595 derive_tree_secret: RatchetInteropTestCase, 596 } 597 598 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_basic_crypto_test_vectors()599 async fn test_basic_crypto_test_vectors() { 600 let test_cases: Vec<InteropTestCase> = 601 load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new()); 602 603 for test_case in test_cases { 604 if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) { 605 test_case.derive_tree_secret.verify(&cs).await 606 } 607 } 608 } 609 610 impl RatchetInteropTestCase { 611 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] verify<P: CipherSuiteProvider>(&self, cs: &P)612 pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) { 613 let mut ratchet = SecretKeyRatchet::new(cs, &self.secret, KeyType::Application) 614 .await 615 .unwrap(); 616 617 ratchet.secret = self.secret.clone().into(); 618 ratchet.generation = self.generation; 619 620 let computed = ratchet 621 .derive_secret(cs, self.label.as_bytes(), self.length) 622 .await 623 .unwrap(); 624 625 assert_eq!(&computed.to_vec(), &self.out); 626 } 627 } 628 } 629 630 #[cfg(test)] 631 mod tests { 632 use alloc::vec; 633 634 use crate::{ 635 cipher_suite::CipherSuite, 636 client::test_utils::TEST_CIPHER_SUITE, 637 crypto::test_utils::{ 638 test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider, 639 }, 640 tree_kem::node::NodeIndex, 641 }; 642 643 #[cfg(not(mls_build_async))] 644 use crate::group::test_utils::random_bytes; 645 646 use super::{test_utils::get_test_tree, *}; 647 648 use assert_matches::assert_matches; 649 650 #[cfg(target_arch = "wasm32")] 651 use wasm_bindgen_test::wasm_bindgen_test as test; 652 653 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_secret_tree()654 async fn test_secret_tree() { 655 test_secret_tree_custom(16u32, (0..16).map(|i| 2 * i).collect(), true).await; 656 test_secret_tree_custom(1u64 << 62, (1..62).map(|i| 1u64 << i).collect(), false).await; 657 } 658 659 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] test_secret_tree_custom<T: TreeIndex>( leaf_count: T, leaves_to_check: Vec<T>, all_deleted: bool, )660 async fn test_secret_tree_custom<T: TreeIndex>( 661 leaf_count: T, 662 leaves_to_check: Vec<T>, 663 all_deleted: bool, 664 ) { 665 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { 666 let cs_provider = test_cipher_suite_provider(cipher_suite); 667 668 let test_secret = vec![0u8; cs_provider.kdf_extract_size()]; 669 let mut test_tree = get_test_tree(test_secret, leaf_count.clone()); 670 671 let mut secrets = Vec::<SecretRatchets>::new(); 672 673 for i in &leaves_to_check { 674 let secret = test_tree 675 .take_leaf_ratchet(&test_cipher_suite_provider(cipher_suite), i) 676 .await 677 .unwrap(); 678 679 secrets.push(secret); 680 } 681 682 // Verify the tree is now completely empty 683 assert!(!all_deleted || test_tree.known_secrets.inner.is_empty()); 684 685 // Verify that all the secrets are unique 686 let count = secrets.len(); 687 secrets.dedup(); 688 assert_eq!(count, secrets.len()); 689 } 690 } 691 692 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_secret_key_ratchet()693 async fn test_secret_key_ratchet() { 694 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { 695 let provider = test_cipher_suite_provider(cipher_suite); 696 697 let mut app_ratchet = SecretKeyRatchet::new( 698 &provider, 699 &vec![0u8; provider.kdf_extract_size()], 700 KeyType::Application, 701 ) 702 .await 703 .unwrap(); 704 705 let mut handshake_ratchet = SecretKeyRatchet::new( 706 &provider, 707 &vec![0u8; provider.kdf_extract_size()], 708 KeyType::Handshake, 709 ) 710 .await 711 .unwrap(); 712 713 let app_key_one = app_ratchet.next_message_key(&provider).await.unwrap(); 714 let app_key_two = app_ratchet.next_message_key(&provider).await.unwrap(); 715 let app_keys = vec![app_key_one, app_key_two]; 716 717 let handshake_key_one = handshake_ratchet.next_message_key(&provider).await.unwrap(); 718 let handshake_key_two = handshake_ratchet.next_message_key(&provider).await.unwrap(); 719 let handshake_keys = vec![handshake_key_one, handshake_key_two]; 720 721 // Verify that the keys have different outcomes due to their different labels 722 assert_ne!(app_keys, handshake_keys); 723 724 // Verify that the keys at each generation are different 725 assert_ne!(handshake_keys[0], handshake_keys[1]); 726 } 727 } 728 729 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_get_key()730 async fn test_get_key() { 731 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { 732 let provider = test_cipher_suite_provider(cipher_suite); 733 734 let mut ratchet = SecretKeyRatchet::new( 735 &test_cipher_suite_provider(cipher_suite), 736 &vec![0u8; provider.kdf_extract_size()], 737 KeyType::Application, 738 ) 739 .await 740 .unwrap(); 741 742 let mut ratchet_clone = ratchet.clone(); 743 744 // This will generate keys 0 and 1 in ratchet_clone 745 let _ = ratchet_clone.next_message_key(&provider).await.unwrap(); 746 let clone_2 = ratchet_clone.next_message_key(&provider).await.unwrap(); 747 748 // Going back in time should result in an error 749 let res = ratchet_clone.get_message_key(&provider, 0).await; 750 assert!(res.is_err()); 751 752 // Calling get key should be the same as calling next until hitting the desired generation 753 let second_key = ratchet 754 .get_message_key(&provider, ratchet_clone.generation - 1) 755 .await 756 .unwrap(); 757 758 assert_eq!(clone_2, second_key) 759 } 760 } 761 762 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_secret_ratchet()763 async fn test_secret_ratchet() { 764 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { 765 let provider = test_cipher_suite_provider(cipher_suite); 766 767 let mut ratchet = SecretKeyRatchet::new( 768 &provider, 769 &vec![0u8; provider.kdf_extract_size()], 770 KeyType::Application, 771 ) 772 .await 773 .unwrap(); 774 775 let original_secret = ratchet.secret.clone(); 776 let _ = ratchet.next_message_key(&provider).await.unwrap(); 777 let new_secret = ratchet.secret; 778 assert_ne!(original_secret, new_secret) 779 } 780 } 781 782 #[cfg(feature = "out_of_order")] 783 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_out_of_order_keys()784 async fn test_out_of_order_keys() { 785 let cipher_suite = TEST_CIPHER_SUITE; 786 let provider = test_cipher_suite_provider(cipher_suite); 787 788 let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake) 789 .await 790 .unwrap(); 791 let mut ratchet_clone = ratchet.clone(); 792 793 // Ask for all the keys in order from the original ratchet 794 let mut ordered_keys = Vec::<MessageKeyData>::new(); 795 796 for i in 0..=MAX_RATCHET_BACK_HISTORY { 797 ordered_keys.push(ratchet.get_message_key(&provider, i).await.unwrap()); 798 } 799 800 // Ask for a key at index MAX_RATCHET_BACK_HISTORY in the clone 801 let last_key = ratchet_clone 802 .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY) 803 .await 804 .unwrap(); 805 806 assert_eq!(last_key, ordered_keys[ordered_keys.len() - 1]); 807 808 // Get all the other keys 809 let mut back_history_keys = Vec::<MessageKeyData>::new(); 810 811 for i in 0..MAX_RATCHET_BACK_HISTORY - 1 { 812 back_history_keys.push(ratchet_clone.get_message_key(&provider, i).await.unwrap()); 813 } 814 815 assert_eq!( 816 back_history_keys, 817 ordered_keys[..(MAX_RATCHET_BACK_HISTORY as usize) - 1] 818 ); 819 } 820 821 #[cfg(not(feature = "out_of_order"))] 822 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] out_of_order_keys_should_throw_error()823 async fn out_of_order_keys_should_throw_error() { 824 let cipher_suite = TEST_CIPHER_SUITE; 825 let provider = test_cipher_suite_provider(cipher_suite); 826 827 let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake) 828 .await 829 .unwrap(); 830 831 ratchet.get_message_key(&provider, 10).await.unwrap(); 832 let res = ratchet.get_message_key(&provider, 9).await; 833 assert_matches!(res, Err(MlsError::KeyMissing(9))) 834 } 835 836 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_too_out_of_order()837 async fn test_too_out_of_order() { 838 let cipher_suite = TEST_CIPHER_SUITE; 839 let provider = test_cipher_suite_provider(cipher_suite); 840 841 let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake) 842 .await 843 .unwrap(); 844 845 let res = ratchet 846 .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY + 1) 847 .await; 848 849 let invalid_generation = MAX_RATCHET_BACK_HISTORY + 1; 850 851 assert_matches!( 852 res, 853 Err(MlsError::InvalidFutureGeneration(invalid)) 854 if invalid == invalid_generation 855 ) 856 } 857 858 #[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)] 859 struct Ratchet { 860 application_keys: Vec<Vec<u8>>, 861 handshake_keys: Vec<Vec<u8>>, 862 } 863 864 #[derive(Debug, serde::Serialize, serde::Deserialize)] 865 struct TestCase { 866 cipher_suite: u16, 867 #[serde(with = "hex::serde")] 868 encryption_secret: Vec<u8>, 869 ratchets: Vec<Ratchet>, 870 } 871 872 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] get_ratchet_data( secret_tree: &mut SecretTree<NodeIndex>, cipher_suite: CipherSuite, ) -> Vec<Ratchet>873 async fn get_ratchet_data( 874 secret_tree: &mut SecretTree<NodeIndex>, 875 cipher_suite: CipherSuite, 876 ) -> Vec<Ratchet> { 877 let provider = test_cipher_suite_provider(cipher_suite); 878 let mut ratchet_data = Vec::new(); 879 880 for index in 0..16 { 881 let mut ratchets = secret_tree 882 .take_leaf_ratchet(&provider, &(index * 2)) 883 .await 884 .unwrap(); 885 886 let mut application_keys = Vec::new(); 887 888 for _ in 0..20 { 889 let key = ratchets 890 .handshake 891 .next_message_key(&provider) 892 .await 893 .unwrap() 894 .mls_encode_to_vec() 895 .unwrap(); 896 897 application_keys.push(key); 898 } 899 900 let mut handshake_keys = Vec::new(); 901 902 for _ in 0..20 { 903 let key = ratchets 904 .handshake 905 .next_message_key(&provider) 906 .await 907 .unwrap() 908 .mls_encode_to_vec() 909 .unwrap(); 910 911 handshake_keys.push(key); 912 } 913 914 ratchet_data.push(Ratchet { 915 application_keys, 916 handshake_keys, 917 }); 918 } 919 920 ratchet_data 921 } 922 923 #[cfg(not(mls_build_async))] 924 #[cfg_attr(coverage_nightly, coverage(off))] generate_test_vector() -> Vec<TestCase>925 fn generate_test_vector() -> Vec<TestCase> { 926 CipherSuite::all() 927 .map(|cipher_suite| { 928 let provider = test_cipher_suite_provider(cipher_suite); 929 let encryption_secret = random_bytes(provider.kdf_extract_size()); 930 931 let mut secret_tree = 932 SecretTree::new(16, Zeroizing::new(encryption_secret.clone())); 933 934 TestCase { 935 cipher_suite: cipher_suite.into(), 936 encryption_secret, 937 ratchets: get_ratchet_data(&mut secret_tree, cipher_suite), 938 } 939 }) 940 .collect() 941 } 942 943 #[cfg(mls_build_async)] generate_test_vector() -> Vec<TestCase>944 fn generate_test_vector() -> Vec<TestCase> { 945 panic!("Tests cannot be generated in async mode"); 946 } 947 948 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_secret_tree_test_vectors()949 async fn test_secret_tree_test_vectors() { 950 let test_cases: Vec<TestCase> = load_test_case_json!(secret_tree, generate_test_vector()); 951 952 for case in test_cases { 953 let Some(cs_provider) = try_test_cipher_suite_provider(case.cipher_suite) else { 954 continue; 955 }; 956 957 let mut secret_tree = SecretTree::new(16, Zeroizing::new(case.encryption_secret)); 958 let ratchet_data = get_ratchet_data(&mut secret_tree, cs_provider.cipher_suite()).await; 959 960 assert_eq!(ratchet_data, case.ratchets); 961 } 962 } 963 } 964 965 #[cfg(all(test, feature = "rfc_compliant", feature = "std"))] 966 mod interop_tests { 967 #[cfg(not(mls_build_async))] 968 use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider}; 969 use zeroize::Zeroizing; 970 971 use crate::{ 972 crypto::test_utils::try_test_cipher_suite_provider, 973 group::{ciphertext_processor::InteropSenderData, secret_tree::KeyType}, 974 }; 975 976 use super::SecretTree; 977 978 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] interop_test_vector()979 async fn interop_test_vector() { 980 // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/secret-tree.json 981 let test_cases = load_interop_test_cases(); 982 983 for case in test_cases { 984 let Some(cs) = try_test_cipher_suite_provider(case.cipher_suite) else { 985 continue; 986 }; 987 988 case.sender_data.verify(&cs).await; 989 990 let mut tree = SecretTree::new( 991 case.leaves.len() as u32, 992 Zeroizing::new(case.encryption_secret), 993 ); 994 995 for (index, leaves) in case.leaves.iter().enumerate() { 996 for leaf in leaves.iter() { 997 let key = tree 998 .message_key_generation( 999 &cs, 1000 (index as u32) * 2, 1001 KeyType::Application, 1002 leaf.generation, 1003 ) 1004 .await 1005 .unwrap(); 1006 1007 assert_eq!(key.key.to_vec(), leaf.application_key); 1008 assert_eq!(key.nonce.to_vec(), leaf.application_nonce); 1009 1010 let key = tree 1011 .message_key_generation( 1012 &cs, 1013 (index as u32) * 2, 1014 KeyType::Handshake, 1015 leaf.generation, 1016 ) 1017 .await 1018 .unwrap(); 1019 1020 assert_eq!(key.key.to_vec(), leaf.handshake_key); 1021 assert_eq!(key.nonce.to_vec(), leaf.handshake_nonce); 1022 } 1023 } 1024 } 1025 } 1026 1027 #[derive(Debug, serde::Serialize, serde::Deserialize)] 1028 struct InteropTestCase { 1029 cipher_suite: u16, 1030 #[serde(with = "hex::serde")] 1031 encryption_secret: Vec<u8>, 1032 sender_data: InteropSenderData, 1033 leaves: Vec<Vec<InteropLeaf>>, 1034 } 1035 1036 #[derive(Debug, serde::Serialize, serde::Deserialize)] 1037 struct InteropLeaf { 1038 generation: u32, 1039 #[serde(with = "hex::serde")] 1040 application_key: Vec<u8>, 1041 #[serde(with = "hex::serde")] 1042 application_nonce: Vec<u8>, 1043 #[serde(with = "hex::serde")] 1044 handshake_key: Vec<u8>, 1045 #[serde(with = "hex::serde")] 1046 handshake_nonce: Vec<u8>, 1047 } 1048 load_interop_test_cases() -> Vec<InteropTestCase>1049 fn load_interop_test_cases() -> Vec<InteropTestCase> { 1050 load_test_case_json!(secret_tree_interop, generate_test_vector()) 1051 } 1052 1053 #[cfg(not(mls_build_async))] 1054 #[cfg_attr(coverage_nightly, coverage(off))] generate_test_vector() -> Vec<InteropTestCase>1055 fn generate_test_vector() -> Vec<InteropTestCase> { 1056 let mut test_cases = vec![]; 1057 1058 for cs in CipherSuite::all() { 1059 let Some(cs) = try_test_cipher_suite_provider(*cs) else { 1060 continue; 1061 }; 1062 1063 let gens = [0, 15]; 1064 let tree_sizes = [1, 8, 32]; 1065 1066 for n_leaves in tree_sizes { 1067 let encryption_secret = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap(); 1068 1069 let mut tree = SecretTree::new(n_leaves, Zeroizing::new(encryption_secret.clone())); 1070 1071 let leaves = (0..n_leaves) 1072 .map(|leaf| { 1073 gens.into_iter() 1074 .map(|gen| { 1075 let index = leaf * 2u32; 1076 1077 let handshake_key = tree 1078 .message_key_generation(&cs, index, KeyType::Handshake, gen) 1079 .unwrap(); 1080 1081 let app_key = tree 1082 .message_key_generation(&cs, index, KeyType::Application, gen) 1083 .unwrap(); 1084 1085 InteropLeaf { 1086 generation: gen, 1087 application_key: app_key.key.to_vec(), 1088 application_nonce: app_key.nonce.to_vec(), 1089 handshake_key: handshake_key.key.to_vec(), 1090 handshake_nonce: handshake_key.nonce.to_vec(), 1091 } 1092 }) 1093 .collect() 1094 }) 1095 .collect(); 1096 1097 let case = InteropTestCase { 1098 cipher_suite: *cs.cipher_suite(), 1099 encryption_secret, 1100 sender_data: InteropSenderData::new(&cs), 1101 leaves, 1102 }; 1103 1104 test_cases.push(case); 1105 } 1106 } 1107 1108 test_cases 1109 } 1110 1111 #[cfg(mls_build_async)] generate_test_vector() -> Vec<InteropTestCase>1112 fn generate_test_vector() -> Vec<InteropTestCase> { 1113 panic!("Tests cannot be generated in async mode"); 1114 } 1115 } 1116