1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use crate::{
6     client::test_utils::TEST_PROTOCOL_VERSION,
7     crypto::test_utils::try_test_cipher_suite_provider,
8     group::{
9         confirmation_tag::ConfirmationTag, framing::Content, message_processor::MessageProcessor,
10         message_signature::AuthenticatedContent, test_utils::GroupWithoutKeySchedule, Commit,
11         GroupContext, PathSecret, Sender,
12     },
13     identity::basic::BasicIdentityProvider,
14     tree_kem::{
15         node::{LeafIndex, NodeVec},
16         TreeKemPrivate, TreeKemPublic, UpdatePath,
17     },
18     WireFormat,
19 };
20 use alloc::vec;
21 use alloc::vec::Vec;
22 use mls_rs_codec::MlsDecode;
23 use mls_rs_core::{crypto::CipherSuiteProvider, extension::ExtensionList};
24 
25 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
26 struct TreeKemTestCase {
27     pub cipher_suite: u16,
28 
29     #[serde(with = "hex::serde")]
30     pub group_id: Vec<u8>,
31     epoch: u64,
32     #[serde(with = "hex::serde")]
33     confirmed_transcript_hash: Vec<u8>,
34     #[serde(with = "hex::serde")]
35     ratchet_tree: Vec<u8>,
36 
37     leaves_private: Vec<TestLeafPrivate>,
38     update_paths: Vec<TestUpdatePath>,
39 }
40 
41 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
42 struct TestLeafPrivate {
43     index: u32,
44     #[serde(with = "hex::serde")]
45     encryption_priv: Vec<u8>,
46     #[serde(with = "hex::serde")]
47     signature_priv: Vec<u8>,
48     path_secrets: Vec<TestPathSecretPrivate>,
49 }
50 
51 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
52 struct TestPathSecretPrivate {
53     node: u32,
54     #[serde(with = "hex::serde")]
55     path_secret: Vec<u8>,
56 }
57 
58 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
59 struct TestUpdatePath {
60     sender: u32,
61     #[serde(with = "hex::serde")]
62     update_path: Vec<u8>,
63     #[serde(with = "hex::serde")]
64     tree_hash_after: Vec<u8>,
65     #[serde(with = "hex::serde")]
66     commit_secret: Vec<u8>,
67 }
68 
69 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
tree_kem()70 async fn tree_kem() {
71     // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/treekem.json
72 
73     let test_cases: Vec<TreeKemTestCase> =
74         load_test_case_json!(interop_tree_kem, Vec::<TreeKemTestCase>::new());
75 
76     for test_case in test_cases {
77         let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
78             continue;
79         };
80 
81         // Import the public ratchet tree
82         let nodes = NodeVec::mls_decode(&mut &*test_case.ratchet_tree).unwrap();
83 
84         let mut tree =
85             TreeKemPublic::import_node_data(nodes, &BasicIdentityProvider, &Default::default())
86                 .await
87                 .unwrap();
88 
89         // Construct GroupContext
90         let group_context = GroupContext {
91             protocol_version: TEST_PROTOCOL_VERSION,
92             cipher_suite: cs.cipher_suite(),
93             group_id: test_case.group_id,
94             epoch: test_case.epoch,
95             tree_hash: tree.tree_hash(&cs).await.unwrap(),
96             confirmed_transcript_hash: test_case.confirmed_transcript_hash.into(),
97             extensions: ExtensionList::new(),
98         };
99 
100         for leaf in test_case.leaves_private.iter() {
101             // Construct the private ratchet tree
102             let mut tree_private = TreeKemPrivate::new(LeafIndex(leaf.index));
103 
104             // Set and validate HPKE keys on direct path
105             let path = tree.nodes.direct_copath(tree_private.self_index);
106 
107             tree_private.secret_keys = Vec::new();
108 
109             for dp in path {
110                 let dp = dp.path;
111 
112                 let secret = leaf
113                     .path_secrets
114                     .iter()
115                     .find_map(|s| (s.node == dp).then_some(s.path_secret.clone()));
116 
117                 let private_key = if let Some(secret) = secret {
118                     let (secret_key, public_key) = PathSecret::from(secret)
119                         .to_hpke_key_pair(&cs)
120                         .await
121                         .unwrap();
122 
123                     let tree_public = &tree.nodes.borrow_as_parent(dp).unwrap().public_key;
124                     assert_eq!(&public_key, tree_public);
125 
126                     Some(secret_key)
127                 } else {
128                     None
129                 };
130 
131                 tree_private.secret_keys.push(private_key);
132             }
133 
134             // Set HPKE key for leaf
135             tree_private
136                 .secret_keys
137                 .insert(0, Some(leaf.encryption_priv.clone().into()));
138 
139             let paths = test_case
140                 .update_paths
141                 .iter()
142                 .filter(|path| path.sender != leaf.index);
143 
144             for update_path in paths {
145                 let mut group = GroupWithoutKeySchedule::new(cs.cipher_suite()).await;
146                 group.state.context = group_context.clone();
147                 group.state.public_tree = tree.clone();
148                 group.private_tree = tree_private.clone();
149 
150                 let path = UpdatePath::mls_decode(&mut &*update_path.update_path).unwrap();
151 
152                 let commit = Commit {
153                     proposals: vec![],
154                     path: Some(path),
155                 };
156 
157                 let mut auth_content = AuthenticatedContent::new(
158                     &group_context,
159                     Sender::Member(update_path.sender),
160                     Content::Commit(alloc::boxed::Box::new(commit)),
161                     vec![],
162                     WireFormat::PublicMessage,
163                 );
164 
165                 auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
166 
167                 // Hack not to increment epoch
168                 group.state.context.epoch -= 1;
169 
170                 group.process_commit(auth_content, None).await.unwrap();
171 
172                 // Check that we got the expected commit secret and correctly merged the update path.
173                 // This implies that we computed the path secrets correctly.
174                 let commit_secret = group.secrets.unwrap().1;
175 
176                 assert_eq!(&*commit_secret, &update_path.commit_secret);
177 
178                 let new_tree = &mut group.provisional_public_state.unwrap().public_tree;
179                 let new_tree_hash = new_tree.tree_hash(&cs).await.unwrap();
180 
181                 assert_eq!(&new_tree_hash, &update_path.tree_hash_after);
182             }
183         }
184     }
185 }
186