1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5 use super::*;
6 #[cfg(feature = "tree_index")]
7 use core::fmt::{self, Debug};
8
9 #[cfg(all(feature = "tree_index", feature = "custom_proposal"))]
10 use crate::group::proposal::ProposalType;
11
12 #[cfg(feature = "tree_index")]
13 use crate::identity::CredentialType;
14
15 #[cfg(feature = "tree_index")]
16 use mls_rs_core::crypto::SignaturePublicKey;
17
18 #[cfg(all(feature = "tree_index", feature = "std"))]
19 use itertools::Itertools;
20
21 #[cfg(all(feature = "tree_index", not(feature = "std")))]
22 use alloc::collections::{btree_map::Entry, BTreeMap};
23
24 #[cfg(all(feature = "tree_index", feature = "std"))]
25 use std::collections::{hash_map::Entry, HashMap};
26
27 #[cfg(all(feature = "tree_index", not(feature = "std")))]
28 use alloc::collections::BTreeSet;
29
30 #[cfg(feature = "tree_index")]
31 use mls_rs_core::crypto::HpkePublicKey;
32
33 #[cfg(feature = "tree_index")]
34 #[derive(Clone, Default, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, Hash, PartialOrd, Ord)]
35 pub struct Identifier(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>);
36
37 #[cfg(feature = "tree_index")]
38 impl Debug for Identifier {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 mls_rs_core::debug::pretty_bytes(&self.0)
41 .named("Identifier")
42 .fmt(f)
43 }
44 }
45
46 #[cfg(all(feature = "tree_index", feature = "std"))]
47 #[derive(Clone, Debug, Default, PartialEq, MlsSize, MlsEncode, MlsDecode)]
48 pub struct TreeIndex {
49 credential_signature_key: HashMap<SignaturePublicKey, LeafIndex>,
50 hpke_key: HashMap<HpkePublicKey, LeafIndex>,
51 identities: HashMap<Identifier, LeafIndex>,
52 credential_type_counters: HashMap<CredentialType, TypeCounter>,
53 #[cfg(feature = "custom_proposal")]
54 proposal_type_counter: HashMap<ProposalType, u32>,
55 }
56
57 #[cfg(all(feature = "tree_index", not(feature = "std")))]
58 #[derive(Clone, Debug, Default, PartialEq, MlsSize, MlsEncode, MlsDecode)]
59 pub struct TreeIndex {
60 credential_signature_key: BTreeMap<SignaturePublicKey, LeafIndex>,
61 hpke_key: BTreeMap<HpkePublicKey, LeafIndex>,
62 identities: BTreeMap<Identifier, LeafIndex>,
63 credential_type_counters: BTreeMap<CredentialType, TypeCounter>,
64 #[cfg(feature = "custom_proposal")]
65 proposal_type_counter: BTreeMap<ProposalType, u32>,
66 }
67
68 #[cfg(feature = "tree_index")]
69 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
index_insert<I: IdentityProvider>( tree_index: &mut TreeIndex, new_leaf: &LeafNode, new_leaf_idx: LeafIndex, id_provider: &I, extensions: &ExtensionList, ) -> Result<(), MlsError>70 pub(super) async fn index_insert<I: IdentityProvider>(
71 tree_index: &mut TreeIndex,
72 new_leaf: &LeafNode,
73 new_leaf_idx: LeafIndex,
74 id_provider: &I,
75 extensions: &ExtensionList,
76 ) -> Result<(), MlsError> {
77 let new_id = id_provider
78 .identity(&new_leaf.signing_identity, extensions)
79 .await
80 .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
81
82 tree_index.insert(new_leaf_idx, new_leaf, new_id)
83 }
84
85 #[cfg(not(feature = "tree_index"))]
86 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
index_insert<I: IdentityProvider>( nodes: &NodeVec, new_leaf: &LeafNode, new_leaf_idx: LeafIndex, id_provider: &I, extensions: &ExtensionList, ) -> Result<(), MlsError>87 pub(super) async fn index_insert<I: IdentityProvider>(
88 nodes: &NodeVec,
89 new_leaf: &LeafNode,
90 new_leaf_idx: LeafIndex,
91 id_provider: &I,
92 extensions: &ExtensionList,
93 ) -> Result<(), MlsError> {
94 let new_id = id_provider
95 .identity(&new_leaf.signing_identity, extensions)
96 .await
97 .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
98
99 for (i, leaf) in nodes.non_empty_leaves().filter(|(i, _)| i != &new_leaf_idx) {
100 (new_leaf.public_key != leaf.public_key)
101 .then_some(())
102 .ok_or(MlsError::DuplicateLeafData(*i))?;
103
104 (new_leaf.signing_identity.signature_key != leaf.signing_identity.signature_key)
105 .then_some(())
106 .ok_or(MlsError::DuplicateLeafData(*i))?;
107
108 let id = id_provider
109 .identity(&leaf.signing_identity, extensions)
110 .await
111 .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
112
113 (new_id != id)
114 .then_some(())
115 .ok_or(MlsError::DuplicateLeafData(*i))?;
116
117 let cred_type = leaf.signing_identity.credential.credential_type();
118
119 new_leaf
120 .capabilities
121 .credentials
122 .contains(&cred_type)
123 .then_some(())
124 .ok_or(MlsError::InUseCredentialTypeUnsupportedByNewLeaf)?;
125
126 let new_cred_type = new_leaf.signing_identity.credential.credential_type();
127
128 leaf.capabilities
129 .credentials
130 .contains(&new_cred_type)
131 .then_some(())
132 .ok_or(MlsError::CredentialTypeOfNewLeafIsUnsupported)?;
133 }
134
135 Ok(())
136 }
137
138 #[cfg(feature = "tree_index")]
139 impl TreeIndex {
new() -> Self140 pub fn new() -> Self {
141 Default::default()
142 }
143
is_initialized(&self) -> bool144 pub fn is_initialized(&self) -> bool {
145 !self.identities.is_empty()
146 }
147
insert( &mut self, index: LeafIndex, leaf_node: &LeafNode, identity: Vec<u8>, ) -> Result<(), MlsError>148 fn insert(
149 &mut self,
150 index: LeafIndex,
151 leaf_node: &LeafNode,
152 identity: Vec<u8>,
153 ) -> Result<(), MlsError> {
154 let old_leaf_count = self.credential_signature_key.len();
155
156 let pub_key = leaf_node.signing_identity.signature_key.clone();
157 let credential_entry = self.credential_signature_key.entry(pub_key);
158
159 if let Entry::Occupied(entry) = credential_entry {
160 return Err(MlsError::DuplicateLeafData(**entry.get()));
161 }
162
163 let hpke_entry = self.hpke_key.entry(leaf_node.public_key.clone());
164
165 if let Entry::Occupied(entry) = hpke_entry {
166 return Err(MlsError::DuplicateLeafData(**entry.get()));
167 }
168
169 let identity_entry = self.identities.entry(Identifier(identity));
170 if let Entry::Occupied(entry) = identity_entry {
171 return Err(MlsError::DuplicateLeafData(**entry.get()));
172 }
173
174 let in_use_cred_type_unsupported_by_new_leaf = self
175 .credential_type_counters
176 .iter()
177 .filter_map(|(cred_type, counters)| Some(*cred_type).filter(|_| counters.used > 0))
178 .find(|cred_type| !leaf_node.capabilities.credentials.contains(cred_type));
179
180 if in_use_cred_type_unsupported_by_new_leaf.is_some() {
181 return Err(MlsError::InUseCredentialTypeUnsupportedByNewLeaf);
182 }
183
184 let new_leaf_cred_type = leaf_node.signing_identity.credential.credential_type();
185
186 let cred_type_counters = self
187 .credential_type_counters
188 .entry(new_leaf_cred_type)
189 .or_default();
190
191 if cred_type_counters.supported != old_leaf_count as u32 {
192 return Err(MlsError::CredentialTypeOfNewLeafIsUnsupported);
193 }
194
195 cred_type_counters.used += 1;
196
197 let credential_type_iter = leaf_node.capabilities.credentials.iter().copied();
198
199 #[cfg(feature = "std")]
200 let credential_type_iter = credential_type_iter.unique();
201
202 #[cfg(not(feature = "std"))]
203 let credential_type_iter = credential_type_iter.collect::<BTreeSet<_>>().into_iter();
204
205 // Credential type counter updates
206 credential_type_iter.for_each(|cred_type| {
207 self.credential_type_counters
208 .entry(cred_type)
209 .or_default()
210 .supported += 1;
211 });
212
213 #[cfg(feature = "custom_proposal")]
214 {
215 let proposal_type_iter = leaf_node.capabilities.proposals.iter().copied();
216
217 #[cfg(feature = "std")]
218 let proposal_type_iter = proposal_type_iter.unique();
219
220 #[cfg(not(feature = "std"))]
221 let proposal_type_iter = proposal_type_iter.collect::<BTreeSet<_>>().into_iter();
222
223 // Proposal type counter update
224 proposal_type_iter.for_each(|proposal_type| {
225 *self.proposal_type_counter.entry(proposal_type).or_default() += 1;
226 });
227 }
228
229 identity_entry.or_insert(index);
230 credential_entry.or_insert(index);
231 hpke_entry.or_insert(index);
232
233 Ok(())
234 }
235
get_leaf_index_with_identity(&self, identity: &[u8]) -> Option<LeafIndex>236 pub(crate) fn get_leaf_index_with_identity(&self, identity: &[u8]) -> Option<LeafIndex> {
237 self.identities.get(&Identifier(identity.to_vec())).copied()
238 }
239
remove(&mut self, leaf_node: &LeafNode, identity: &[u8])240 pub fn remove(&mut self, leaf_node: &LeafNode, identity: &[u8]) {
241 let existed = self
242 .identities
243 .remove(&Identifier(identity.to_vec()))
244 .is_some();
245
246 self.credential_signature_key
247 .remove(&leaf_node.signing_identity.signature_key);
248
249 self.hpke_key.remove(&leaf_node.public_key);
250
251 if !existed {
252 return;
253 }
254
255 // Decrement credential type counters
256 let leaf_cred_type = leaf_node.signing_identity.credential.credential_type();
257
258 if let Some(counters) = self.credential_type_counters.get_mut(&leaf_cred_type) {
259 counters.used -= 1;
260 }
261
262 let credential_type_iter = leaf_node.capabilities.credentials.iter();
263
264 #[cfg(feature = "std")]
265 let credential_type_iter = credential_type_iter.unique();
266
267 #[cfg(not(feature = "std"))]
268 let credential_type_iter = credential_type_iter.collect::<BTreeSet<_>>().into_iter();
269
270 credential_type_iter.for_each(|cred_type| {
271 if let Some(counters) = self.credential_type_counters.get_mut(cred_type) {
272 counters.supported -= 1;
273 }
274 });
275
276 #[cfg(feature = "custom_proposal")]
277 {
278 let proposal_type_iter = leaf_node.capabilities.proposals.iter();
279
280 #[cfg(feature = "std")]
281 let proposal_type_iter = proposal_type_iter.unique();
282
283 #[cfg(not(feature = "std"))]
284 let proposal_type_iter = proposal_type_iter.collect::<BTreeSet<_>>().into_iter();
285
286 // Decrement proposal type counters
287 proposal_type_iter.for_each(|proposal_type| {
288 if let Some(supported) = self.proposal_type_counter.get_mut(proposal_type) {
289 *supported -= 1;
290 }
291 })
292 }
293 }
294
295 #[cfg(feature = "custom_proposal")]
count_supporting_proposal(&self, proposal_type: ProposalType) -> u32296 pub fn count_supporting_proposal(&self, proposal_type: ProposalType) -> u32 {
297 self.proposal_type_counter
298 .get(&proposal_type)
299 .copied()
300 .unwrap_or_default()
301 }
302
303 #[cfg(test)]
len(&self) -> usize304 pub fn len(&self) -> usize {
305 self.credential_signature_key.len()
306 }
307 }
308
309 #[cfg(feature = "tree_index")]
310 #[derive(Clone, Debug, Default, PartialEq, MlsEncode, MlsDecode, MlsSize)]
311 struct TypeCounter {
312 supported: u32,
313 used: u32,
314 }
315
316 #[cfg(feature = "tree_index")]
317 #[cfg(test)]
318 mod tests {
319 use super::*;
320 use crate::{
321 client::test_utils::TEST_CIPHER_SUITE,
322 tree_kem::leaf_node::test_utils::{get_basic_test_node, get_test_client_identity},
323 };
324 use alloc::format;
325 use assert_matches::assert_matches;
326
327 #[derive(Clone, Debug)]
328 struct TestData {
329 pub leaf_node: LeafNode,
330 pub index: LeafIndex,
331 }
332
333 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_data(index: LeafIndex) -> TestData334 async fn get_test_data(index: LeafIndex) -> TestData {
335 let cipher_suite = TEST_CIPHER_SUITE;
336 let leaf_node = get_basic_test_node(cipher_suite, &format!("foo{}", index.0)).await;
337
338 TestData { leaf_node, index }
339 }
340
341 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_setup() -> (Vec<TestData>, TreeIndex)342 async fn test_setup() -> (Vec<TestData>, TreeIndex) {
343 let mut test_data = Vec::new();
344
345 for i in 0..10 {
346 test_data.push(get_test_data(LeafIndex(i)).await);
347 }
348
349 let mut test_index = TreeIndex::new();
350
351 test_data.clone().into_iter().for_each(|d| {
352 test_index
353 .insert(
354 d.index,
355 &d.leaf_node,
356 get_test_client_identity(&d.leaf_node),
357 )
358 .unwrap()
359 });
360
361 (test_data, test_index)
362 }
363
364 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_insert()365 async fn test_insert() {
366 let (test_data, test_index) = test_setup().await;
367
368 assert_eq!(test_index.credential_signature_key.len(), test_data.len());
369 assert_eq!(test_index.hpke_key.len(), test_data.len());
370
371 test_data.into_iter().enumerate().for_each(|(i, d)| {
372 let pub_key = d.leaf_node.signing_identity.signature_key;
373
374 assert_eq!(
375 test_index.credential_signature_key.get(&pub_key),
376 Some(&LeafIndex(i as u32))
377 );
378
379 assert_eq!(
380 test_index.hpke_key.get(&d.leaf_node.public_key),
381 Some(&LeafIndex(i as u32))
382 );
383 })
384 }
385
386 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_insert_duplicate_credential_key()387 async fn test_insert_duplicate_credential_key() {
388 let (test_data, mut test_index) = test_setup().await;
389
390 let before_error = test_index.clone();
391
392 let mut new_key_package = get_basic_test_node(TEST_CIPHER_SUITE, "foo").await;
393 new_key_package.signing_identity = test_data[1].leaf_node.signing_identity.clone();
394
395 let res = test_index.insert(
396 test_data[1].index,
397 &new_key_package,
398 get_test_client_identity(&new_key_package),
399 );
400
401 assert_matches!(res, Err(MlsError::DuplicateLeafData(index))
402 if index == *test_data[1].index);
403
404 assert_eq!(before_error, test_index);
405 }
406
407 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_insert_duplicate_hpke_key()408 async fn test_insert_duplicate_hpke_key() {
409 let cipher_suite = TEST_CIPHER_SUITE;
410 let (test_data, mut test_index) = test_setup().await;
411 let before_error = test_index.clone();
412
413 let mut new_leaf_node = get_basic_test_node(cipher_suite, "foo").await;
414 new_leaf_node.public_key = test_data[1].leaf_node.public_key.clone();
415
416 let res = test_index.insert(
417 test_data[1].index,
418 &new_leaf_node,
419 get_test_client_identity(&new_leaf_node),
420 );
421
422 assert_matches!(res, Err(MlsError::DuplicateLeafData(index))
423 if index == *test_data[1].index);
424
425 assert_eq!(before_error, test_index);
426 }
427
428 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_remove()429 async fn test_remove() {
430 let (test_data, mut test_index) = test_setup().await;
431
432 test_index.remove(
433 &test_data[1].leaf_node,
434 &get_test_client_identity(&test_data[1].leaf_node),
435 );
436
437 assert_eq!(
438 test_index.credential_signature_key.len(),
439 test_data.len() - 1
440 );
441
442 assert_eq!(test_index.hpke_key.len(), test_data.len() - 1);
443
444 assert_eq!(
445 test_index
446 .credential_signature_key
447 .get(&test_data[1].leaf_node.signing_identity.signature_key),
448 None
449 );
450
451 assert_eq!(
452 test_index.hpke_key.get(&test_data[1].leaf_node.public_key),
453 None
454 );
455 }
456
457 #[cfg(feature = "custom_proposal")]
458 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
custom_proposals()459 async fn custom_proposals() {
460 let test_proposal_id = ProposalType::new(42);
461 let other_proposal_id = ProposalType::new(45);
462
463 let mut test_data_1 = get_test_data(LeafIndex(0)).await;
464
465 test_data_1
466 .leaf_node
467 .capabilities
468 .proposals
469 .push(test_proposal_id);
470
471 let mut test_data_2 = get_test_data(LeafIndex(1)).await;
472
473 test_data_2
474 .leaf_node
475 .capabilities
476 .proposals
477 .push(test_proposal_id);
478
479 test_data_2
480 .leaf_node
481 .capabilities
482 .proposals
483 .push(other_proposal_id);
484
485 let mut test_index = TreeIndex::new();
486
487 test_index
488 .insert(test_data_1.index, &test_data_1.leaf_node, vec![0])
489 .unwrap();
490
491 assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 1);
492
493 test_index
494 .insert(test_data_2.index, &test_data_2.leaf_node, vec![1])
495 .unwrap();
496
497 assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 2);
498 assert_eq!(test_index.count_supporting_proposal(other_proposal_id), 1);
499
500 test_index.remove(&test_data_2.leaf_node, &[1]);
501
502 assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 1);
503 assert_eq!(test_index.count_supporting_proposal(other_proposal_id), 0);
504 }
505 }
506