1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use super::*;
6 #[cfg(feature = "tree_index")]
7 use core::fmt::{self, Debug};
8 
9 #[cfg(all(feature = "tree_index", feature = "custom_proposal"))]
10 use crate::group::proposal::ProposalType;
11 
12 #[cfg(feature = "tree_index")]
13 use crate::identity::CredentialType;
14 
15 #[cfg(feature = "tree_index")]
16 use mls_rs_core::crypto::SignaturePublicKey;
17 
18 #[cfg(all(feature = "tree_index", feature = "std"))]
19 use itertools::Itertools;
20 
21 #[cfg(all(feature = "tree_index", not(feature = "std")))]
22 use alloc::collections::{btree_map::Entry, BTreeMap};
23 
24 #[cfg(all(feature = "tree_index", feature = "std"))]
25 use std::collections::{hash_map::Entry, HashMap};
26 
27 #[cfg(all(feature = "tree_index", not(feature = "std")))]
28 use alloc::collections::BTreeSet;
29 
30 #[cfg(feature = "tree_index")]
31 use mls_rs_core::crypto::HpkePublicKey;
32 
33 #[cfg(feature = "tree_index")]
34 #[derive(Clone, Default, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, Hash, PartialOrd, Ord)]
35 pub struct Identifier(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>);
36 
37 #[cfg(feature = "tree_index")]
38 impl Debug for Identifier {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result39     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40         mls_rs_core::debug::pretty_bytes(&self.0)
41             .named("Identifier")
42             .fmt(f)
43     }
44 }
45 
46 #[cfg(all(feature = "tree_index", feature = "std"))]
47 #[derive(Clone, Debug, Default, PartialEq, MlsSize, MlsEncode, MlsDecode)]
48 pub struct TreeIndex {
49     credential_signature_key: HashMap<SignaturePublicKey, LeafIndex>,
50     hpke_key: HashMap<HpkePublicKey, LeafIndex>,
51     identities: HashMap<Identifier, LeafIndex>,
52     credential_type_counters: HashMap<CredentialType, TypeCounter>,
53     #[cfg(feature = "custom_proposal")]
54     proposal_type_counter: HashMap<ProposalType, u32>,
55 }
56 
57 #[cfg(all(feature = "tree_index", not(feature = "std")))]
58 #[derive(Clone, Debug, Default, PartialEq, MlsSize, MlsEncode, MlsDecode)]
59 pub struct TreeIndex {
60     credential_signature_key: BTreeMap<SignaturePublicKey, LeafIndex>,
61     hpke_key: BTreeMap<HpkePublicKey, LeafIndex>,
62     identities: BTreeMap<Identifier, LeafIndex>,
63     credential_type_counters: BTreeMap<CredentialType, TypeCounter>,
64     #[cfg(feature = "custom_proposal")]
65     proposal_type_counter: BTreeMap<ProposalType, u32>,
66 }
67 
68 #[cfg(feature = "tree_index")]
69 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
index_insert<I: IdentityProvider>( tree_index: &mut TreeIndex, new_leaf: &LeafNode, new_leaf_idx: LeafIndex, id_provider: &I, extensions: &ExtensionList, ) -> Result<(), MlsError>70 pub(super) async fn index_insert<I: IdentityProvider>(
71     tree_index: &mut TreeIndex,
72     new_leaf: &LeafNode,
73     new_leaf_idx: LeafIndex,
74     id_provider: &I,
75     extensions: &ExtensionList,
76 ) -> Result<(), MlsError> {
77     let new_id = id_provider
78         .identity(&new_leaf.signing_identity, extensions)
79         .await
80         .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
81 
82     tree_index.insert(new_leaf_idx, new_leaf, new_id)
83 }
84 
85 #[cfg(not(feature = "tree_index"))]
86 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
index_insert<I: IdentityProvider>( nodes: &NodeVec, new_leaf: &LeafNode, new_leaf_idx: LeafIndex, id_provider: &I, extensions: &ExtensionList, ) -> Result<(), MlsError>87 pub(super) async fn index_insert<I: IdentityProvider>(
88     nodes: &NodeVec,
89     new_leaf: &LeafNode,
90     new_leaf_idx: LeafIndex,
91     id_provider: &I,
92     extensions: &ExtensionList,
93 ) -> Result<(), MlsError> {
94     let new_id = id_provider
95         .identity(&new_leaf.signing_identity, extensions)
96         .await
97         .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
98 
99     for (i, leaf) in nodes.non_empty_leaves().filter(|(i, _)| i != &new_leaf_idx) {
100         (new_leaf.public_key != leaf.public_key)
101             .then_some(())
102             .ok_or(MlsError::DuplicateLeafData(*i))?;
103 
104         (new_leaf.signing_identity.signature_key != leaf.signing_identity.signature_key)
105             .then_some(())
106             .ok_or(MlsError::DuplicateLeafData(*i))?;
107 
108         let id = id_provider
109             .identity(&leaf.signing_identity, extensions)
110             .await
111             .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
112 
113         (new_id != id)
114             .then_some(())
115             .ok_or(MlsError::DuplicateLeafData(*i))?;
116 
117         let cred_type = leaf.signing_identity.credential.credential_type();
118 
119         new_leaf
120             .capabilities
121             .credentials
122             .contains(&cred_type)
123             .then_some(())
124             .ok_or(MlsError::InUseCredentialTypeUnsupportedByNewLeaf)?;
125 
126         let new_cred_type = new_leaf.signing_identity.credential.credential_type();
127 
128         leaf.capabilities
129             .credentials
130             .contains(&new_cred_type)
131             .then_some(())
132             .ok_or(MlsError::CredentialTypeOfNewLeafIsUnsupported)?;
133     }
134 
135     Ok(())
136 }
137 
138 #[cfg(feature = "tree_index")]
139 impl TreeIndex {
new() -> Self140     pub fn new() -> Self {
141         Default::default()
142     }
143 
is_initialized(&self) -> bool144     pub fn is_initialized(&self) -> bool {
145         !self.identities.is_empty()
146     }
147 
insert( &mut self, index: LeafIndex, leaf_node: &LeafNode, identity: Vec<u8>, ) -> Result<(), MlsError>148     fn insert(
149         &mut self,
150         index: LeafIndex,
151         leaf_node: &LeafNode,
152         identity: Vec<u8>,
153     ) -> Result<(), MlsError> {
154         let old_leaf_count = self.credential_signature_key.len();
155 
156         let pub_key = leaf_node.signing_identity.signature_key.clone();
157         let credential_entry = self.credential_signature_key.entry(pub_key);
158 
159         if let Entry::Occupied(entry) = credential_entry {
160             return Err(MlsError::DuplicateLeafData(**entry.get()));
161         }
162 
163         let hpke_entry = self.hpke_key.entry(leaf_node.public_key.clone());
164 
165         if let Entry::Occupied(entry) = hpke_entry {
166             return Err(MlsError::DuplicateLeafData(**entry.get()));
167         }
168 
169         let identity_entry = self.identities.entry(Identifier(identity));
170         if let Entry::Occupied(entry) = identity_entry {
171             return Err(MlsError::DuplicateLeafData(**entry.get()));
172         }
173 
174         let in_use_cred_type_unsupported_by_new_leaf = self
175             .credential_type_counters
176             .iter()
177             .filter_map(|(cred_type, counters)| Some(*cred_type).filter(|_| counters.used > 0))
178             .find(|cred_type| !leaf_node.capabilities.credentials.contains(cred_type));
179 
180         if in_use_cred_type_unsupported_by_new_leaf.is_some() {
181             return Err(MlsError::InUseCredentialTypeUnsupportedByNewLeaf);
182         }
183 
184         let new_leaf_cred_type = leaf_node.signing_identity.credential.credential_type();
185 
186         let cred_type_counters = self
187             .credential_type_counters
188             .entry(new_leaf_cred_type)
189             .or_default();
190 
191         if cred_type_counters.supported != old_leaf_count as u32 {
192             return Err(MlsError::CredentialTypeOfNewLeafIsUnsupported);
193         }
194 
195         cred_type_counters.used += 1;
196 
197         let credential_type_iter = leaf_node.capabilities.credentials.iter().copied();
198 
199         #[cfg(feature = "std")]
200         let credential_type_iter = credential_type_iter.unique();
201 
202         #[cfg(not(feature = "std"))]
203         let credential_type_iter = credential_type_iter.collect::<BTreeSet<_>>().into_iter();
204 
205         // Credential type counter updates
206         credential_type_iter.for_each(|cred_type| {
207             self.credential_type_counters
208                 .entry(cred_type)
209                 .or_default()
210                 .supported += 1;
211         });
212 
213         #[cfg(feature = "custom_proposal")]
214         {
215             let proposal_type_iter = leaf_node.capabilities.proposals.iter().copied();
216 
217             #[cfg(feature = "std")]
218             let proposal_type_iter = proposal_type_iter.unique();
219 
220             #[cfg(not(feature = "std"))]
221             let proposal_type_iter = proposal_type_iter.collect::<BTreeSet<_>>().into_iter();
222 
223             // Proposal type counter update
224             proposal_type_iter.for_each(|proposal_type| {
225                 *self.proposal_type_counter.entry(proposal_type).or_default() += 1;
226             });
227         }
228 
229         identity_entry.or_insert(index);
230         credential_entry.or_insert(index);
231         hpke_entry.or_insert(index);
232 
233         Ok(())
234     }
235 
get_leaf_index_with_identity(&self, identity: &[u8]) -> Option<LeafIndex>236     pub(crate) fn get_leaf_index_with_identity(&self, identity: &[u8]) -> Option<LeafIndex> {
237         self.identities.get(&Identifier(identity.to_vec())).copied()
238     }
239 
remove(&mut self, leaf_node: &LeafNode, identity: &[u8])240     pub fn remove(&mut self, leaf_node: &LeafNode, identity: &[u8]) {
241         let existed = self
242             .identities
243             .remove(&Identifier(identity.to_vec()))
244             .is_some();
245 
246         self.credential_signature_key
247             .remove(&leaf_node.signing_identity.signature_key);
248 
249         self.hpke_key.remove(&leaf_node.public_key);
250 
251         if !existed {
252             return;
253         }
254 
255         // Decrement credential type counters
256         let leaf_cred_type = leaf_node.signing_identity.credential.credential_type();
257 
258         if let Some(counters) = self.credential_type_counters.get_mut(&leaf_cred_type) {
259             counters.used -= 1;
260         }
261 
262         let credential_type_iter = leaf_node.capabilities.credentials.iter();
263 
264         #[cfg(feature = "std")]
265         let credential_type_iter = credential_type_iter.unique();
266 
267         #[cfg(not(feature = "std"))]
268         let credential_type_iter = credential_type_iter.collect::<BTreeSet<_>>().into_iter();
269 
270         credential_type_iter.for_each(|cred_type| {
271             if let Some(counters) = self.credential_type_counters.get_mut(cred_type) {
272                 counters.supported -= 1;
273             }
274         });
275 
276         #[cfg(feature = "custom_proposal")]
277         {
278             let proposal_type_iter = leaf_node.capabilities.proposals.iter();
279 
280             #[cfg(feature = "std")]
281             let proposal_type_iter = proposal_type_iter.unique();
282 
283             #[cfg(not(feature = "std"))]
284             let proposal_type_iter = proposal_type_iter.collect::<BTreeSet<_>>().into_iter();
285 
286             // Decrement proposal type counters
287             proposal_type_iter.for_each(|proposal_type| {
288                 if let Some(supported) = self.proposal_type_counter.get_mut(proposal_type) {
289                     *supported -= 1;
290                 }
291             })
292         }
293     }
294 
295     #[cfg(feature = "custom_proposal")]
count_supporting_proposal(&self, proposal_type: ProposalType) -> u32296     pub fn count_supporting_proposal(&self, proposal_type: ProposalType) -> u32 {
297         self.proposal_type_counter
298             .get(&proposal_type)
299             .copied()
300             .unwrap_or_default()
301     }
302 
303     #[cfg(test)]
len(&self) -> usize304     pub fn len(&self) -> usize {
305         self.credential_signature_key.len()
306     }
307 }
308 
309 #[cfg(feature = "tree_index")]
310 #[derive(Clone, Debug, Default, PartialEq, MlsEncode, MlsDecode, MlsSize)]
311 struct TypeCounter {
312     supported: u32,
313     used: u32,
314 }
315 
316 #[cfg(feature = "tree_index")]
317 #[cfg(test)]
318 mod tests {
319     use super::*;
320     use crate::{
321         client::test_utils::TEST_CIPHER_SUITE,
322         tree_kem::leaf_node::test_utils::{get_basic_test_node, get_test_client_identity},
323     };
324     use alloc::format;
325     use assert_matches::assert_matches;
326 
327     #[derive(Clone, Debug)]
328     struct TestData {
329         pub leaf_node: LeafNode,
330         pub index: LeafIndex,
331     }
332 
333     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_data(index: LeafIndex) -> TestData334     async fn get_test_data(index: LeafIndex) -> TestData {
335         let cipher_suite = TEST_CIPHER_SUITE;
336         let leaf_node = get_basic_test_node(cipher_suite, &format!("foo{}", index.0)).await;
337 
338         TestData { leaf_node, index }
339     }
340 
341     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_setup() -> (Vec<TestData>, TreeIndex)342     async fn test_setup() -> (Vec<TestData>, TreeIndex) {
343         let mut test_data = Vec::new();
344 
345         for i in 0..10 {
346             test_data.push(get_test_data(LeafIndex(i)).await);
347         }
348 
349         let mut test_index = TreeIndex::new();
350 
351         test_data.clone().into_iter().for_each(|d| {
352             test_index
353                 .insert(
354                     d.index,
355                     &d.leaf_node,
356                     get_test_client_identity(&d.leaf_node),
357                 )
358                 .unwrap()
359         });
360 
361         (test_data, test_index)
362     }
363 
364     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_insert()365     async fn test_insert() {
366         let (test_data, test_index) = test_setup().await;
367 
368         assert_eq!(test_index.credential_signature_key.len(), test_data.len());
369         assert_eq!(test_index.hpke_key.len(), test_data.len());
370 
371         test_data.into_iter().enumerate().for_each(|(i, d)| {
372             let pub_key = d.leaf_node.signing_identity.signature_key;
373 
374             assert_eq!(
375                 test_index.credential_signature_key.get(&pub_key),
376                 Some(&LeafIndex(i as u32))
377             );
378 
379             assert_eq!(
380                 test_index.hpke_key.get(&d.leaf_node.public_key),
381                 Some(&LeafIndex(i as u32))
382             );
383         })
384     }
385 
386     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_insert_duplicate_credential_key()387     async fn test_insert_duplicate_credential_key() {
388         let (test_data, mut test_index) = test_setup().await;
389 
390         let before_error = test_index.clone();
391 
392         let mut new_key_package = get_basic_test_node(TEST_CIPHER_SUITE, "foo").await;
393         new_key_package.signing_identity = test_data[1].leaf_node.signing_identity.clone();
394 
395         let res = test_index.insert(
396             test_data[1].index,
397             &new_key_package,
398             get_test_client_identity(&new_key_package),
399         );
400 
401         assert_matches!(res, Err(MlsError::DuplicateLeafData(index))
402                         if index == *test_data[1].index);
403 
404         assert_eq!(before_error, test_index);
405     }
406 
407     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_insert_duplicate_hpke_key()408     async fn test_insert_duplicate_hpke_key() {
409         let cipher_suite = TEST_CIPHER_SUITE;
410         let (test_data, mut test_index) = test_setup().await;
411         let before_error = test_index.clone();
412 
413         let mut new_leaf_node = get_basic_test_node(cipher_suite, "foo").await;
414         new_leaf_node.public_key = test_data[1].leaf_node.public_key.clone();
415 
416         let res = test_index.insert(
417             test_data[1].index,
418             &new_leaf_node,
419             get_test_client_identity(&new_leaf_node),
420         );
421 
422         assert_matches!(res, Err(MlsError::DuplicateLeafData(index))
423                         if index == *test_data[1].index);
424 
425         assert_eq!(before_error, test_index);
426     }
427 
428     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_remove()429     async fn test_remove() {
430         let (test_data, mut test_index) = test_setup().await;
431 
432         test_index.remove(
433             &test_data[1].leaf_node,
434             &get_test_client_identity(&test_data[1].leaf_node),
435         );
436 
437         assert_eq!(
438             test_index.credential_signature_key.len(),
439             test_data.len() - 1
440         );
441 
442         assert_eq!(test_index.hpke_key.len(), test_data.len() - 1);
443 
444         assert_eq!(
445             test_index
446                 .credential_signature_key
447                 .get(&test_data[1].leaf_node.signing_identity.signature_key),
448             None
449         );
450 
451         assert_eq!(
452             test_index.hpke_key.get(&test_data[1].leaf_node.public_key),
453             None
454         );
455     }
456 
457     #[cfg(feature = "custom_proposal")]
458     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
custom_proposals()459     async fn custom_proposals() {
460         let test_proposal_id = ProposalType::new(42);
461         let other_proposal_id = ProposalType::new(45);
462 
463         let mut test_data_1 = get_test_data(LeafIndex(0)).await;
464 
465         test_data_1
466             .leaf_node
467             .capabilities
468             .proposals
469             .push(test_proposal_id);
470 
471         let mut test_data_2 = get_test_data(LeafIndex(1)).await;
472 
473         test_data_2
474             .leaf_node
475             .capabilities
476             .proposals
477             .push(test_proposal_id);
478 
479         test_data_2
480             .leaf_node
481             .capabilities
482             .proposals
483             .push(other_proposal_id);
484 
485         let mut test_index = TreeIndex::new();
486 
487         test_index
488             .insert(test_data_1.index, &test_data_1.leaf_node, vec![0])
489             .unwrap();
490 
491         assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 1);
492 
493         test_index
494             .insert(test_data_2.index, &test_data_2.leaf_node, vec![1])
495             .unwrap();
496 
497         assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 2);
498         assert_eq!(test_index.count_supporting_proposal(other_proposal_id), 1);
499 
500         test_index.remove(&test_data_2.leaf_node, &[1]);
501 
502         assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 1);
503         assert_eq!(test_index.count_supporting_proposal(other_proposal_id), 0);
504     }
505 }
506