1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use crate::client::MlsError; 6 use crate::crypto::{CipherSuiteProvider, HpkePublicKey}; 7 use crate::tree_kem::math as tree_math; 8 use crate::tree_kem::node::{LeafIndex, Node, NodeIndex}; 9 use crate::tree_kem::TreeKemPublic; 10 use alloc::vec::Vec; 11 use core::{ 12 fmt::{self, Debug}, 13 ops::Deref, 14 }; 15 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 16 use mls_rs_core::error::IntoAnyError; 17 use tree_math::TreeIndex; 18 19 use super::leaf_node::LeafNodeSource; 20 21 #[cfg(feature = "std")] 22 use std::collections::HashSet; 23 24 #[cfg(not(feature = "std"))] 25 use alloc::collections::BTreeSet; 26 27 #[derive(Clone, Debug, MlsSize, MlsEncode)] 28 struct ParentHashInput<'a> { 29 #[mls_codec(with = "mls_rs_codec::byte_vec")] 30 public_key: &'a HpkePublicKey, 31 #[mls_codec(with = "mls_rs_codec::byte_vec")] 32 parent_hash: &'a [u8], 33 #[mls_codec(with = "mls_rs_codec::byte_vec")] 34 original_sibling_tree_hash: &'a [u8], 35 } 36 37 #[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq, Eq)] 38 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 39 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 40 pub struct ParentHash( 41 #[mls_codec(with = "mls_rs_codec::byte_vec")] 42 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))] 43 Vec<u8>, 44 ); 45 46 impl Debug for ParentHash { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 48 mls_rs_core::debug::pretty_bytes(&self.0) 49 .named("ParentHash") 50 .fmt(f) 51 } 52 } 53 54 impl From<Vec<u8>> for ParentHash { from(v: Vec<u8>) -> Self55 fn from(v: Vec<u8>) -> Self { 56 Self(v) 57 } 58 } 59 60 impl Deref for ParentHash { 61 type Target = Vec<u8>; 62 deref(&self) -> &Self::Target63 fn deref(&self) -> &Self::Target { 64 &self.0 65 } 66 } 67 68 impl ParentHash { 69 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] new<P: CipherSuiteProvider>( cipher_suite_provider: &P, public_key: &HpkePublicKey, parent_hash: &ParentHash, original_sibling_tree_hash: &[u8], ) -> Result<Self, MlsError>70 pub async fn new<P: CipherSuiteProvider>( 71 cipher_suite_provider: &P, 72 public_key: &HpkePublicKey, 73 parent_hash: &ParentHash, 74 original_sibling_tree_hash: &[u8], 75 ) -> Result<Self, MlsError> { 76 let input = ParentHashInput { 77 public_key, 78 parent_hash, 79 original_sibling_tree_hash, 80 }; 81 82 let input_bytes = input.mls_encode_to_vec()?; 83 84 let hash = cipher_suite_provider 85 .hash(&input_bytes) 86 .await 87 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?; 88 89 Ok(Self(hash)) 90 } 91 empty() -> Self92 pub fn empty() -> Self { 93 ParentHash(Vec::new()) 94 } 95 matches(&self, hash: &ParentHash) -> bool96 pub fn matches(&self, hash: &ParentHash) -> bool { 97 //TODO: Constant time equals 98 hash == self 99 } 100 } 101 102 impl Node { get_parent_hash(&self) -> Option<ParentHash>103 fn get_parent_hash(&self) -> Option<ParentHash> { 104 match self { 105 Node::Parent(p) => Some(p.parent_hash.clone()), 106 Node::Leaf(l) => match &l.leaf_node_source { 107 LeafNodeSource::Commit(parent_hash) => Some(parent_hash.clone()), 108 _ => None, 109 }, 110 } 111 } 112 } 113 114 impl TreeKemPublic { 115 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] parent_hash_for_leaf<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, index: LeafIndex, ) -> Result<ParentHash, MlsError>116 async fn parent_hash_for_leaf<P: CipherSuiteProvider>( 117 &mut self, 118 cipher_suite_provider: &P, 119 index: LeafIndex, 120 ) -> Result<ParentHash, MlsError> { 121 let mut hash = ParentHash::empty(); 122 123 for node in self.nodes.direct_copath(index).into_iter().rev() { 124 if self.nodes.is_resolution_empty(node.copath) { 125 continue; 126 } 127 128 let parent = self.nodes.borrow_as_parent_mut(node.path)?; 129 130 let calculated = ParentHash::new( 131 cipher_suite_provider, 132 &parent.public_key, 133 &hash, 134 &self.tree_hashes.current[node.copath as usize], 135 ) 136 .await?; 137 138 (parent.parent_hash, hash) = (hash, calculated); 139 } 140 141 Ok(hash) 142 } 143 144 // Updates all of the required parent hash values, and returns the calculated parent hash value for the leaf node 145 // If an update path is provided, additionally verify that the calculated parent hash matches 146 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] update_parent_hashes<P: CipherSuiteProvider>( &mut self, index: LeafIndex, verify_leaf_hash: bool, cipher_suite_provider: &P, ) -> Result<(), MlsError>147 pub(crate) async fn update_parent_hashes<P: CipherSuiteProvider>( 148 &mut self, 149 index: LeafIndex, 150 verify_leaf_hash: bool, 151 cipher_suite_provider: &P, 152 ) -> Result<(), MlsError> { 153 // First update the relevant original hashes used for parent hash computation. 154 self.update_hashes(&[index], cipher_suite_provider).await?; 155 156 let leaf_hash = self 157 .parent_hash_for_leaf(cipher_suite_provider, index) 158 .await?; 159 160 let leaf = self.nodes.borrow_as_leaf_mut(index)?; 161 162 if verify_leaf_hash { 163 // Verify the parent hash of the new sender leaf node and update the parent hash values 164 // in the local tree 165 if let LeafNodeSource::Commit(parent_hash) = &leaf.leaf_node_source { 166 if !leaf_hash.matches(parent_hash) { 167 return Err(MlsError::ParentHashMismatch); 168 } 169 } else { 170 return Err(MlsError::InvalidLeafNodeSource); 171 } 172 } else { 173 leaf.leaf_node_source = LeafNodeSource::Commit(leaf_hash); 174 } 175 176 // Update hashes after changes to the tree. 177 self.update_hashes(&[index], cipher_suite_provider).await 178 } 179 180 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] validate_parent_hashes<P: CipherSuiteProvider>( &self, cipher_suite_provider: &P, ) -> Result<(), MlsError>181 pub(super) async fn validate_parent_hashes<P: CipherSuiteProvider>( 182 &self, 183 cipher_suite_provider: &P, 184 ) -> Result<(), MlsError> { 185 let original_hashes = self.compute_original_hashes(cipher_suite_provider).await?; 186 187 let nodes_to_validate = self 188 .nodes 189 .non_empty_parents() 190 .map(|(node_index, _)| node_index); 191 192 #[cfg(feature = "std")] 193 let mut nodes_to_validate = nodes_to_validate.collect::<HashSet<_>>(); 194 #[cfg(not(feature = "std"))] 195 let mut nodes_to_validate = nodes_to_validate.collect::<BTreeSet<_>>(); 196 197 let num_leaves = self.total_leaf_count(); 198 199 // For each leaf l, validate all non-blank nodes on the chain from l up the tree. 200 for (leaf_index, _) in self.nodes.non_empty_leaves() { 201 let mut n = NodeIndex::from(leaf_index); 202 203 while let Some(mut ps) = n.parent_sibling(&num_leaves) { 204 // Find the first non-blank ancestor p of n and p's co-path child s. 205 while self.nodes.is_blank(ps.parent)? { 206 // If we reached the root, we're done with this chain. 207 let Some(ps_parent) = ps.parent.parent_sibling(&num_leaves) else { 208 return Ok(()); 209 }; 210 211 ps = ps_parent; 212 } 213 214 // Check is n's parent_hash field matches the parent hash of p with co-path child s. 215 let p_parent = self.nodes.borrow_as_parent(ps.parent)?; 216 217 let n_node = self 218 .nodes 219 .borrow_node(n)? 220 .as_ref() 221 .ok_or(MlsError::ExpectedNode)?; 222 223 let calculated = ParentHash::new( 224 cipher_suite_provider, 225 &p_parent.public_key, 226 &p_parent.parent_hash, 227 &original_hashes[ps.sibling as usize], 228 ) 229 .await?; 230 231 if n_node.get_parent_hash() == Some(calculated) { 232 // Check that "n is in the resolution of c, and the intersection of p's unmerged_leaves with the subtree 233 // under c is equal to the resolution of c with n removed". 234 let Some(cp) = ps.sibling.parent_sibling(&num_leaves) else { 235 return Err(MlsError::ParentHashMismatch); 236 }; 237 238 let c = cp.sibling; 239 let c_resolution = self.nodes.get_resolution_index(c)?.into_iter(); 240 241 #[cfg(feature = "std")] 242 let mut c_resolution = c_resolution.collect::<HashSet<_>>(); 243 #[cfg(not(feature = "std"))] 244 let mut c_resolution = c_resolution.collect::<BTreeSet<_>>(); 245 246 let p_unmerged_in_c_subtree = self 247 .unmerged_in_subtree(ps.parent, c)? 248 .iter() 249 .copied() 250 .map(|x| *x * 2); 251 252 #[cfg(feature = "std")] 253 let p_unmerged_in_c_subtree = p_unmerged_in_c_subtree.collect::<HashSet<_>>(); 254 #[cfg(not(feature = "std"))] 255 let p_unmerged_in_c_subtree = p_unmerged_in_c_subtree.collect::<BTreeSet<_>>(); 256 257 if c_resolution.remove(&n) 258 && c_resolution == p_unmerged_in_c_subtree 259 && nodes_to_validate.remove(&ps.parent) 260 { 261 // If n's parent_hash field matches and p has not been validated yet, mark p as validated and continue. 262 n = ps.parent; 263 } else { 264 // If p is validated for the second time, the check fails ("all non-blank parent nodes are covered by exactly one such chain"). 265 return Err(MlsError::ParentHashMismatch); 266 } 267 } else { 268 // If n's parent_hash field doesn't match, we're done with this chain. 269 break; 270 } 271 } 272 } 273 274 // The check passes iff all non-blank nodes are validated. 275 if nodes_to_validate.is_empty() { 276 Ok(()) 277 } else { 278 Err(MlsError::ParentHashMismatch) 279 } 280 } 281 } 282 283 #[cfg(test)] 284 pub(crate) mod test_utils { 285 286 use super::*; 287 use crate::{ 288 cipher_suite::CipherSuite, 289 crypto::test_utils::test_cipher_suite_provider, 290 identity::basic::BasicIdentityProvider, 291 tree_kem::{leaf_node::test_utils::get_basic_test_node, node::Parent}, 292 }; 293 294 use alloc::vec; 295 296 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] test_parent( cipher_suite: CipherSuite, unmerged_leaves: Vec<LeafIndex>, ) -> Parent297 pub(crate) async fn test_parent( 298 cipher_suite: CipherSuite, 299 unmerged_leaves: Vec<LeafIndex>, 300 ) -> Parent { 301 let (_, public_key) = test_cipher_suite_provider(cipher_suite) 302 .kem_generate() 303 .await 304 .unwrap(); 305 306 Parent { 307 public_key, 308 parent_hash: ParentHash::empty(), 309 unmerged_leaves, 310 } 311 } 312 313 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] test_parent_node( cipher_suite: CipherSuite, unmerged_leaves: Vec<LeafIndex>, ) -> Node314 pub(crate) async fn test_parent_node( 315 cipher_suite: CipherSuite, 316 unmerged_leaves: Vec<LeafIndex>, 317 ) -> Node { 318 Node::Parent(test_parent(cipher_suite, unmerged_leaves).await) 319 } 320 321 // Create figure 12 from MLS RFC 322 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] get_test_tree_fig_12(cipher_suite: CipherSuite) -> TreeKemPublic323 pub(crate) async fn get_test_tree_fig_12(cipher_suite: CipherSuite) -> TreeKemPublic { 324 let cipher_suite_provider = test_cipher_suite_provider(cipher_suite); 325 326 let mut tree = TreeKemPublic::new(); 327 328 let mut leaves = Vec::new(); 329 330 for l in ["A", "B", "C", "D", "E", "F", "G"] { 331 leaves.push(get_basic_test_node(cipher_suite, l).await); 332 } 333 334 tree.add_leaves(leaves, &BasicIdentityProvider, &cipher_suite_provider) 335 .await 336 .unwrap(); 337 338 tree.nodes[1] = Some(test_parent_node(cipher_suite, vec![]).await); 339 tree.nodes[3] = Some(test_parent_node(cipher_suite, vec![LeafIndex(3)]).await); 340 341 tree.nodes[7] = 342 Some(test_parent_node(cipher_suite, vec![LeafIndex(3), LeafIndex(6)]).await); 343 344 tree.nodes[9] = Some(test_parent_node(cipher_suite, vec![LeafIndex(5)]).await); 345 346 tree.nodes[11] = 347 Some(test_parent_node(cipher_suite, vec![LeafIndex(5), LeafIndex(6)]).await); 348 349 tree.update_parent_hashes(LeafIndex(0), false, &cipher_suite_provider) 350 .await 351 .unwrap(); 352 353 tree.update_parent_hashes(LeafIndex(4), false, &cipher_suite_provider) 354 .await 355 .unwrap(); 356 357 tree 358 } 359 } 360 361 #[cfg(test)] 362 mod tests { 363 use super::*; 364 use crate::client::test_utils::TEST_CIPHER_SUITE; 365 use crate::crypto::test_utils::test_cipher_suite_provider; 366 use crate::tree_kem::leaf_node::test_utils::get_basic_test_node; 367 use crate::tree_kem::leaf_node::LeafNodeSource; 368 use crate::tree_kem::test_utils::TreeWithSigners; 369 use crate::tree_kem::MlsError; 370 use assert_matches::assert_matches; 371 372 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_missing_parent_hash()373 async fn test_missing_parent_hash() { 374 let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE); 375 let mut test_tree = TreeWithSigners::make_full_tree(8, &cs).await.tree; 376 377 *test_tree.nodes.borrow_as_leaf_mut(LeafIndex(0)).unwrap() = 378 get_basic_test_node(TEST_CIPHER_SUITE, "foo").await; 379 380 let missing_parent_hash_res = test_tree 381 .update_parent_hashes( 382 LeafIndex(0), 383 true, 384 &test_cipher_suite_provider(TEST_CIPHER_SUITE), 385 ) 386 .await; 387 388 assert_matches!( 389 missing_parent_hash_res, 390 Err(MlsError::InvalidLeafNodeSource) 391 ); 392 } 393 394 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_parent_hash_mismatch()395 async fn test_parent_hash_mismatch() { 396 let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE); 397 let mut test_tree = TreeWithSigners::make_full_tree(8, &cs).await.tree; 398 399 let unexpected_parent_hash = ParentHash::from(hex!("f00d")); 400 401 test_tree 402 .nodes 403 .borrow_as_leaf_mut(LeafIndex(0)) 404 .unwrap() 405 .leaf_node_source = LeafNodeSource::Commit(unexpected_parent_hash); 406 407 let invalid_parent_hash_res = test_tree 408 .update_parent_hashes( 409 LeafIndex(0), 410 true, 411 &test_cipher_suite_provider(TEST_CIPHER_SUITE), 412 ) 413 .await; 414 415 assert_matches!(invalid_parent_hash_res, Err(MlsError::ParentHashMismatch)); 416 } 417 418 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_parent_hash_invalid()419 async fn test_parent_hash_invalid() { 420 let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE); 421 let mut test_tree = TreeWithSigners::make_full_tree(8, &cs).await.tree; 422 423 test_tree.nodes[2] = None; 424 425 let res = test_tree 426 .validate_parent_hashes(&test_cipher_suite_provider(TEST_CIPHER_SUITE)) 427 .await; 428 429 assert_matches!(res, Err(MlsError::ParentHashMismatch)); 430 } 431 } 432