1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5 use alloc::string::String;
6 use alloc::{format, vec};
7 use core::borrow::BorrowMut;
8
9 use debug_tree::TreeBuilder;
10
11 use super::node::{NodeIndex, NodeVec};
12 use crate::{client::MlsError, tree_kem::math::TreeIndex};
13
build_tree( tree: &mut TreeBuilder, nodes: &NodeVec, idx: NodeIndex, ) -> Result<(), MlsError>14 pub(crate) fn build_tree(
15 tree: &mut TreeBuilder,
16 nodes: &NodeVec,
17 idx: NodeIndex,
18 ) -> Result<(), MlsError> {
19 let blank_tag = if nodes.is_blank(idx)? { "Blank " } else { "" };
20
21 // Leaf Node
22 if nodes.is_leaf(idx) {
23 let leaf_tag = format!("{blank_tag}Leaf ({idx})");
24 tree.add_leaf(&leaf_tag);
25 return Ok(());
26 }
27
28 // Parent Leaf
29 let mut parent_tag = format!("{blank_tag}Parent ({idx})");
30
31 if nodes.total_leaf_count().root() == idx {
32 parent_tag = format!("{blank_tag}Root ({idx})");
33 }
34
35 // Add unmerged leaves indexes
36 let unmerged_leaves_idxs = match nodes.borrow_as_parent(idx) {
37 Ok(parent) => parent
38 .unmerged_leaves
39 .iter()
40 .map(|leaf_idx| format!("{}", leaf_idx.0))
41 .collect(),
42 Err(_) => {
43 // Empty parent nodes throw `NotParent` error when borrow as Parent
44 vec![]
45 }
46 };
47
48 if !unmerged_leaves_idxs.is_empty() {
49 let unmerged_leaves_tag =
50 format!(" unmerged leaves idxs: {}", unmerged_leaves_idxs.join(","));
51 parent_tag.push_str(&unmerged_leaves_tag);
52 }
53
54 let mut branch = tree.add_branch(&parent_tag);
55
56 //This cannot panic, as we already checked that idx is not a leaf
57 build_tree(tree, nodes, idx.left_unchecked())?;
58 build_tree(tree, nodes, idx.right_unchecked())?;
59
60 branch.release();
61
62 Ok(())
63 }
64
build_ascii_tree(nodes: &NodeVec) -> String65 pub(crate) fn build_ascii_tree(nodes: &NodeVec) -> String {
66 let leaves_count: u32 = nodes.total_leaf_count();
67 let mut tree = TreeBuilder::new();
68 build_tree(tree.borrow_mut(), nodes, leaves_count.root()).unwrap();
69 tree.string()
70 }
71
72 #[cfg(test)]
73 mod tests {
74 use alloc::vec;
75
76 use crate::{
77 client::test_utils::TEST_CIPHER_SUITE,
78 crypto::test_utils::test_cipher_suite_provider,
79 identity::basic::BasicIdentityProvider,
80 tree_kem::{
81 node::Parent,
82 parent_hash::ParentHash,
83 test_utils::{get_test_leaf_nodes, get_test_tree},
84 },
85 };
86
87 use super::build_ascii_tree;
88
89 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
print_fully_populated_tree()90 async fn print_fully_populated_tree() {
91 let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
92
93 // Create a tree
94 let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
95 let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
96
97 tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
98 .await
99 .unwrap();
100
101 let tree_str = concat!(
102 "Blank Root (3)\n",
103 "├╼ Blank Parent (1)\n",
104 "│ ├╼ Leaf (0)\n",
105 "│ └╼ Leaf (2)\n",
106 "└╼ Blank Parent (5)\n",
107 " ├╼ Leaf (4)\n",
108 " └╼ Leaf (6)",
109 );
110
111 assert_eq!(tree_str, build_ascii_tree(&tree.nodes));
112 }
113
114 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
print_tree_blank_leaves()115 async fn print_tree_blank_leaves() {
116 let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
117
118 // Create a tree
119 let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
120 let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
121
122 let to_remove = tree
123 .add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
124 .await
125 .unwrap()[0];
126
127 tree.remove_leaves(
128 vec![to_remove],
129 &BasicIdentityProvider,
130 &cipher_suite_provider,
131 )
132 .await
133 .unwrap();
134
135 let tree_str = concat!(
136 "Blank Root (3)\n",
137 "├╼ Blank Parent (1)\n",
138 "│ ├╼ Leaf (0)\n",
139 "│ └╼ Blank Leaf (2)\n",
140 "└╼ Blank Parent (5)\n",
141 " ├╼ Leaf (4)\n",
142 " └╼ Leaf (6)",
143 );
144
145 assert_eq!(tree_str, build_ascii_tree(&tree.nodes));
146 }
147
148 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
print_tree_unmerged_leaves_on_parent()149 async fn print_tree_unmerged_leaves_on_parent() {
150 let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
151
152 // Create a tree
153 let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
154 let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
155
156 tree.add_leaves(
157 [key_packages[0].clone(), key_packages[1].clone()].to_vec(),
158 &BasicIdentityProvider,
159 &cipher_suite_provider,
160 )
161 .await
162 .unwrap();
163
164 tree.nodes[3] = Parent {
165 public_key: vec![].into(),
166 parent_hash: ParentHash::empty(),
167 unmerged_leaves: vec![],
168 }
169 .into();
170
171 tree.add_leaves(
172 [key_packages[2].clone()].to_vec(),
173 &BasicIdentityProvider,
174 &cipher_suite_provider,
175 )
176 .await
177 .unwrap();
178
179 let tree_str = concat!(
180 "Root (3) unmerged leaves idxs: 3\n",
181 "├╼ Blank Parent (1)\n",
182 "│ ├╼ Leaf (0)\n",
183 "│ └╼ Leaf (2)\n",
184 "└╼ Blank Parent (5)\n",
185 " ├╼ Leaf (4)\n",
186 " └╼ Leaf (6)",
187 );
188
189 assert_eq!(tree_str, build_ascii_tree(&tree.nodes));
190 }
191 }
192