1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::string::String;
6 use alloc::{format, vec};
7 use core::borrow::BorrowMut;
8 
9 use debug_tree::TreeBuilder;
10 
11 use super::node::{NodeIndex, NodeVec};
12 use crate::{client::MlsError, tree_kem::math::TreeIndex};
13 
build_tree( tree: &mut TreeBuilder, nodes: &NodeVec, idx: NodeIndex, ) -> Result<(), MlsError>14 pub(crate) fn build_tree(
15     tree: &mut TreeBuilder,
16     nodes: &NodeVec,
17     idx: NodeIndex,
18 ) -> Result<(), MlsError> {
19     let blank_tag = if nodes.is_blank(idx)? { "Blank " } else { "" };
20 
21     // Leaf Node
22     if nodes.is_leaf(idx) {
23         let leaf_tag = format!("{blank_tag}Leaf ({idx})");
24         tree.add_leaf(&leaf_tag);
25         return Ok(());
26     }
27 
28     // Parent Leaf
29     let mut parent_tag = format!("{blank_tag}Parent ({idx})");
30 
31     if nodes.total_leaf_count().root() == idx {
32         parent_tag = format!("{blank_tag}Root ({idx})");
33     }
34 
35     // Add unmerged leaves indexes
36     let unmerged_leaves_idxs = match nodes.borrow_as_parent(idx) {
37         Ok(parent) => parent
38             .unmerged_leaves
39             .iter()
40             .map(|leaf_idx| format!("{}", leaf_idx.0))
41             .collect(),
42         Err(_) => {
43             // Empty parent nodes throw `NotParent` error when borrow as Parent
44             vec![]
45         }
46     };
47 
48     if !unmerged_leaves_idxs.is_empty() {
49         let unmerged_leaves_tag =
50             format!(" unmerged leaves idxs: {}", unmerged_leaves_idxs.join(","));
51         parent_tag.push_str(&unmerged_leaves_tag);
52     }
53 
54     let mut branch = tree.add_branch(&parent_tag);
55 
56     //This cannot panic, as we already checked that idx is not a leaf
57     build_tree(tree, nodes, idx.left_unchecked())?;
58     build_tree(tree, nodes, idx.right_unchecked())?;
59 
60     branch.release();
61 
62     Ok(())
63 }
64 
build_ascii_tree(nodes: &NodeVec) -> String65 pub(crate) fn build_ascii_tree(nodes: &NodeVec) -> String {
66     let leaves_count: u32 = nodes.total_leaf_count();
67     let mut tree = TreeBuilder::new();
68     build_tree(tree.borrow_mut(), nodes, leaves_count.root()).unwrap();
69     tree.string()
70 }
71 
72 #[cfg(test)]
73 mod tests {
74     use alloc::vec;
75 
76     use crate::{
77         client::test_utils::TEST_CIPHER_SUITE,
78         crypto::test_utils::test_cipher_suite_provider,
79         identity::basic::BasicIdentityProvider,
80         tree_kem::{
81             node::Parent,
82             parent_hash::ParentHash,
83             test_utils::{get_test_leaf_nodes, get_test_tree},
84         },
85     };
86 
87     use super::build_ascii_tree;
88 
89     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
print_fully_populated_tree()90     async fn print_fully_populated_tree() {
91         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
92 
93         // Create a tree
94         let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
95         let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
96 
97         tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
98             .await
99             .unwrap();
100 
101         let tree_str = concat!(
102             "Blank Root (3)\n",
103             "├╼ Blank Parent (1)\n",
104             "│ ├╼ Leaf (0)\n",
105             "│ └╼ Leaf (2)\n",
106             "└╼ Blank Parent (5)\n",
107             "  ├╼ Leaf (4)\n",
108             "  └╼ Leaf (6)",
109         );
110 
111         assert_eq!(tree_str, build_ascii_tree(&tree.nodes));
112     }
113 
114     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
print_tree_blank_leaves()115     async fn print_tree_blank_leaves() {
116         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
117 
118         // Create a tree
119         let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
120         let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
121 
122         let to_remove = tree
123             .add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
124             .await
125             .unwrap()[0];
126 
127         tree.remove_leaves(
128             vec![to_remove],
129             &BasicIdentityProvider,
130             &cipher_suite_provider,
131         )
132         .await
133         .unwrap();
134 
135         let tree_str = concat!(
136             "Blank Root (3)\n",
137             "├╼ Blank Parent (1)\n",
138             "│ ├╼ Leaf (0)\n",
139             "│ └╼ Blank Leaf (2)\n",
140             "└╼ Blank Parent (5)\n",
141             "  ├╼ Leaf (4)\n",
142             "  └╼ Leaf (6)",
143         );
144 
145         assert_eq!(tree_str, build_ascii_tree(&tree.nodes));
146     }
147 
148     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
print_tree_unmerged_leaves_on_parent()149     async fn print_tree_unmerged_leaves_on_parent() {
150         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
151 
152         // Create a tree
153         let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
154         let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
155 
156         tree.add_leaves(
157             [key_packages[0].clone(), key_packages[1].clone()].to_vec(),
158             &BasicIdentityProvider,
159             &cipher_suite_provider,
160         )
161         .await
162         .unwrap();
163 
164         tree.nodes[3] = Parent {
165             public_key: vec![].into(),
166             parent_hash: ParentHash::empty(),
167             unmerged_leaves: vec![],
168         }
169         .into();
170 
171         tree.add_leaves(
172             [key_packages[2].clone()].to_vec(),
173             &BasicIdentityProvider,
174             &cipher_suite_provider,
175         )
176         .await
177         .unwrap();
178 
179         let tree_str = concat!(
180             "Root (3) unmerged leaves idxs: 3\n",
181             "├╼ Blank Parent (1)\n",
182             "│ ├╼ Leaf (0)\n",
183             "│ └╼ Leaf (2)\n",
184             "└╼ Blank Parent (5)\n",
185             "  ├╼ Leaf (4)\n",
186             "  └╼ Leaf (6)",
187         );
188 
189         assert_eq!(tree_str, build_ascii_tree(&tree.nodes));
190     }
191 }
192