1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec;
6 use alloc::vec::Vec;
7 use mls_rs_codec::{MlsDecode, MlsEncode};
8 use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider, SignaturePublicKey};
9 
10 use crate::{
11     client::test_utils::{TestClientConfig, TEST_PROTOCOL_VERSION},
12     crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
13     group::{
14         confirmation_tag::ConfirmationTag,
15         epoch::EpochSecrets,
16         framing::{Content, WireFormat},
17         message_processor::{EventOrContent, MessageProcessor},
18         mls_rules::EncryptionOptions,
19         padding::PaddingMode,
20         proposal::{Proposal, RemoveProposal},
21         secret_tree::test_utils::get_test_tree,
22         test_utils::{random_bytes, test_group_custom_config},
23         AuthenticatedContent, Commit, Group, GroupContext, MlsMessage, Sender,
24     },
25     mls_rules::DefaultMlsRules,
26     test_utils::is_edwards,
27     tree_kem::{leaf_node::test_utils::get_basic_test_node, node::LeafIndex},
28 };
29 
30 const FRAMING_N_LEAVES: u32 = 2;
31 
32 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
33 struct FramingTestCase {
34     #[serde(flatten)]
35     pub context: InteropGroupContext,
36 
37     #[serde(with = "hex::serde")]
38     pub signature_priv: Vec<u8>,
39     #[serde(with = "hex::serde")]
40     pub signature_pub: Vec<u8>,
41 
42     #[serde(with = "hex::serde")]
43     pub encryption_secret: Vec<u8>,
44     #[serde(with = "hex::serde")]
45     pub sender_data_secret: Vec<u8>,
46     #[serde(with = "hex::serde")]
47     pub membership_key: Vec<u8>,
48 
49     #[serde(with = "hex::serde")]
50     pub proposal: Vec<u8>,
51     #[serde(with = "hex::serde")]
52     pub proposal_priv: Vec<u8>,
53     #[serde(with = "hex::serde")]
54     pub proposal_pub: Vec<u8>,
55 
56     #[serde(with = "hex::serde")]
57     pub commit: Vec<u8>,
58     #[serde(with = "hex::serde")]
59     pub commit_priv: Vec<u8>,
60     #[serde(with = "hex::serde")]
61     pub commit_pub: Vec<u8>,
62 
63     #[serde(with = "hex::serde")]
64     pub application: Vec<u8>,
65     #[serde(with = "hex::serde")]
66     pub application_priv: Vec<u8>,
67 }
68 
69 impl FramingTestCase {
70     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
71     #[cfg_attr(coverage_nightly, coverage(off))]
random<P: CipherSuiteProvider>(cs: &P) -> Self72     async fn random<P: CipherSuiteProvider>(cs: &P) -> Self {
73         let mut context = InteropGroupContext::random(cs);
74         context.cipher_suite = cs.cipher_suite().into();
75 
76         let (mut signature_priv, signature_pub) = cs.signature_key_generate().await.unwrap();
77 
78         if is_edwards(*cs.cipher_suite()) {
79             signature_priv = signature_priv[0..signature_priv.len() / 2].to_vec().into();
80         }
81 
82         Self {
83             context,
84             signature_priv: signature_priv.to_vec(),
85             signature_pub: signature_pub.to_vec(),
86             encryption_secret: random_bytes(cs.kdf_extract_size()),
87             sender_data_secret: random_bytes(cs.kdf_extract_size()),
88             membership_key: random_bytes(cs.kdf_extract_size()),
89             ..Default::default()
90         }
91     }
92 }
93 
94 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
95 pub struct InteropGroupContext {
96     pub cipher_suite: u16,
97     #[serde(with = "hex::serde")]
98     pub group_id: Vec<u8>,
99     pub epoch: u64,
100     #[serde(with = "hex::serde")]
101     pub tree_hash: Vec<u8>,
102     #[serde(with = "hex::serde")]
103     pub confirmed_transcript_hash: Vec<u8>,
104 }
105 
106 impl InteropGroupContext {
107     #[cfg_attr(coverage_nightly, coverage(off))]
random<P: CipherSuiteProvider>(cs: &P) -> Self108     fn random<P: CipherSuiteProvider>(cs: &P) -> Self {
109         Self {
110             cipher_suite: cs.cipher_suite().into(),
111             group_id: random_bytes(cs.kdf_extract_size()),
112             epoch: 0x121212,
113             tree_hash: random_bytes(cs.kdf_extract_size()),
114             confirmed_transcript_hash: random_bytes(cs.kdf_extract_size()),
115         }
116     }
117 }
118 
119 impl From<InteropGroupContext> for GroupContext {
120     #[cfg_attr(coverage_nightly, coverage(off))]
from(ctx: InteropGroupContext) -> Self121     fn from(ctx: InteropGroupContext) -> Self {
122         Self {
123             cipher_suite: ctx.cipher_suite.into(),
124             protocol_version: TEST_PROTOCOL_VERSION,
125             group_id: ctx.group_id,
126             epoch: ctx.epoch,
127             tree_hash: ctx.tree_hash,
128             confirmed_transcript_hash: ctx.confirmed_transcript_hash.into(),
129             extensions: vec![].into(),
130         }
131     }
132 }
133 
134 // The test vector can be found here:
135 // https://github.com/mlswg/mls-implementations/blob/main/test-vectors/message-protection.json
136 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
framing_proposal()137 async fn framing_proposal() {
138     #[cfg(not(mls_build_async))]
139     let test_cases: Vec<FramingTestCase> =
140         load_test_case_json!(framing, generate_framing_test_vector());
141 
142     #[cfg(mls_build_async)]
143     let test_cases: Vec<FramingTestCase> =
144         load_test_case_json!(framing, generate_framing_test_vector().await);
145 
146     for test_case in test_cases.into_iter() {
147         let Some(cs) = try_test_cipher_suite_provider(test_case.context.cipher_suite) else {
148             continue;
149         };
150 
151         let to_check = vec![
152             test_case.proposal_priv.clone(),
153             test_case.proposal_pub.clone(),
154         ];
155 
156         // Wasm uses incompatible signature secret key format
157         #[cfg(not(target_arch = "wasm32"))]
158         let mut to_check = to_check;
159 
160         #[cfg(not(target_arch = "wasm32"))]
161         for enable_encryption in [true, false] {
162             let proposal = Proposal::mls_decode(&mut &*test_case.proposal).unwrap();
163 
164             let built = make_group(&test_case, true, enable_encryption, &cs)
165                 .await
166                 .proposal_message(proposal, vec![])
167                 .await
168                 .unwrap()
169                 .mls_encode_to_vec()
170                 .unwrap();
171 
172             to_check.push(built);
173         }
174 
175         let proposal = Proposal::mls_decode(&mut &*test_case.proposal).unwrap();
176 
177         for message in to_check {
178             match process_message(&test_case, &message, &cs).await {
179                 Content::Proposal(p) => assert_eq!(p.as_ref(), &proposal),
180                 _ => panic!("received value not proposal"),
181             };
182         }
183     }
184 }
185 
186 // The test vector can be found here:
187 // https://github.com/mlswg/mls-implementations/blob/main/test-vectors/message-protection.json
188 // Wasm uses incompatible signature secret key format
189 #[cfg(not(target_arch = "wasm32"))]
190 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
framing_application()191 async fn framing_application() {
192     #[cfg(not(mls_build_async))]
193     let test_cases: Vec<FramingTestCase> =
194         load_test_case_json!(framing, generate_framing_test_vector());
195 
196     #[cfg(mls_build_async)]
197     let test_cases: Vec<FramingTestCase> =
198         load_test_case_json!(framing, generate_framing_test_vector().await);
199 
200     for test_case in test_cases.into_iter() {
201         let Some(cs) = try_test_cipher_suite_provider(test_case.context.cipher_suite) else {
202             continue;
203         };
204 
205         let built_priv = make_group(&test_case, true, true, &cs)
206             .await
207             .encrypt_application_message(&test_case.application, vec![])
208             .await
209             .unwrap()
210             .mls_encode_to_vec()
211             .unwrap();
212 
213         for message in [&test_case.application_priv, &built_priv] {
214             match process_message(&test_case, message, &cs).await {
215                 Content::Application(data) => assert_eq!(data.as_ref(), &test_case.application),
216                 _ => panic!("decrypted value not application data"),
217             };
218         }
219     }
220 }
221 
222 // The test vector can be found here:
223 // https://github.com/mlswg/mls-implementations/blob/main/test-vectors/message-protection.json
224 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
framing_commit()225 async fn framing_commit() {
226     #[cfg(not(mls_build_async))]
227     let test_cases: Vec<FramingTestCase> =
228         load_test_case_json!(framing, generate_framing_test_vector());
229 
230     #[cfg(mls_build_async)]
231     let test_cases: Vec<FramingTestCase> =
232         load_test_case_json!(framing, generate_framing_test_vector().await);
233 
234     for test_case in test_cases.into_iter() {
235         let Some(cs) = try_test_cipher_suite_provider(test_case.context.cipher_suite) else {
236             continue;
237         };
238 
239         let commit = Commit::mls_decode(&mut &*test_case.commit).unwrap();
240 
241         let to_check = vec![test_case.commit_priv.clone(), test_case.commit_pub.clone()];
242 
243         // Wasm uses incompatible signature secret key format
244         #[cfg(not(target_arch = "wasm32"))]
245         let to_check = {
246             let mut to_check = to_check;
247 
248             let mut signature_priv = test_case.signature_priv.clone();
249 
250             if is_edwards(test_case.context.cipher_suite) {
251                 signature_priv.extend(test_case.signature_pub.iter());
252             }
253 
254             let mut auth_content = AuthenticatedContent::new_signed(
255                 &cs,
256                 &test_case.context.clone().into(),
257                 Sender::Member(1),
258                 Content::Commit(alloc::boxed::Box::new(commit.clone())),
259                 &signature_priv.into(),
260                 WireFormat::PublicMessage,
261                 vec![],
262             )
263             .await
264             .unwrap();
265 
266             auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
267 
268             for enable_encryption in [true, false] {
269                 let built = make_group(&test_case, true, enable_encryption, &cs)
270                     .await
271                     .format_for_wire(auth_content.clone())
272                     .await
273                     .unwrap()
274                     .mls_encode_to_vec()
275                     .unwrap();
276 
277                 to_check.push(built);
278             }
279 
280             to_check
281         };
282 
283         for message in to_check {
284             match process_message(&test_case, &message, &cs).await {
285                 Content::Commit(c) => assert_eq!(&*c, &commit),
286                 _ => panic!("received value not commit"),
287             };
288         }
289         let commit = Commit::mls_decode(&mut &*test_case.commit).unwrap();
290 
291         match process_message(&test_case, &test_case.commit_priv.clone(), &cs).await {
292             Content::Commit(c) => assert_eq!(&*c, &commit),
293             _ => panic!("received value not commit"),
294         };
295     }
296 }
297 
298 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
299 #[cfg_attr(coverage_nightly, coverage(off))]
generate_framing_test_vector() -> Vec<FramingTestCase>300 async fn generate_framing_test_vector() -> Vec<FramingTestCase> {
301     let mut test_vector = vec![];
302 
303     for cs in CipherSuite::all() {
304         let cs = test_cipher_suite_provider(cs);
305 
306         let mut test_case = FramingTestCase::random(&cs).await;
307 
308         // Generate private application message
309         test_case.application = cs.random_bytes_vec(42).unwrap();
310 
311         let application_priv = make_group(&test_case, true, true, &cs)
312             .await
313             .encrypt_application_message(&test_case.application, vec![])
314             .await
315             .unwrap();
316 
317         test_case.application_priv = application_priv.mls_encode_to_vec().unwrap();
318 
319         // Generate private and public proposal message
320         let proposal = Proposal::Remove(RemoveProposal {
321             to_remove: LeafIndex(2),
322         });
323 
324         test_case.proposal = proposal.mls_encode_to_vec().unwrap();
325 
326         let mut group = make_group(&test_case, true, false, &cs).await;
327         let proposal_pub = group.proposal_message(proposal.clone(), vec![]).await;
328         test_case.proposal_pub = proposal_pub.unwrap().mls_encode_to_vec().unwrap();
329 
330         let mut group = make_group(&test_case, true, true, &cs).await;
331         let proposal_priv = group.proposal_message(proposal, vec![]).await.unwrap();
332         test_case.proposal_priv = proposal_priv.mls_encode_to_vec().unwrap();
333 
334         // Generate private and public commit message
335         let commit = Commit {
336             proposals: vec![],
337             path: None,
338         };
339 
340         test_case.commit = commit.mls_encode_to_vec().unwrap();
341 
342         let mut auth_content = AuthenticatedContent::new_signed(
343             &cs,
344             group.context(),
345             Sender::Member(1),
346             Content::Commit(alloc::boxed::Box::new(commit.clone())),
347             &group.signer,
348             WireFormat::PublicMessage,
349             vec![],
350         )
351         .await
352         .unwrap();
353 
354         auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
355 
356         let mut group = make_group(&test_case, true, false, &cs).await;
357         let commit_pub = group.format_for_wire(auth_content.clone()).await.unwrap();
358         test_case.commit_pub = commit_pub.mls_encode_to_vec().unwrap();
359 
360         let mut auth_content = AuthenticatedContent::new_signed(
361             &cs,
362             group.context(),
363             Sender::Member(1),
364             Content::Commit(alloc::boxed::Box::new(commit)),
365             &group.signer,
366             WireFormat::PrivateMessage,
367             vec![],
368         )
369         .await
370         .unwrap();
371 
372         auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
373 
374         let mut group = make_group(&test_case, true, true, &cs).await;
375         let commit_priv = group.format_for_wire(auth_content.clone()).await.unwrap();
376         test_case.commit_priv = commit_priv.mls_encode_to_vec().unwrap();
377 
378         test_vector.push(test_case);
379     }
380 
381     test_vector
382 }
383 
384 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
make_group<P: CipherSuiteProvider>( test_case: &FramingTestCase, for_send: bool, control_encryption_enabled: bool, cs: &P, ) -> Group<TestClientConfig>385 async fn make_group<P: CipherSuiteProvider>(
386     test_case: &FramingTestCase,
387     for_send: bool,
388     control_encryption_enabled: bool,
389     cs: &P,
390 ) -> Group<TestClientConfig> {
391     let mut group =
392         test_group_custom_config(
393             TEST_PROTOCOL_VERSION,
394             test_case.context.cipher_suite.into(),
395             |b| {
396                 b.mls_rules(DefaultMlsRules::default().with_encryption_options(
397                     EncryptionOptions::new(control_encryption_enabled, PaddingMode::None),
398                 ))
399             },
400         )
401         .await
402         .group;
403 
404     // Add a leaf for the sender. It will get index 1.
405     let mut leaf = get_basic_test_node(cs.cipher_suite(), "leaf").await;
406 
407     leaf.signing_identity.signature_key = SignaturePublicKey::from(test_case.signature_pub.clone());
408 
409     group
410         .state
411         .public_tree
412         .add_leaves(vec![leaf], &group.config.0.identity_provider, cs)
413         .await
414         .unwrap();
415 
416     // Convince the group that their index is 1 if they send or 0 if they receive.
417     group.private_tree.self_index = LeafIndex(if for_send { 1 } else { 0 });
418 
419     // Convince the group that their signing key is the one from the test case
420     let mut signature_priv = test_case.signature_priv.clone();
421 
422     if is_edwards(test_case.context.cipher_suite) {
423         signature_priv.extend(test_case.signature_pub.iter());
424     }
425 
426     group.signer = signature_priv.into();
427 
428     // Set the group context and secrets
429     let context = GroupContext::from(test_case.context.clone());
430     let secret_tree = get_test_tree(test_case.encryption_secret.clone(), FRAMING_N_LEAVES);
431 
432     let secrets = EpochSecrets {
433         secret_tree,
434         resumption_secret: vec![0_u8; cs.kdf_extract_size()].into(),
435         sender_data_secret: test_case.sender_data_secret.clone().into(),
436     };
437 
438     group.epoch_secrets = secrets;
439     group.state.context = context;
440     let membership_key = test_case.membership_key.clone();
441     group.key_schedule.set_membership_key(membership_key);
442 
443     group
444 }
445 
446 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
process_message<P: CipherSuiteProvider>( test_case: &FramingTestCase, message: &[u8], cs: &P, ) -> Content447 async fn process_message<P: CipherSuiteProvider>(
448     test_case: &FramingTestCase,
449     message: &[u8],
450     cs: &P,
451 ) -> Content {
452     // Enabling encryption doesn't matter for processing
453     let mut group = make_group(test_case, false, true, cs).await;
454     let message = MlsMessage::mls_decode(&mut &*message).unwrap();
455     let evt_or_cont = group.get_event_from_incoming_message(message);
456 
457     match evt_or_cont.await.unwrap() {
458         EventOrContent::Content(content) => content.content.content,
459         EventOrContent::Event(_) => panic!("expected content, got event"),
460     }
461 }
462