1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5 use alloc::boxed::Box;
6 use alloc::vec;
7 use alloc::vec::Vec;
8 use mls_rs_codec::{MlsDecode, MlsEncode};
9 use mls_rs_core::crypto::CipherSuite;
10
11 use crate::{
12 client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
13 crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
14 group::{
15 proposal::{AddProposal, Proposal, ProposalOrRef, RemoveProposal, UpdateProposal},
16 proposal_cache::test_utils::CommitReceiver,
17 proposal_ref::ProposalRef,
18 test_utils::TEST_GROUP,
19 LeafIndex, Sender, TreeKemPublic,
20 },
21 identity::basic::BasicIdentityProvider,
22 key_package::test_utils::test_key_package,
23 tree_kem::{
24 leaf_node::test_utils::default_properties, node::NodeVec, test_utils::TreeWithSigners,
25 },
26 };
27
28 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
29 struct TreeModsTestCase {
30 #[serde(with = "hex::serde")]
31 pub tree_before: Vec<u8>,
32 #[serde(with = "hex::serde")]
33 pub proposal: Vec<u8>,
34 pub proposal_sender: u32,
35 #[serde(with = "hex::serde")]
36 pub tree_after: Vec<u8>,
37 }
38
39 impl TreeModsTestCase {
40 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
41 #[cfg_attr(coverage_nightly, coverage(off))]
new(tree_before: TreeKemPublic, proposal: Proposal, proposal_sender: u32) -> Self42 async fn new(tree_before: TreeKemPublic, proposal: Proposal, proposal_sender: u32) -> Self {
43 let tree_after = apply_proposal(proposal.clone(), proposal_sender, &tree_before).await;
44
45 Self {
46 tree_before: tree_before.nodes.mls_encode_to_vec().unwrap(),
47 proposal: proposal.mls_encode_to_vec().unwrap(),
48 tree_after: tree_after.nodes.mls_encode_to_vec().unwrap(),
49 proposal_sender,
50 }
51 }
52 }
53
54 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
55 #[cfg_attr(coverage_nightly, coverage(off))]
generate_tree_mods_tests() -> Vec<TreeModsTestCase>56 async fn generate_tree_mods_tests() -> Vec<TreeModsTestCase> {
57 let mut test_vector = vec![];
58 let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
59
60 // Update
61 let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
62 let update = generate_update(6, &tree_before).await;
63 test_vector.push(TreeModsTestCase::new(tree_before.tree, update, 6).await);
64
65 // Add in the middle
66 let mut tree_before = TreeWithSigners::make_full_tree(6, &cs).await;
67 tree_before.remove_member(3);
68 test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_add().await, 2).await);
69
70 // Add at the end
71 let tree_before = TreeWithSigners::make_full_tree(6, &cs).await;
72 test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_add().await, 2).await);
73
74 // Add at the end, tree grows
75 let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
76 test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_add().await, 2).await);
77
78 // Remove in the middle
79 let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
80 test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_remove(2), 2).await);
81
82 // Remove at the end
83 let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
84 test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_remove(7), 2).await);
85
86 // Remove at the end, tree shrinks
87 let tree_before = TreeWithSigners::make_full_tree(9, &cs).await;
88 test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_remove(8), 2).await);
89
90 test_vector
91 }
92
93 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
tree_modifications_interop()94 async fn tree_modifications_interop() {
95 // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/tree-operations.json
96
97 // All test vectors use cipher suite 1
98 if try_test_cipher_suite_provider(*CipherSuite::CURVE25519_AES128).is_none() {
99 return;
100 }
101
102 #[cfg(not(mls_build_async))]
103 let test_cases: Vec<TreeModsTestCase> =
104 load_test_case_json!(tree_modifications_interop, generate_tree_mods_tests());
105
106 #[cfg(mls_build_async)]
107 let test_cases: Vec<TreeModsTestCase> =
108 load_test_case_json!(tree_modifications_interop, generate_tree_mods_tests().await);
109
110 for test_case in test_cases.into_iter() {
111 let nodes = NodeVec::mls_decode(&mut &*test_case.tree_before).unwrap();
112
113 let tree_before =
114 TreeKemPublic::import_node_data(nodes, &BasicIdentityProvider, &Default::default())
115 .await
116 .unwrap();
117
118 let proposal = Proposal::mls_decode(&mut &*test_case.proposal).unwrap();
119
120 let tree_after = apply_proposal(proposal, test_case.proposal_sender, &tree_before).await;
121
122 let tree_after = tree_after.nodes.mls_encode_to_vec().unwrap();
123
124 assert_eq!(tree_after, test_case.tree_after);
125 }
126 }
127
128 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
apply_proposal( proposal: Proposal, sender: u32, tree_before: &TreeKemPublic, ) -> TreeKemPublic129 async fn apply_proposal(
130 proposal: Proposal,
131 sender: u32,
132 tree_before: &TreeKemPublic,
133 ) -> TreeKemPublic {
134 let cs = test_cipher_suite_provider(CipherSuite::CURVE25519_AES128);
135 let p_ref = ProposalRef::new_fake(b"fake ref".to_vec());
136
137 CommitReceiver::new(tree_before, Sender::Member(0), LeafIndex(1), cs)
138 .cache(p_ref.clone(), proposal, Sender::Member(sender))
139 .receive(vec![ProposalOrRef::Reference(p_ref)])
140 .await
141 .unwrap()
142 .public_tree
143 }
144
145 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
146 #[cfg_attr(coverage_nightly, coverage(off))]
generate_add() -> Proposal147 async fn generate_add() -> Proposal {
148 let key_package = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "Roger").await;
149 Proposal::Add(Box::new(AddProposal { key_package }))
150 }
151
152 #[cfg_attr(coverage_nightly, coverage(off))]
generate_remove(i: u32) -> Proposal153 fn generate_remove(i: u32) -> Proposal {
154 let to_remove = LeafIndex(i);
155 Proposal::Remove(RemoveProposal { to_remove })
156 }
157
158 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
159 #[cfg_attr(coverage_nightly, coverage(off))]
generate_update(i: u32, tree: &TreeWithSigners) -> Proposal160 async fn generate_update(i: u32, tree: &TreeWithSigners) -> Proposal {
161 let signer = tree.signers[i as usize].as_ref().unwrap();
162 let mut leaf_node = tree.tree.get_leaf_node(LeafIndex(i)).unwrap().clone();
163
164 leaf_node
165 .update(
166 &test_cipher_suite_provider(TEST_CIPHER_SUITE),
167 TEST_GROUP,
168 i,
169 default_properties(),
170 None,
171 signer,
172 )
173 .await
174 .unwrap();
175
176 Proposal::Update(UpdateProposal { leaf_node })
177 }
178