1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5 use crate::{
6 client::test_utils::TEST_PROTOCOL_VERSION,
7 crypto::test_utils::try_test_cipher_suite_provider,
8 group::{
9 confirmation_tag::ConfirmationTag, framing::Content, message_processor::MessageProcessor,
10 message_signature::AuthenticatedContent, test_utils::GroupWithoutKeySchedule, Commit,
11 GroupContext, PathSecret, Sender,
12 },
13 identity::basic::BasicIdentityProvider,
14 tree_kem::{
15 node::{LeafIndex, NodeVec},
16 TreeKemPrivate, TreeKemPublic, UpdatePath,
17 },
18 WireFormat,
19 };
20 use alloc::vec;
21 use alloc::vec::Vec;
22 use mls_rs_codec::MlsDecode;
23 use mls_rs_core::{crypto::CipherSuiteProvider, extension::ExtensionList};
24
25 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
26 struct TreeKemTestCase {
27 pub cipher_suite: u16,
28
29 #[serde(with = "hex::serde")]
30 pub group_id: Vec<u8>,
31 epoch: u64,
32 #[serde(with = "hex::serde")]
33 confirmed_transcript_hash: Vec<u8>,
34 #[serde(with = "hex::serde")]
35 ratchet_tree: Vec<u8>,
36
37 leaves_private: Vec<TestLeafPrivate>,
38 update_paths: Vec<TestUpdatePath>,
39 }
40
41 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
42 struct TestLeafPrivate {
43 index: u32,
44 #[serde(with = "hex::serde")]
45 encryption_priv: Vec<u8>,
46 #[serde(with = "hex::serde")]
47 signature_priv: Vec<u8>,
48 path_secrets: Vec<TestPathSecretPrivate>,
49 }
50
51 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
52 struct TestPathSecretPrivate {
53 node: u32,
54 #[serde(with = "hex::serde")]
55 path_secret: Vec<u8>,
56 }
57
58 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
59 struct TestUpdatePath {
60 sender: u32,
61 #[serde(with = "hex::serde")]
62 update_path: Vec<u8>,
63 #[serde(with = "hex::serde")]
64 tree_hash_after: Vec<u8>,
65 #[serde(with = "hex::serde")]
66 commit_secret: Vec<u8>,
67 }
68
69 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
tree_kem()70 async fn tree_kem() {
71 // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/treekem.json
72
73 let test_cases: Vec<TreeKemTestCase> =
74 load_test_case_json!(interop_tree_kem, Vec::<TreeKemTestCase>::new());
75
76 for test_case in test_cases {
77 let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
78 continue;
79 };
80
81 // Import the public ratchet tree
82 let nodes = NodeVec::mls_decode(&mut &*test_case.ratchet_tree).unwrap();
83
84 let mut tree =
85 TreeKemPublic::import_node_data(nodes, &BasicIdentityProvider, &Default::default())
86 .await
87 .unwrap();
88
89 // Construct GroupContext
90 let group_context = GroupContext {
91 protocol_version: TEST_PROTOCOL_VERSION,
92 cipher_suite: cs.cipher_suite(),
93 group_id: test_case.group_id,
94 epoch: test_case.epoch,
95 tree_hash: tree.tree_hash(&cs).await.unwrap(),
96 confirmed_transcript_hash: test_case.confirmed_transcript_hash.into(),
97 extensions: ExtensionList::new(),
98 };
99
100 for leaf in test_case.leaves_private.iter() {
101 // Construct the private ratchet tree
102 let mut tree_private = TreeKemPrivate::new(LeafIndex(leaf.index));
103
104 // Set and validate HPKE keys on direct path
105 let path = tree.nodes.direct_copath(tree_private.self_index);
106
107 tree_private.secret_keys = Vec::new();
108
109 for dp in path {
110 let dp = dp.path;
111
112 let secret = leaf
113 .path_secrets
114 .iter()
115 .find_map(|s| (s.node == dp).then_some(s.path_secret.clone()));
116
117 let private_key = if let Some(secret) = secret {
118 let (secret_key, public_key) = PathSecret::from(secret)
119 .to_hpke_key_pair(&cs)
120 .await
121 .unwrap();
122
123 let tree_public = &tree.nodes.borrow_as_parent(dp).unwrap().public_key;
124 assert_eq!(&public_key, tree_public);
125
126 Some(secret_key)
127 } else {
128 None
129 };
130
131 tree_private.secret_keys.push(private_key);
132 }
133
134 // Set HPKE key for leaf
135 tree_private
136 .secret_keys
137 .insert(0, Some(leaf.encryption_priv.clone().into()));
138
139 let paths = test_case
140 .update_paths
141 .iter()
142 .filter(|path| path.sender != leaf.index);
143
144 for update_path in paths {
145 let mut group = GroupWithoutKeySchedule::new(cs.cipher_suite()).await;
146 group.state.context = group_context.clone();
147 group.state.public_tree = tree.clone();
148 group.private_tree = tree_private.clone();
149
150 let path = UpdatePath::mls_decode(&mut &*update_path.update_path).unwrap();
151
152 let commit = Commit {
153 proposals: vec![],
154 path: Some(path),
155 };
156
157 let mut auth_content = AuthenticatedContent::new(
158 &group_context,
159 Sender::Member(update_path.sender),
160 Content::Commit(alloc::boxed::Box::new(commit)),
161 vec![],
162 WireFormat::PublicMessage,
163 );
164
165 auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
166
167 // Hack not to increment epoch
168 group.state.context.epoch -= 1;
169
170 group.process_commit(auth_content, None).await.unwrap();
171
172 // Check that we got the expected commit secret and correctly merged the update path.
173 // This implies that we computed the path secrets correctly.
174 let commit_secret = group.secrets.unwrap().1;
175
176 assert_eq!(&*commit_secret, &update_path.commit_secret);
177
178 let new_tree = &mut group.provisional_public_state.unwrap().public_tree;
179 let new_tree_hash = new_tree.tree_hash(&cs).await.unwrap();
180
181 assert_eq!(&new_tree_hash, &update_path.tree_hash_after);
182 }
183 }
184 }
185 }
186