1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5 use alloc::vec;
6 use alloc::vec::Vec;
7 use mls_rs_codec::{MlsDecode, MlsEncode};
8 use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider, SignaturePublicKey};
9
10 use crate::{
11 client::test_utils::{TestClientConfig, TEST_PROTOCOL_VERSION},
12 crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
13 group::{
14 confirmation_tag::ConfirmationTag,
15 epoch::EpochSecrets,
16 framing::{Content, WireFormat},
17 message_processor::{EventOrContent, MessageProcessor},
18 mls_rules::EncryptionOptions,
19 padding::PaddingMode,
20 proposal::{Proposal, RemoveProposal},
21 secret_tree::test_utils::get_test_tree,
22 test_utils::{random_bytes, test_group_custom_config},
23 AuthenticatedContent, Commit, Group, GroupContext, MlsMessage, Sender,
24 },
25 mls_rules::DefaultMlsRules,
26 test_utils::is_edwards,
27 tree_kem::{leaf_node::test_utils::get_basic_test_node, node::LeafIndex},
28 };
29
30 const FRAMING_N_LEAVES: u32 = 2;
31
32 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
33 struct FramingTestCase {
34 #[serde(flatten)]
35 pub context: InteropGroupContext,
36
37 #[serde(with = "hex::serde")]
38 pub signature_priv: Vec<u8>,
39 #[serde(with = "hex::serde")]
40 pub signature_pub: Vec<u8>,
41
42 #[serde(with = "hex::serde")]
43 pub encryption_secret: Vec<u8>,
44 #[serde(with = "hex::serde")]
45 pub sender_data_secret: Vec<u8>,
46 #[serde(with = "hex::serde")]
47 pub membership_key: Vec<u8>,
48
49 #[serde(with = "hex::serde")]
50 pub proposal: Vec<u8>,
51 #[serde(with = "hex::serde")]
52 pub proposal_priv: Vec<u8>,
53 #[serde(with = "hex::serde")]
54 pub proposal_pub: Vec<u8>,
55
56 #[serde(with = "hex::serde")]
57 pub commit: Vec<u8>,
58 #[serde(with = "hex::serde")]
59 pub commit_priv: Vec<u8>,
60 #[serde(with = "hex::serde")]
61 pub commit_pub: Vec<u8>,
62
63 #[serde(with = "hex::serde")]
64 pub application: Vec<u8>,
65 #[serde(with = "hex::serde")]
66 pub application_priv: Vec<u8>,
67 }
68
69 impl FramingTestCase {
70 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
71 #[cfg_attr(coverage_nightly, coverage(off))]
random<P: CipherSuiteProvider>(cs: &P) -> Self72 async fn random<P: CipherSuiteProvider>(cs: &P) -> Self {
73 let mut context = InteropGroupContext::random(cs);
74 context.cipher_suite = cs.cipher_suite().into();
75
76 let (mut signature_priv, signature_pub) = cs.signature_key_generate().await.unwrap();
77
78 if is_edwards(*cs.cipher_suite()) {
79 signature_priv = signature_priv[0..signature_priv.len() / 2].to_vec().into();
80 }
81
82 Self {
83 context,
84 signature_priv: signature_priv.to_vec(),
85 signature_pub: signature_pub.to_vec(),
86 encryption_secret: random_bytes(cs.kdf_extract_size()),
87 sender_data_secret: random_bytes(cs.kdf_extract_size()),
88 membership_key: random_bytes(cs.kdf_extract_size()),
89 ..Default::default()
90 }
91 }
92 }
93
94 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
95 pub struct InteropGroupContext {
96 pub cipher_suite: u16,
97 #[serde(with = "hex::serde")]
98 pub group_id: Vec<u8>,
99 pub epoch: u64,
100 #[serde(with = "hex::serde")]
101 pub tree_hash: Vec<u8>,
102 #[serde(with = "hex::serde")]
103 pub confirmed_transcript_hash: Vec<u8>,
104 }
105
106 impl InteropGroupContext {
107 #[cfg_attr(coverage_nightly, coverage(off))]
random<P: CipherSuiteProvider>(cs: &P) -> Self108 fn random<P: CipherSuiteProvider>(cs: &P) -> Self {
109 Self {
110 cipher_suite: cs.cipher_suite().into(),
111 group_id: random_bytes(cs.kdf_extract_size()),
112 epoch: 0x121212,
113 tree_hash: random_bytes(cs.kdf_extract_size()),
114 confirmed_transcript_hash: random_bytes(cs.kdf_extract_size()),
115 }
116 }
117 }
118
119 impl From<InteropGroupContext> for GroupContext {
120 #[cfg_attr(coverage_nightly, coverage(off))]
from(ctx: InteropGroupContext) -> Self121 fn from(ctx: InteropGroupContext) -> Self {
122 Self {
123 cipher_suite: ctx.cipher_suite.into(),
124 protocol_version: TEST_PROTOCOL_VERSION,
125 group_id: ctx.group_id,
126 epoch: ctx.epoch,
127 tree_hash: ctx.tree_hash,
128 confirmed_transcript_hash: ctx.confirmed_transcript_hash.into(),
129 extensions: vec![].into(),
130 }
131 }
132 }
133
134 // The test vector can be found here:
135 // https://github.com/mlswg/mls-implementations/blob/main/test-vectors/message-protection.json
136 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
framing_proposal()137 async fn framing_proposal() {
138 #[cfg(not(mls_build_async))]
139 let test_cases: Vec<FramingTestCase> =
140 load_test_case_json!(framing, generate_framing_test_vector());
141
142 #[cfg(mls_build_async)]
143 let test_cases: Vec<FramingTestCase> =
144 load_test_case_json!(framing, generate_framing_test_vector().await);
145
146 for test_case in test_cases.into_iter() {
147 let Some(cs) = try_test_cipher_suite_provider(test_case.context.cipher_suite) else {
148 continue;
149 };
150
151 let to_check = vec![
152 test_case.proposal_priv.clone(),
153 test_case.proposal_pub.clone(),
154 ];
155
156 // Wasm uses incompatible signature secret key format
157 #[cfg(not(target_arch = "wasm32"))]
158 let mut to_check = to_check;
159
160 #[cfg(not(target_arch = "wasm32"))]
161 for enable_encryption in [true, false] {
162 let proposal = Proposal::mls_decode(&mut &*test_case.proposal).unwrap();
163
164 let built = make_group(&test_case, true, enable_encryption, &cs)
165 .await
166 .proposal_message(proposal, vec![])
167 .await
168 .unwrap()
169 .mls_encode_to_vec()
170 .unwrap();
171
172 to_check.push(built);
173 }
174
175 let proposal = Proposal::mls_decode(&mut &*test_case.proposal).unwrap();
176
177 for message in to_check {
178 match process_message(&test_case, &message, &cs).await {
179 Content::Proposal(p) => assert_eq!(p.as_ref(), &proposal),
180 _ => panic!("received value not proposal"),
181 };
182 }
183 }
184 }
185
186 // The test vector can be found here:
187 // https://github.com/mlswg/mls-implementations/blob/main/test-vectors/message-protection.json
188 // Wasm uses incompatible signature secret key format
189 #[cfg(not(target_arch = "wasm32"))]
190 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
framing_application()191 async fn framing_application() {
192 #[cfg(not(mls_build_async))]
193 let test_cases: Vec<FramingTestCase> =
194 load_test_case_json!(framing, generate_framing_test_vector());
195
196 #[cfg(mls_build_async)]
197 let test_cases: Vec<FramingTestCase> =
198 load_test_case_json!(framing, generate_framing_test_vector().await);
199
200 for test_case in test_cases.into_iter() {
201 let Some(cs) = try_test_cipher_suite_provider(test_case.context.cipher_suite) else {
202 continue;
203 };
204
205 let built_priv = make_group(&test_case, true, true, &cs)
206 .await
207 .encrypt_application_message(&test_case.application, vec![])
208 .await
209 .unwrap()
210 .mls_encode_to_vec()
211 .unwrap();
212
213 for message in [&test_case.application_priv, &built_priv] {
214 match process_message(&test_case, message, &cs).await {
215 Content::Application(data) => assert_eq!(data.as_ref(), &test_case.application),
216 _ => panic!("decrypted value not application data"),
217 };
218 }
219 }
220 }
221
222 // The test vector can be found here:
223 // https://github.com/mlswg/mls-implementations/blob/main/test-vectors/message-protection.json
224 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
framing_commit()225 async fn framing_commit() {
226 #[cfg(not(mls_build_async))]
227 let test_cases: Vec<FramingTestCase> =
228 load_test_case_json!(framing, generate_framing_test_vector());
229
230 #[cfg(mls_build_async)]
231 let test_cases: Vec<FramingTestCase> =
232 load_test_case_json!(framing, generate_framing_test_vector().await);
233
234 for test_case in test_cases.into_iter() {
235 let Some(cs) = try_test_cipher_suite_provider(test_case.context.cipher_suite) else {
236 continue;
237 };
238
239 let commit = Commit::mls_decode(&mut &*test_case.commit).unwrap();
240
241 let to_check = vec![test_case.commit_priv.clone(), test_case.commit_pub.clone()];
242
243 // Wasm uses incompatible signature secret key format
244 #[cfg(not(target_arch = "wasm32"))]
245 let to_check = {
246 let mut to_check = to_check;
247
248 let mut signature_priv = test_case.signature_priv.clone();
249
250 if is_edwards(test_case.context.cipher_suite) {
251 signature_priv.extend(test_case.signature_pub.iter());
252 }
253
254 let mut auth_content = AuthenticatedContent::new_signed(
255 &cs,
256 &test_case.context.clone().into(),
257 Sender::Member(1),
258 Content::Commit(alloc::boxed::Box::new(commit.clone())),
259 &signature_priv.into(),
260 WireFormat::PublicMessage,
261 vec![],
262 )
263 .await
264 .unwrap();
265
266 auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
267
268 for enable_encryption in [true, false] {
269 let built = make_group(&test_case, true, enable_encryption, &cs)
270 .await
271 .format_for_wire(auth_content.clone())
272 .await
273 .unwrap()
274 .mls_encode_to_vec()
275 .unwrap();
276
277 to_check.push(built);
278 }
279
280 to_check
281 };
282
283 for message in to_check {
284 match process_message(&test_case, &message, &cs).await {
285 Content::Commit(c) => assert_eq!(&*c, &commit),
286 _ => panic!("received value not commit"),
287 };
288 }
289 let commit = Commit::mls_decode(&mut &*test_case.commit).unwrap();
290
291 match process_message(&test_case, &test_case.commit_priv.clone(), &cs).await {
292 Content::Commit(c) => assert_eq!(&*c, &commit),
293 _ => panic!("received value not commit"),
294 };
295 }
296 }
297
298 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
299 #[cfg_attr(coverage_nightly, coverage(off))]
generate_framing_test_vector() -> Vec<FramingTestCase>300 async fn generate_framing_test_vector() -> Vec<FramingTestCase> {
301 let mut test_vector = vec![];
302
303 for cs in CipherSuite::all() {
304 let cs = test_cipher_suite_provider(cs);
305
306 let mut test_case = FramingTestCase::random(&cs).await;
307
308 // Generate private application message
309 test_case.application = cs.random_bytes_vec(42).unwrap();
310
311 let application_priv = make_group(&test_case, true, true, &cs)
312 .await
313 .encrypt_application_message(&test_case.application, vec![])
314 .await
315 .unwrap();
316
317 test_case.application_priv = application_priv.mls_encode_to_vec().unwrap();
318
319 // Generate private and public proposal message
320 let proposal = Proposal::Remove(RemoveProposal {
321 to_remove: LeafIndex(2),
322 });
323
324 test_case.proposal = proposal.mls_encode_to_vec().unwrap();
325
326 let mut group = make_group(&test_case, true, false, &cs).await;
327 let proposal_pub = group.proposal_message(proposal.clone(), vec![]).await;
328 test_case.proposal_pub = proposal_pub.unwrap().mls_encode_to_vec().unwrap();
329
330 let mut group = make_group(&test_case, true, true, &cs).await;
331 let proposal_priv = group.proposal_message(proposal, vec![]).await.unwrap();
332 test_case.proposal_priv = proposal_priv.mls_encode_to_vec().unwrap();
333
334 // Generate private and public commit message
335 let commit = Commit {
336 proposals: vec![],
337 path: None,
338 };
339
340 test_case.commit = commit.mls_encode_to_vec().unwrap();
341
342 let mut auth_content = AuthenticatedContent::new_signed(
343 &cs,
344 group.context(),
345 Sender::Member(1),
346 Content::Commit(alloc::boxed::Box::new(commit.clone())),
347 &group.signer,
348 WireFormat::PublicMessage,
349 vec![],
350 )
351 .await
352 .unwrap();
353
354 auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
355
356 let mut group = make_group(&test_case, true, false, &cs).await;
357 let commit_pub = group.format_for_wire(auth_content.clone()).await.unwrap();
358 test_case.commit_pub = commit_pub.mls_encode_to_vec().unwrap();
359
360 let mut auth_content = AuthenticatedContent::new_signed(
361 &cs,
362 group.context(),
363 Sender::Member(1),
364 Content::Commit(alloc::boxed::Box::new(commit)),
365 &group.signer,
366 WireFormat::PrivateMessage,
367 vec![],
368 )
369 .await
370 .unwrap();
371
372 auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
373
374 let mut group = make_group(&test_case, true, true, &cs).await;
375 let commit_priv = group.format_for_wire(auth_content.clone()).await.unwrap();
376 test_case.commit_priv = commit_priv.mls_encode_to_vec().unwrap();
377
378 test_vector.push(test_case);
379 }
380
381 test_vector
382 }
383
384 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
make_group<P: CipherSuiteProvider>( test_case: &FramingTestCase, for_send: bool, control_encryption_enabled: bool, cs: &P, ) -> Group<TestClientConfig>385 async fn make_group<P: CipherSuiteProvider>(
386 test_case: &FramingTestCase,
387 for_send: bool,
388 control_encryption_enabled: bool,
389 cs: &P,
390 ) -> Group<TestClientConfig> {
391 let mut group =
392 test_group_custom_config(
393 TEST_PROTOCOL_VERSION,
394 test_case.context.cipher_suite.into(),
395 |b| {
396 b.mls_rules(DefaultMlsRules::default().with_encryption_options(
397 EncryptionOptions::new(control_encryption_enabled, PaddingMode::None),
398 ))
399 },
400 )
401 .await
402 .group;
403
404 // Add a leaf for the sender. It will get index 1.
405 let mut leaf = get_basic_test_node(cs.cipher_suite(), "leaf").await;
406
407 leaf.signing_identity.signature_key = SignaturePublicKey::from(test_case.signature_pub.clone());
408
409 group
410 .state
411 .public_tree
412 .add_leaves(vec![leaf], &group.config.0.identity_provider, cs)
413 .await
414 .unwrap();
415
416 // Convince the group that their index is 1 if they send or 0 if they receive.
417 group.private_tree.self_index = LeafIndex(if for_send { 1 } else { 0 });
418
419 // Convince the group that their signing key is the one from the test case
420 let mut signature_priv = test_case.signature_priv.clone();
421
422 if is_edwards(test_case.context.cipher_suite) {
423 signature_priv.extend(test_case.signature_pub.iter());
424 }
425
426 group.signer = signature_priv.into();
427
428 // Set the group context and secrets
429 let context = GroupContext::from(test_case.context.clone());
430 let secret_tree = get_test_tree(test_case.encryption_secret.clone(), FRAMING_N_LEAVES);
431
432 let secrets = EpochSecrets {
433 secret_tree,
434 resumption_secret: vec![0_u8; cs.kdf_extract_size()].into(),
435 sender_data_secret: test_case.sender_data_secret.clone().into(),
436 };
437
438 group.epoch_secrets = secrets;
439 group.state.context = context;
440 let membership_key = test_case.membership_key.clone();
441 group.key_schedule.set_membership_key(membership_key);
442
443 group
444 }
445
446 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
process_message<P: CipherSuiteProvider>( test_case: &FramingTestCase, message: &[u8], cs: &P, ) -> Content447 async fn process_message<P: CipherSuiteProvider>(
448 test_case: &FramingTestCase,
449 message: &[u8],
450 cs: &P,
451 ) -> Content {
452 // Enabling encryption doesn't matter for processing
453 let mut group = make_group(test_case, false, true, cs).await;
454 let message = MlsMessage::mls_decode(&mut &*message).unwrap();
455 let evt_or_cont = group.get_event_from_incoming_message(message);
456
457 match evt_or_cont.await.unwrap() {
458 EventOrContent::Content(content) => content.content.content,
459 EventOrContent::Event(_) => panic!("expected content, got event"),
460 }
461 }
462