1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec::Vec;
6 use core::{
7     fmt::{self, Debug},
8     ops::{Deref, DerefMut},
9 };
10 
11 use zeroize::Zeroizing;
12 
13 use crate::{client::MlsError, tree_kem::math::TreeIndex, CipherSuiteProvider};
14 
15 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
16 use mls_rs_core::error::IntoAnyError;
17 
18 #[cfg(feature = "std")]
19 use std::collections::HashMap;
20 
21 #[cfg(not(feature = "std"))]
22 use alloc::collections::BTreeMap;
23 
24 use super::key_schedule::kdf_expand_with_label;
25 
26 pub(crate) const MAX_RATCHET_BACK_HISTORY: u32 = 1024;
27 
28 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
29 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
30 #[repr(u8)]
31 enum SecretTreeNode {
32     Secret(TreeSecret) = 0u8,
33     Ratchet(SecretRatchets) = 1u8,
34 }
35 
36 impl SecretTreeNode {
into_secret(self) -> Option<TreeSecret>37     fn into_secret(self) -> Option<TreeSecret> {
38         if let SecretTreeNode::Secret(secret) = self {
39             Some(secret)
40         } else {
41             None
42         }
43     }
44 }
45 
46 #[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
47 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
48 struct TreeSecret(
49     #[mls_codec(with = "mls_rs_codec::byte_vec")]
50     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
51     Zeroizing<Vec<u8>>,
52 );
53 
54 impl Debug for TreeSecret {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result55     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56         mls_rs_core::debug::pretty_bytes(&self.0)
57             .named("TreeSecret")
58             .fmt(f)
59     }
60 }
61 
62 impl Deref for TreeSecret {
63     type Target = Vec<u8>;
64 
deref(&self) -> &Self::Target65     fn deref(&self) -> &Self::Target {
66         &self.0
67     }
68 }
69 
70 impl DerefMut for TreeSecret {
deref_mut(&mut self) -> &mut Self::Target71     fn deref_mut(&mut self) -> &mut Self::Target {
72         &mut self.0
73     }
74 }
75 
76 impl AsRef<[u8]> for TreeSecret {
as_ref(&self) -> &[u8]77     fn as_ref(&self) -> &[u8] {
78         &self.0
79     }
80 }
81 
82 impl From<Vec<u8>> for TreeSecret {
from(vec: Vec<u8>) -> Self83     fn from(vec: Vec<u8>) -> Self {
84         TreeSecret(Zeroizing::new(vec))
85     }
86 }
87 
88 impl From<Zeroizing<Vec<u8>>> for TreeSecret {
from(vec: Zeroizing<Vec<u8>>) -> Self89     fn from(vec: Zeroizing<Vec<u8>>) -> Self {
90         TreeSecret(vec)
91     }
92 }
93 
94 #[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)]
95 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
96 struct TreeSecretsVec<T: TreeIndex> {
97     #[cfg(feature = "std")]
98     inner: HashMap<T, SecretTreeNode>,
99     #[cfg(not(feature = "std"))]
100     inner: Vec<(T, SecretTreeNode)>,
101 }
102 
103 #[cfg(feature = "std")]
104 impl<T: TreeIndex> TreeSecretsVec<T> {
set_node(&mut self, index: T, value: SecretTreeNode)105     fn set_node(&mut self, index: T, value: SecretTreeNode) {
106         self.inner.insert(index, value);
107     }
108 
take_node(&mut self, index: &T) -> Option<SecretTreeNode>109     fn take_node(&mut self, index: &T) -> Option<SecretTreeNode> {
110         self.inner.remove(index)
111     }
112 }
113 
114 #[cfg(not(feature = "std"))]
115 impl<T: TreeIndex> TreeSecretsVec<T> {
set_node(&mut self, index: T, value: SecretTreeNode)116     fn set_node(&mut self, index: T, value: SecretTreeNode) {
117         if let Some(i) = self.find_node(&index) {
118             self.inner[i] = (index, value)
119         } else {
120             self.inner.push((index, value))
121         }
122     }
123 
take_node(&mut self, index: &T) -> Option<SecretTreeNode>124     fn take_node(&mut self, index: &T) -> Option<SecretTreeNode> {
125         self.find_node(index).map(|i| self.inner.remove(i).1)
126     }
127 
find_node(&self, index: &T) -> Option<usize>128     fn find_node(&self, index: &T) -> Option<usize> {
129         use itertools::Itertools;
130 
131         self.inner
132             .iter()
133             .find_position(|(i, _)| i == index)
134             .map(|(i, _)| i)
135     }
136 }
137 
138 #[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize)]
139 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
140 pub struct SecretTree<T: TreeIndex> {
141     known_secrets: TreeSecretsVec<T>,
142     leaf_count: T,
143 }
144 
145 impl<T: TreeIndex> SecretTree<T> {
empty() -> SecretTree<T>146     pub(crate) fn empty() -> SecretTree<T> {
147         SecretTree {
148             known_secrets: Default::default(),
149             leaf_count: T::zero(),
150         }
151     }
152 }
153 
154 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
155 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
156 pub struct SecretRatchets {
157     pub application: SecretKeyRatchet,
158     pub handshake: SecretKeyRatchet,
159 }
160 
161 impl SecretRatchets {
162     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
message_key_generation<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, generation: u32, key_type: KeyType, ) -> Result<MessageKeyData, MlsError>163     pub async fn message_key_generation<P: CipherSuiteProvider>(
164         &mut self,
165         cipher_suite_provider: &P,
166         generation: u32,
167         key_type: KeyType,
168     ) -> Result<MessageKeyData, MlsError> {
169         match key_type {
170             KeyType::Handshake => {
171                 self.handshake
172                     .get_message_key(cipher_suite_provider, generation)
173                     .await
174             }
175             KeyType::Application => {
176                 self.application
177                     .get_message_key(cipher_suite_provider, generation)
178                     .await
179             }
180         }
181     }
182 
183     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
next_message_key<P: CipherSuiteProvider>( &mut self, cipher_suite: &P, key_type: KeyType, ) -> Result<MessageKeyData, MlsError>184     pub async fn next_message_key<P: CipherSuiteProvider>(
185         &mut self,
186         cipher_suite: &P,
187         key_type: KeyType,
188     ) -> Result<MessageKeyData, MlsError> {
189         match key_type {
190             KeyType::Handshake => self.handshake.next_message_key(cipher_suite).await,
191             KeyType::Application => self.application.next_message_key(cipher_suite).await,
192         }
193     }
194 }
195 
196 impl<T: TreeIndex> SecretTree<T> {
new(leaf_count: T, encryption_secret: Zeroizing<Vec<u8>>) -> SecretTree<T>197     pub fn new(leaf_count: T, encryption_secret: Zeroizing<Vec<u8>>) -> SecretTree<T> {
198         let mut known_secrets = TreeSecretsVec::default();
199 
200         let root_secret = SecretTreeNode::Secret(TreeSecret::from(encryption_secret));
201         known_secrets.set_node(leaf_count.root(), root_secret);
202 
203         Self {
204             known_secrets,
205             leaf_count,
206         }
207     }
208 
209     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
consume_node<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, index: &T, ) -> Result<(), MlsError>210     async fn consume_node<P: CipherSuiteProvider>(
211         &mut self,
212         cipher_suite_provider: &P,
213         index: &T,
214     ) -> Result<(), MlsError> {
215         let node = self.known_secrets.take_node(index);
216 
217         if let Some(secret) = node.and_then(|n| n.into_secret()) {
218             let left_index = index.left().ok_or(MlsError::LeafNodeNoChildren)?;
219             let right_index = index.right().ok_or(MlsError::LeafNodeNoChildren)?;
220 
221             let left_secret =
222                 kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"left", None)
223                     .await?;
224 
225             let right_secret =
226                 kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"right", None)
227                     .await?;
228 
229             self.known_secrets
230                 .set_node(left_index, SecretTreeNode::Secret(left_secret.into()));
231 
232             self.known_secrets
233                 .set_node(right_index, SecretTreeNode::Secret(right_secret.into()));
234         }
235 
236         Ok(())
237     }
238 
239     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
take_leaf_ratchet<P: CipherSuiteProvider>( &mut self, cipher_suite: &P, leaf_index: &T, ) -> Result<SecretRatchets, MlsError>240     async fn take_leaf_ratchet<P: CipherSuiteProvider>(
241         &mut self,
242         cipher_suite: &P,
243         leaf_index: &T,
244     ) -> Result<SecretRatchets, MlsError> {
245         let node_index = leaf_index;
246 
247         let node = match self.known_secrets.take_node(node_index) {
248             Some(node) => node,
249             None => {
250                 // Start at the root node and work your way down consuming any intermediates needed
251                 for i in node_index.direct_copath(&self.leaf_count).into_iter().rev() {
252                     self.consume_node(cipher_suite, &i.path).await?;
253                 }
254 
255                 self.known_secrets
256                     .take_node(node_index)
257                     .ok_or(MlsError::InvalidLeafConsumption)?
258             }
259         };
260 
261         Ok(match node {
262             SecretTreeNode::Ratchet(ratchet) => ratchet,
263             SecretTreeNode::Secret(secret) => SecretRatchets {
264                 application: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Application)
265                     .await?,
266                 handshake: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Handshake).await?,
267             },
268         })
269     }
270 
271     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
next_message_key<P: CipherSuiteProvider>( &mut self, cipher_suite: &P, leaf_index: T, key_type: KeyType, ) -> Result<MessageKeyData, MlsError>272     pub async fn next_message_key<P: CipherSuiteProvider>(
273         &mut self,
274         cipher_suite: &P,
275         leaf_index: T,
276         key_type: KeyType,
277     ) -> Result<MessageKeyData, MlsError> {
278         let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?;
279         let res = ratchet.next_message_key(cipher_suite, key_type).await?;
280 
281         self.known_secrets
282             .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet));
283 
284         Ok(res)
285     }
286 
287     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
message_key_generation<P: CipherSuiteProvider>( &mut self, cipher_suite: &P, leaf_index: T, key_type: KeyType, generation: u32, ) -> Result<MessageKeyData, MlsError>288     pub async fn message_key_generation<P: CipherSuiteProvider>(
289         &mut self,
290         cipher_suite: &P,
291         leaf_index: T,
292         key_type: KeyType,
293         generation: u32,
294     ) -> Result<MessageKeyData, MlsError> {
295         let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?;
296 
297         let res = ratchet
298             .message_key_generation(cipher_suite, generation, key_type)
299             .await?;
300 
301         self.known_secrets
302             .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet));
303 
304         Ok(res)
305     }
306 }
307 
308 #[derive(Clone, Copy)]
309 pub enum KeyType {
310     Handshake,
311     Application,
312 }
313 
314 #[cfg_attr(
315     all(feature = "ffi", not(test)),
316     safer_ffi_gen::ffi_type(clone, opaque)
317 )]
318 #[derive(Clone, PartialEq, Eq, MlsEncode, MlsDecode, MlsSize)]
319 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
320 /// AEAD key derived by the MLS secret tree.
321 pub struct MessageKeyData {
322     #[mls_codec(with = "mls_rs_codec::byte_vec")]
323     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
324     pub(crate) nonce: Zeroizing<Vec<u8>>,
325     #[mls_codec(with = "mls_rs_codec::byte_vec")]
326     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
327     pub(crate) key: Zeroizing<Vec<u8>>,
328     pub(crate) generation: u32,
329 }
330 
331 impl Debug for MessageKeyData {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result332     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
333         f.debug_struct("MessageKeyData")
334             .field("nonce", &mls_rs_core::debug::pretty_bytes(&self.nonce))
335             .field("key", &mls_rs_core::debug::pretty_bytes(&self.key))
336             .field("generation", &self.generation)
337             .finish()
338     }
339 }
340 
341 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
342 impl MessageKeyData {
343     /// AEAD nonce.
344     #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
nonce(&self) -> &[u8]345     pub fn nonce(&self) -> &[u8] {
346         &self.nonce
347     }
348 
349     /// AEAD key.
350     #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
key(&self) -> &[u8]351     pub fn key(&self) -> &[u8] {
352         &self.key
353     }
354 
355     /// Generation of this key within the key schedule.
356     #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
generation(&self) -> u32357     pub fn generation(&self) -> u32 {
358         self.generation
359     }
360 }
361 
362 #[derive(Debug, Clone, PartialEq)]
363 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
364 pub struct SecretKeyRatchet {
365     secret: TreeSecret,
366     generation: u32,
367     #[cfg(all(feature = "out_of_order", feature = "std"))]
368     history: HashMap<u32, MessageKeyData>,
369     #[cfg(all(feature = "out_of_order", not(feature = "std")))]
370     history: BTreeMap<u32, MessageKeyData>,
371 }
372 
373 impl MlsSize for SecretKeyRatchet {
mls_encoded_len(&self) -> usize374     fn mls_encoded_len(&self) -> usize {
375         let len = mls_rs_codec::byte_vec::mls_encoded_len(&self.secret)
376             + self.generation.mls_encoded_len();
377 
378         #[cfg(feature = "out_of_order")]
379         return len + mls_rs_codec::iter::mls_encoded_len(self.history.values());
380         #[cfg(not(feature = "out_of_order"))]
381         return len;
382     }
383 }
384 
385 #[cfg(feature = "out_of_order")]
386 impl MlsEncode for SecretKeyRatchet {
mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error>387     fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
388         mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?;
389         self.generation.mls_encode(writer)?;
390         mls_rs_codec::iter::mls_encode(self.history.values(), writer)
391     }
392 }
393 
394 #[cfg(not(feature = "out_of_order"))]
395 impl MlsEncode for SecretKeyRatchet {
mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error>396     fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
397         mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?;
398         self.generation.mls_encode(writer)
399     }
400 }
401 
402 impl MlsDecode for SecretKeyRatchet {
mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error>403     fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
404         Ok(Self {
405             secret: mls_rs_codec::byte_vec::mls_decode(reader)?,
406             generation: u32::mls_decode(reader)?,
407             #[cfg(all(feature = "std", feature = "out_of_order"))]
408             history: mls_rs_codec::iter::mls_decode_collection(reader, |data| {
409                 let mut items = HashMap::default();
410 
411                 while !data.is_empty() {
412                     let item = MessageKeyData::mls_decode(data)?;
413                     items.insert(item.generation, item);
414                 }
415 
416                 Ok(items)
417             })?,
418             #[cfg(all(not(feature = "std"), feature = "out_of_order"))]
419             history: mls_rs_codec::iter::mls_decode_collection(reader, |data| {
420                 let mut items = alloc::collections::BTreeMap::default();
421 
422                 while !data.is_empty() {
423                     let item = MessageKeyData::mls_decode(data)?;
424                     items.insert(item.generation, item);
425                 }
426 
427                 Ok(items)
428             })?,
429         })
430     }
431 }
432 
433 impl SecretKeyRatchet {
434     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
new<P: CipherSuiteProvider>( cipher_suite_provider: &P, secret: &[u8], key_type: KeyType, ) -> Result<Self, MlsError>435     async fn new<P: CipherSuiteProvider>(
436         cipher_suite_provider: &P,
437         secret: &[u8],
438         key_type: KeyType,
439     ) -> Result<Self, MlsError> {
440         let label = match key_type {
441             KeyType::Handshake => b"handshake".as_slice(),
442             KeyType::Application => b"application".as_slice(),
443         };
444 
445         let secret = kdf_expand_with_label(cipher_suite_provider, secret, label, &[], None)
446             .await
447             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
448 
449         Ok(Self {
450             secret: TreeSecret::from(secret),
451             generation: 0,
452             #[cfg(feature = "out_of_order")]
453             history: Default::default(),
454         })
455     }
456 
457     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_message_key<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, generation: u32, ) -> Result<MessageKeyData, MlsError>458     async fn get_message_key<P: CipherSuiteProvider>(
459         &mut self,
460         cipher_suite_provider: &P,
461         generation: u32,
462     ) -> Result<MessageKeyData, MlsError> {
463         #[cfg(feature = "out_of_order")]
464         if generation < self.generation {
465             return self
466                 .history
467                 .remove_entry(&generation)
468                 .map(|(_, mk)| mk)
469                 .ok_or(MlsError::KeyMissing(generation));
470         }
471 
472         #[cfg(not(feature = "out_of_order"))]
473         if generation < self.generation {
474             return Err(MlsError::KeyMissing(generation));
475         }
476 
477         let max_generation_allowed = self.generation + MAX_RATCHET_BACK_HISTORY;
478 
479         if generation > max_generation_allowed {
480             return Err(MlsError::InvalidFutureGeneration(generation));
481         }
482 
483         #[cfg(not(feature = "out_of_order"))]
484         while self.generation < generation {
485             self.next_message_key(cipher_suite_provider)?;
486         }
487 
488         #[cfg(feature = "out_of_order")]
489         while self.generation < generation {
490             let key_data = self.next_message_key(cipher_suite_provider).await?;
491             self.history.insert(key_data.generation, key_data);
492         }
493 
494         self.next_message_key(cipher_suite_provider).await
495     }
496 
497     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
next_message_key<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, ) -> Result<MessageKeyData, MlsError>498     async fn next_message_key<P: CipherSuiteProvider>(
499         &mut self,
500         cipher_suite_provider: &P,
501     ) -> Result<MessageKeyData, MlsError> {
502         let generation = self.generation;
503 
504         let key = MessageKeyData {
505             nonce: self
506                 .derive_secret(
507                     cipher_suite_provider,
508                     b"nonce",
509                     cipher_suite_provider.aead_nonce_size(),
510                 )
511                 .await?,
512             key: self
513                 .derive_secret(
514                     cipher_suite_provider,
515                     b"key",
516                     cipher_suite_provider.aead_key_size(),
517                 )
518                 .await?,
519             generation,
520         };
521 
522         self.secret = self
523             .derive_secret(
524                 cipher_suite_provider,
525                 b"secret",
526                 cipher_suite_provider.kdf_extract_size(),
527             )
528             .await?
529             .into();
530 
531         self.generation = generation + 1;
532 
533         Ok(key)
534     }
535 
536     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
derive_secret<P: CipherSuiteProvider>( &self, cipher_suite_provider: &P, label: &[u8], len: usize, ) -> Result<Zeroizing<Vec<u8>>, MlsError>537     async fn derive_secret<P: CipherSuiteProvider>(
538         &self,
539         cipher_suite_provider: &P,
540         label: &[u8],
541         len: usize,
542     ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
543         kdf_expand_with_label(
544             cipher_suite_provider,
545             self.secret.as_ref(),
546             label,
547             &self.generation.to_be_bytes(),
548             Some(len),
549         )
550         .await
551         .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
552     }
553 }
554 
555 #[cfg(test)]
556 pub(crate) mod test_utils {
557     use alloc::{string::String, vec::Vec};
558     use mls_rs_core::crypto::CipherSuiteProvider;
559     use zeroize::Zeroizing;
560 
561     use crate::{crypto::test_utils::try_test_cipher_suite_provider, tree_kem::math::TreeIndex};
562 
563     use super::{KeyType, SecretKeyRatchet, SecretTree};
564 
get_test_tree<T: TreeIndex>(secret: Vec<u8>, leaf_count: T) -> SecretTree<T>565     pub(crate) fn get_test_tree<T: TreeIndex>(secret: Vec<u8>, leaf_count: T) -> SecretTree<T> {
566         SecretTree::new(leaf_count, Zeroizing::new(secret))
567     }
568 
569     impl SecretTree<u32> {
get_root_secret(&self) -> Vec<u8>570         pub(crate) fn get_root_secret(&self) -> Vec<u8> {
571             self.known_secrets
572                 .clone()
573                 .take_node(&self.leaf_count.root())
574                 .unwrap()
575                 .into_secret()
576                 .unwrap()
577                 .to_vec()
578         }
579     }
580 
581     #[derive(Debug, serde::Serialize, serde::Deserialize)]
582     pub struct RatchetInteropTestCase {
583         #[serde(with = "hex::serde")]
584         secret: Vec<u8>,
585         label: String,
586         generation: u32,
587         length: usize,
588         #[serde(with = "hex::serde")]
589         out: Vec<u8>,
590     }
591 
592     #[derive(Debug, serde::Serialize, serde::Deserialize)]
593     pub struct InteropTestCase {
594         cipher_suite: u16,
595         derive_tree_secret: RatchetInteropTestCase,
596     }
597 
598     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_basic_crypto_test_vectors()599     async fn test_basic_crypto_test_vectors() {
600         let test_cases: Vec<InteropTestCase> =
601             load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
602 
603         for test_case in test_cases {
604             if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
605                 test_case.derive_tree_secret.verify(&cs).await
606             }
607         }
608     }
609 
610     impl RatchetInteropTestCase {
611         #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify<P: CipherSuiteProvider>(&self, cs: &P)612         pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) {
613             let mut ratchet = SecretKeyRatchet::new(cs, &self.secret, KeyType::Application)
614                 .await
615                 .unwrap();
616 
617             ratchet.secret = self.secret.clone().into();
618             ratchet.generation = self.generation;
619 
620             let computed = ratchet
621                 .derive_secret(cs, self.label.as_bytes(), self.length)
622                 .await
623                 .unwrap();
624 
625             assert_eq!(&computed.to_vec(), &self.out);
626         }
627     }
628 }
629 
630 #[cfg(test)]
631 mod tests {
632     use alloc::vec;
633 
634     use crate::{
635         cipher_suite::CipherSuite,
636         client::test_utils::TEST_CIPHER_SUITE,
637         crypto::test_utils::{
638             test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider,
639         },
640         tree_kem::node::NodeIndex,
641     };
642 
643     #[cfg(not(mls_build_async))]
644     use crate::group::test_utils::random_bytes;
645 
646     use super::{test_utils::get_test_tree, *};
647 
648     use assert_matches::assert_matches;
649 
650     #[cfg(target_arch = "wasm32")]
651     use wasm_bindgen_test::wasm_bindgen_test as test;
652 
653     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_secret_tree()654     async fn test_secret_tree() {
655         test_secret_tree_custom(16u32, (0..16).map(|i| 2 * i).collect(), true).await;
656         test_secret_tree_custom(1u64 << 62, (1..62).map(|i| 1u64 << i).collect(), false).await;
657     }
658 
659     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_secret_tree_custom<T: TreeIndex>( leaf_count: T, leaves_to_check: Vec<T>, all_deleted: bool, )660     async fn test_secret_tree_custom<T: TreeIndex>(
661         leaf_count: T,
662         leaves_to_check: Vec<T>,
663         all_deleted: bool,
664     ) {
665         for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
666             let cs_provider = test_cipher_suite_provider(cipher_suite);
667 
668             let test_secret = vec![0u8; cs_provider.kdf_extract_size()];
669             let mut test_tree = get_test_tree(test_secret, leaf_count.clone());
670 
671             let mut secrets = Vec::<SecretRatchets>::new();
672 
673             for i in &leaves_to_check {
674                 let secret = test_tree
675                     .take_leaf_ratchet(&test_cipher_suite_provider(cipher_suite), i)
676                     .await
677                     .unwrap();
678 
679                 secrets.push(secret);
680             }
681 
682             // Verify the tree is now completely empty
683             assert!(!all_deleted || test_tree.known_secrets.inner.is_empty());
684 
685             // Verify that all the secrets are unique
686             let count = secrets.len();
687             secrets.dedup();
688             assert_eq!(count, secrets.len());
689         }
690     }
691 
692     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_secret_key_ratchet()693     async fn test_secret_key_ratchet() {
694         for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
695             let provider = test_cipher_suite_provider(cipher_suite);
696 
697             let mut app_ratchet = SecretKeyRatchet::new(
698                 &provider,
699                 &vec![0u8; provider.kdf_extract_size()],
700                 KeyType::Application,
701             )
702             .await
703             .unwrap();
704 
705             let mut handshake_ratchet = SecretKeyRatchet::new(
706                 &provider,
707                 &vec![0u8; provider.kdf_extract_size()],
708                 KeyType::Handshake,
709             )
710             .await
711             .unwrap();
712 
713             let app_key_one = app_ratchet.next_message_key(&provider).await.unwrap();
714             let app_key_two = app_ratchet.next_message_key(&provider).await.unwrap();
715             let app_keys = vec![app_key_one, app_key_two];
716 
717             let handshake_key_one = handshake_ratchet.next_message_key(&provider).await.unwrap();
718             let handshake_key_two = handshake_ratchet.next_message_key(&provider).await.unwrap();
719             let handshake_keys = vec![handshake_key_one, handshake_key_two];
720 
721             // Verify that the keys have different outcomes due to their different labels
722             assert_ne!(app_keys, handshake_keys);
723 
724             // Verify that the keys at each generation are different
725             assert_ne!(handshake_keys[0], handshake_keys[1]);
726         }
727     }
728 
729     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_get_key()730     async fn test_get_key() {
731         for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
732             let provider = test_cipher_suite_provider(cipher_suite);
733 
734             let mut ratchet = SecretKeyRatchet::new(
735                 &test_cipher_suite_provider(cipher_suite),
736                 &vec![0u8; provider.kdf_extract_size()],
737                 KeyType::Application,
738             )
739             .await
740             .unwrap();
741 
742             let mut ratchet_clone = ratchet.clone();
743 
744             // This will generate keys 0 and 1 in ratchet_clone
745             let _ = ratchet_clone.next_message_key(&provider).await.unwrap();
746             let clone_2 = ratchet_clone.next_message_key(&provider).await.unwrap();
747 
748             // Going back in time should result in an error
749             let res = ratchet_clone.get_message_key(&provider, 0).await;
750             assert!(res.is_err());
751 
752             // Calling get key should be the same as calling next until hitting the desired generation
753             let second_key = ratchet
754                 .get_message_key(&provider, ratchet_clone.generation - 1)
755                 .await
756                 .unwrap();
757 
758             assert_eq!(clone_2, second_key)
759         }
760     }
761 
762     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_secret_ratchet()763     async fn test_secret_ratchet() {
764         for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
765             let provider = test_cipher_suite_provider(cipher_suite);
766 
767             let mut ratchet = SecretKeyRatchet::new(
768                 &provider,
769                 &vec![0u8; provider.kdf_extract_size()],
770                 KeyType::Application,
771             )
772             .await
773             .unwrap();
774 
775             let original_secret = ratchet.secret.clone();
776             let _ = ratchet.next_message_key(&provider).await.unwrap();
777             let new_secret = ratchet.secret;
778             assert_ne!(original_secret, new_secret)
779         }
780     }
781 
782     #[cfg(feature = "out_of_order")]
783     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_out_of_order_keys()784     async fn test_out_of_order_keys() {
785         let cipher_suite = TEST_CIPHER_SUITE;
786         let provider = test_cipher_suite_provider(cipher_suite);
787 
788         let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
789             .await
790             .unwrap();
791         let mut ratchet_clone = ratchet.clone();
792 
793         // Ask for all the keys in order from the original ratchet
794         let mut ordered_keys = Vec::<MessageKeyData>::new();
795 
796         for i in 0..=MAX_RATCHET_BACK_HISTORY {
797             ordered_keys.push(ratchet.get_message_key(&provider, i).await.unwrap());
798         }
799 
800         // Ask for a key at index MAX_RATCHET_BACK_HISTORY in the clone
801         let last_key = ratchet_clone
802             .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY)
803             .await
804             .unwrap();
805 
806         assert_eq!(last_key, ordered_keys[ordered_keys.len() - 1]);
807 
808         // Get all the other keys
809         let mut back_history_keys = Vec::<MessageKeyData>::new();
810 
811         for i in 0..MAX_RATCHET_BACK_HISTORY - 1 {
812             back_history_keys.push(ratchet_clone.get_message_key(&provider, i).await.unwrap());
813         }
814 
815         assert_eq!(
816             back_history_keys,
817             ordered_keys[..(MAX_RATCHET_BACK_HISTORY as usize) - 1]
818         );
819     }
820 
821     #[cfg(not(feature = "out_of_order"))]
822     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
out_of_order_keys_should_throw_error()823     async fn out_of_order_keys_should_throw_error() {
824         let cipher_suite = TEST_CIPHER_SUITE;
825         let provider = test_cipher_suite_provider(cipher_suite);
826 
827         let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
828             .await
829             .unwrap();
830 
831         ratchet.get_message_key(&provider, 10).await.unwrap();
832         let res = ratchet.get_message_key(&provider, 9).await;
833         assert_matches!(res, Err(MlsError::KeyMissing(9)))
834     }
835 
836     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_too_out_of_order()837     async fn test_too_out_of_order() {
838         let cipher_suite = TEST_CIPHER_SUITE;
839         let provider = test_cipher_suite_provider(cipher_suite);
840 
841         let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
842             .await
843             .unwrap();
844 
845         let res = ratchet
846             .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY + 1)
847             .await;
848 
849         let invalid_generation = MAX_RATCHET_BACK_HISTORY + 1;
850 
851         assert_matches!(
852             res,
853             Err(MlsError::InvalidFutureGeneration(invalid))
854             if invalid == invalid_generation
855         )
856     }
857 
858     #[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
859     struct Ratchet {
860         application_keys: Vec<Vec<u8>>,
861         handshake_keys: Vec<Vec<u8>>,
862     }
863 
864     #[derive(Debug, serde::Serialize, serde::Deserialize)]
865     struct TestCase {
866         cipher_suite: u16,
867         #[serde(with = "hex::serde")]
868         encryption_secret: Vec<u8>,
869         ratchets: Vec<Ratchet>,
870     }
871 
872     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_ratchet_data( secret_tree: &mut SecretTree<NodeIndex>, cipher_suite: CipherSuite, ) -> Vec<Ratchet>873     async fn get_ratchet_data(
874         secret_tree: &mut SecretTree<NodeIndex>,
875         cipher_suite: CipherSuite,
876     ) -> Vec<Ratchet> {
877         let provider = test_cipher_suite_provider(cipher_suite);
878         let mut ratchet_data = Vec::new();
879 
880         for index in 0..16 {
881             let mut ratchets = secret_tree
882                 .take_leaf_ratchet(&provider, &(index * 2))
883                 .await
884                 .unwrap();
885 
886             let mut application_keys = Vec::new();
887 
888             for _ in 0..20 {
889                 let key = ratchets
890                     .handshake
891                     .next_message_key(&provider)
892                     .await
893                     .unwrap()
894                     .mls_encode_to_vec()
895                     .unwrap();
896 
897                 application_keys.push(key);
898             }
899 
900             let mut handshake_keys = Vec::new();
901 
902             for _ in 0..20 {
903                 let key = ratchets
904                     .handshake
905                     .next_message_key(&provider)
906                     .await
907                     .unwrap()
908                     .mls_encode_to_vec()
909                     .unwrap();
910 
911                 handshake_keys.push(key);
912             }
913 
914             ratchet_data.push(Ratchet {
915                 application_keys,
916                 handshake_keys,
917             });
918         }
919 
920         ratchet_data
921     }
922 
923     #[cfg(not(mls_build_async))]
924     #[cfg_attr(coverage_nightly, coverage(off))]
generate_test_vector() -> Vec<TestCase>925     fn generate_test_vector() -> Vec<TestCase> {
926         CipherSuite::all()
927             .map(|cipher_suite| {
928                 let provider = test_cipher_suite_provider(cipher_suite);
929                 let encryption_secret = random_bytes(provider.kdf_extract_size());
930 
931                 let mut secret_tree =
932                     SecretTree::new(16, Zeroizing::new(encryption_secret.clone()));
933 
934                 TestCase {
935                     cipher_suite: cipher_suite.into(),
936                     encryption_secret,
937                     ratchets: get_ratchet_data(&mut secret_tree, cipher_suite),
938                 }
939             })
940             .collect()
941     }
942 
943     #[cfg(mls_build_async)]
generate_test_vector() -> Vec<TestCase>944     fn generate_test_vector() -> Vec<TestCase> {
945         panic!("Tests cannot be generated in async mode");
946     }
947 
948     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_secret_tree_test_vectors()949     async fn test_secret_tree_test_vectors() {
950         let test_cases: Vec<TestCase> = load_test_case_json!(secret_tree, generate_test_vector());
951 
952         for case in test_cases {
953             let Some(cs_provider) = try_test_cipher_suite_provider(case.cipher_suite) else {
954                 continue;
955             };
956 
957             let mut secret_tree = SecretTree::new(16, Zeroizing::new(case.encryption_secret));
958             let ratchet_data = get_ratchet_data(&mut secret_tree, cs_provider.cipher_suite()).await;
959 
960             assert_eq!(ratchet_data, case.ratchets);
961         }
962     }
963 }
964 
965 #[cfg(all(test, feature = "rfc_compliant", feature = "std"))]
966 mod interop_tests {
967     #[cfg(not(mls_build_async))]
968     use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider};
969     use zeroize::Zeroizing;
970 
971     use crate::{
972         crypto::test_utils::try_test_cipher_suite_provider,
973         group::{ciphertext_processor::InteropSenderData, secret_tree::KeyType},
974     };
975 
976     use super::SecretTree;
977 
978     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
interop_test_vector()979     async fn interop_test_vector() {
980         // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/secret-tree.json
981         let test_cases = load_interop_test_cases();
982 
983         for case in test_cases {
984             let Some(cs) = try_test_cipher_suite_provider(case.cipher_suite) else {
985                 continue;
986             };
987 
988             case.sender_data.verify(&cs).await;
989 
990             let mut tree = SecretTree::new(
991                 case.leaves.len() as u32,
992                 Zeroizing::new(case.encryption_secret),
993             );
994 
995             for (index, leaves) in case.leaves.iter().enumerate() {
996                 for leaf in leaves.iter() {
997                     let key = tree
998                         .message_key_generation(
999                             &cs,
1000                             (index as u32) * 2,
1001                             KeyType::Application,
1002                             leaf.generation,
1003                         )
1004                         .await
1005                         .unwrap();
1006 
1007                     assert_eq!(key.key.to_vec(), leaf.application_key);
1008                     assert_eq!(key.nonce.to_vec(), leaf.application_nonce);
1009 
1010                     let key = tree
1011                         .message_key_generation(
1012                             &cs,
1013                             (index as u32) * 2,
1014                             KeyType::Handshake,
1015                             leaf.generation,
1016                         )
1017                         .await
1018                         .unwrap();
1019 
1020                     assert_eq!(key.key.to_vec(), leaf.handshake_key);
1021                     assert_eq!(key.nonce.to_vec(), leaf.handshake_nonce);
1022                 }
1023             }
1024         }
1025     }
1026 
1027     #[derive(Debug, serde::Serialize, serde::Deserialize)]
1028     struct InteropTestCase {
1029         cipher_suite: u16,
1030         #[serde(with = "hex::serde")]
1031         encryption_secret: Vec<u8>,
1032         sender_data: InteropSenderData,
1033         leaves: Vec<Vec<InteropLeaf>>,
1034     }
1035 
1036     #[derive(Debug, serde::Serialize, serde::Deserialize)]
1037     struct InteropLeaf {
1038         generation: u32,
1039         #[serde(with = "hex::serde")]
1040         application_key: Vec<u8>,
1041         #[serde(with = "hex::serde")]
1042         application_nonce: Vec<u8>,
1043         #[serde(with = "hex::serde")]
1044         handshake_key: Vec<u8>,
1045         #[serde(with = "hex::serde")]
1046         handshake_nonce: Vec<u8>,
1047     }
1048 
load_interop_test_cases() -> Vec<InteropTestCase>1049     fn load_interop_test_cases() -> Vec<InteropTestCase> {
1050         load_test_case_json!(secret_tree_interop, generate_test_vector())
1051     }
1052 
1053     #[cfg(not(mls_build_async))]
1054     #[cfg_attr(coverage_nightly, coverage(off))]
generate_test_vector() -> Vec<InteropTestCase>1055     fn generate_test_vector() -> Vec<InteropTestCase> {
1056         let mut test_cases = vec![];
1057 
1058         for cs in CipherSuite::all() {
1059             let Some(cs) = try_test_cipher_suite_provider(*cs) else {
1060                 continue;
1061             };
1062 
1063             let gens = [0, 15];
1064             let tree_sizes = [1, 8, 32];
1065 
1066             for n_leaves in tree_sizes {
1067                 let encryption_secret = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap();
1068 
1069                 let mut tree = SecretTree::new(n_leaves, Zeroizing::new(encryption_secret.clone()));
1070 
1071                 let leaves = (0..n_leaves)
1072                     .map(|leaf| {
1073                         gens.into_iter()
1074                             .map(|gen| {
1075                                 let index = leaf * 2u32;
1076 
1077                                 let handshake_key = tree
1078                                     .message_key_generation(&cs, index, KeyType::Handshake, gen)
1079                                     .unwrap();
1080 
1081                                 let app_key = tree
1082                                     .message_key_generation(&cs, index, KeyType::Application, gen)
1083                                     .unwrap();
1084 
1085                                 InteropLeaf {
1086                                     generation: gen,
1087                                     application_key: app_key.key.to_vec(),
1088                                     application_nonce: app_key.nonce.to_vec(),
1089                                     handshake_key: handshake_key.key.to_vec(),
1090                                     handshake_nonce: handshake_key.nonce.to_vec(),
1091                                 }
1092                             })
1093                             .collect()
1094                     })
1095                     .collect();
1096 
1097                 let case = InteropTestCase {
1098                     cipher_suite: *cs.cipher_suite(),
1099                     encryption_secret,
1100                     sender_data: InteropSenderData::new(&cs),
1101                     leaves,
1102                 };
1103 
1104                 test_cases.push(case);
1105             }
1106         }
1107 
1108         test_cases
1109     }
1110 
1111     #[cfg(mls_build_async)]
generate_test_vector() -> Vec<InteropTestCase>1112     fn generate_test_vector() -> Vec<InteropTestCase> {
1113         panic!("Tests cannot be generated in async mode");
1114     }
1115 }
1116