1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::boxed::Box;
6 use alloc::vec;
7 use alloc::vec::Vec;
8 use mls_rs_codec::{MlsDecode, MlsEncode};
9 use mls_rs_core::crypto::CipherSuite;
10 
11 use crate::{
12     client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
13     crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
14     group::{
15         proposal::{AddProposal, Proposal, ProposalOrRef, RemoveProposal, UpdateProposal},
16         proposal_cache::test_utils::CommitReceiver,
17         proposal_ref::ProposalRef,
18         test_utils::TEST_GROUP,
19         LeafIndex, Sender, TreeKemPublic,
20     },
21     identity::basic::BasicIdentityProvider,
22     key_package::test_utils::test_key_package,
23     tree_kem::{
24         leaf_node::test_utils::default_properties, node::NodeVec, test_utils::TreeWithSigners,
25     },
26 };
27 
28 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
29 struct TreeModsTestCase {
30     #[serde(with = "hex::serde")]
31     pub tree_before: Vec<u8>,
32     #[serde(with = "hex::serde")]
33     pub proposal: Vec<u8>,
34     pub proposal_sender: u32,
35     #[serde(with = "hex::serde")]
36     pub tree_after: Vec<u8>,
37 }
38 
39 impl TreeModsTestCase {
40     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
41     #[cfg_attr(coverage_nightly, coverage(off))]
new(tree_before: TreeKemPublic, proposal: Proposal, proposal_sender: u32) -> Self42     async fn new(tree_before: TreeKemPublic, proposal: Proposal, proposal_sender: u32) -> Self {
43         let tree_after = apply_proposal(proposal.clone(), proposal_sender, &tree_before).await;
44 
45         Self {
46             tree_before: tree_before.nodes.mls_encode_to_vec().unwrap(),
47             proposal: proposal.mls_encode_to_vec().unwrap(),
48             tree_after: tree_after.nodes.mls_encode_to_vec().unwrap(),
49             proposal_sender,
50         }
51     }
52 }
53 
54 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
55 #[cfg_attr(coverage_nightly, coverage(off))]
generate_tree_mods_tests() -> Vec<TreeModsTestCase>56 async fn generate_tree_mods_tests() -> Vec<TreeModsTestCase> {
57     let mut test_vector = vec![];
58     let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
59 
60     // Update
61     let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
62     let update = generate_update(6, &tree_before).await;
63     test_vector.push(TreeModsTestCase::new(tree_before.tree, update, 6).await);
64 
65     // Add in the middle
66     let mut tree_before = TreeWithSigners::make_full_tree(6, &cs).await;
67     tree_before.remove_member(3);
68     test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_add().await, 2).await);
69 
70     // Add at the end
71     let tree_before = TreeWithSigners::make_full_tree(6, &cs).await;
72     test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_add().await, 2).await);
73 
74     // Add at the end, tree grows
75     let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
76     test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_add().await, 2).await);
77 
78     // Remove in the middle
79     let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
80     test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_remove(2), 2).await);
81 
82     // Remove at the end
83     let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
84     test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_remove(7), 2).await);
85 
86     // Remove at the end, tree shrinks
87     let tree_before = TreeWithSigners::make_full_tree(9, &cs).await;
88     test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_remove(8), 2).await);
89 
90     test_vector
91 }
92 
93 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
tree_modifications_interop()94 async fn tree_modifications_interop() {
95     // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/tree-operations.json
96 
97     // All test vectors use cipher suite 1
98     if try_test_cipher_suite_provider(*CipherSuite::CURVE25519_AES128).is_none() {
99         return;
100     }
101 
102     #[cfg(not(mls_build_async))]
103     let test_cases: Vec<TreeModsTestCase> =
104         load_test_case_json!(tree_modifications_interop, generate_tree_mods_tests());
105 
106     #[cfg(mls_build_async)]
107     let test_cases: Vec<TreeModsTestCase> =
108         load_test_case_json!(tree_modifications_interop, generate_tree_mods_tests().await);
109 
110     for test_case in test_cases.into_iter() {
111         let nodes = NodeVec::mls_decode(&mut &*test_case.tree_before).unwrap();
112 
113         let tree_before =
114             TreeKemPublic::import_node_data(nodes, &BasicIdentityProvider, &Default::default())
115                 .await
116                 .unwrap();
117 
118         let proposal = Proposal::mls_decode(&mut &*test_case.proposal).unwrap();
119 
120         let tree_after = apply_proposal(proposal, test_case.proposal_sender, &tree_before).await;
121 
122         let tree_after = tree_after.nodes.mls_encode_to_vec().unwrap();
123 
124         assert_eq!(tree_after, test_case.tree_after);
125     }
126 }
127 
128 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
apply_proposal( proposal: Proposal, sender: u32, tree_before: &TreeKemPublic, ) -> TreeKemPublic129 async fn apply_proposal(
130     proposal: Proposal,
131     sender: u32,
132     tree_before: &TreeKemPublic,
133 ) -> TreeKemPublic {
134     let cs = test_cipher_suite_provider(CipherSuite::CURVE25519_AES128);
135     let p_ref = ProposalRef::new_fake(b"fake ref".to_vec());
136 
137     CommitReceiver::new(tree_before, Sender::Member(0), LeafIndex(1), cs)
138         .cache(p_ref.clone(), proposal, Sender::Member(sender))
139         .receive(vec![ProposalOrRef::Reference(p_ref)])
140         .await
141         .unwrap()
142         .public_tree
143 }
144 
145 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
146 #[cfg_attr(coverage_nightly, coverage(off))]
generate_add() -> Proposal147 async fn generate_add() -> Proposal {
148     let key_package = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "Roger").await;
149     Proposal::Add(Box::new(AddProposal { key_package }))
150 }
151 
152 #[cfg_attr(coverage_nightly, coverage(off))]
generate_remove(i: u32) -> Proposal153 fn generate_remove(i: u32) -> Proposal {
154     let to_remove = LeafIndex(i);
155     Proposal::Remove(RemoveProposal { to_remove })
156 }
157 
158 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
159 #[cfg_attr(coverage_nightly, coverage(off))]
generate_update(i: u32, tree: &TreeWithSigners) -> Proposal160 async fn generate_update(i: u32, tree: &TreeWithSigners) -> Proposal {
161     let signer = tree.signers[i as usize].as_ref().unwrap();
162     let mut leaf_node = tree.tree.get_leaf_node(LeafIndex(i)).unwrap().clone();
163 
164     leaf_node
165         .update(
166             &test_cipher_suite_provider(TEST_CIPHER_SUITE),
167             TEST_GROUP,
168             i,
169             default_properties(),
170             None,
171             signer,
172         )
173         .await
174         .unwrap();
175 
176     Proposal::Update(UpdateProposal { leaf_node })
177 }
178