1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::{vec, vec::Vec};
6 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
7 use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider};
8 
9 use super::{
10     leaf_node::LeafNode,
11     leaf_node_validator::{LeafNodeValidator, ValidationContext},
12     node::LeafIndex,
13 };
14 use crate::{
15     client::MlsError,
16     crypto::{CipherSuiteProvider, HpkeCiphertext, HpkePublicKey},
17 };
18 use crate::{group::message_processor::ProvisionalState, time::MlsTime};
19 
20 #[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
21 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
22 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23 pub struct UpdatePathNode {
24     pub public_key: HpkePublicKey,
25     pub encrypted_path_secret: Vec<HpkeCiphertext>,
26 }
27 
28 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
29 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
30 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31 pub struct UpdatePath {
32     pub leaf_node: LeafNode,
33     pub nodes: Vec<UpdatePathNode>,
34 }
35 
36 #[derive(Clone, Debug, PartialEq)]
37 pub struct ValidatedUpdatePath {
38     pub leaf_node: LeafNode,
39     pub nodes: Vec<Option<UpdatePathNode>>,
40 }
41 
42 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
validate_update_path<C: IdentityProvider, CSP: CipherSuiteProvider>( identity_provider: &C, cipher_suite_provider: &CSP, path: UpdatePath, state: &ProvisionalState, sender: LeafIndex, commit_time: Option<MlsTime>, ) -> Result<ValidatedUpdatePath, MlsError>43 pub(crate) async fn validate_update_path<C: IdentityProvider, CSP: CipherSuiteProvider>(
44     identity_provider: &C,
45     cipher_suite_provider: &CSP,
46     path: UpdatePath,
47     state: &ProvisionalState,
48     sender: LeafIndex,
49     commit_time: Option<MlsTime>,
50 ) -> Result<ValidatedUpdatePath, MlsError> {
51     let group_context_extensions = &state.group_context.extensions;
52 
53     let leaf_validator = LeafNodeValidator::new(
54         cipher_suite_provider,
55         identity_provider,
56         Some(group_context_extensions),
57     );
58 
59     leaf_validator
60         .check_if_valid(
61             &path.leaf_node,
62             ValidationContext::Commit((&state.group_context.group_id, *sender, commit_time)),
63         )
64         .await?;
65 
66     let check_identity_eq = state.applied_proposals.external_initializations.is_empty();
67 
68     if check_identity_eq {
69         let existing_leaf = state.public_tree.nodes.borrow_as_leaf(sender)?;
70         let original_leaf_node = existing_leaf.clone();
71 
72         identity_provider
73             .valid_successor(
74                 &original_leaf_node.signing_identity,
75                 &path.leaf_node.signing_identity,
76                 group_context_extensions,
77             )
78             .await
79             .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?
80             .then_some(())
81             .ok_or(MlsError::InvalidSuccessor)?;
82 
83         (existing_leaf.public_key != path.leaf_node.public_key)
84             .then_some(())
85             .ok_or(MlsError::SameHpkeKey(*sender))?;
86     }
87 
88     // Unfilter the update path
89     let filtered = state.public_tree.nodes.filtered(sender)?;
90     let mut unfiltered_nodes = vec![];
91     let mut i = 0;
92 
93     for n in path.nodes {
94         while *filtered.get(i).ok_or(MlsError::WrongPathLen)? {
95             unfiltered_nodes.push(None);
96             i += 1;
97         }
98 
99         unfiltered_nodes.push(Some(n));
100         i += 1;
101     }
102 
103     Ok(ValidatedUpdatePath {
104         leaf_node: path.leaf_node,
105         nodes: unfiltered_nodes,
106     })
107 }
108 
109 #[cfg(test)]
110 mod tests {
111     use alloc::vec;
112     use assert_matches::assert_matches;
113 
114     use crate::client::test_utils::TEST_CIPHER_SUITE;
115     use crate::crypto::test_utils::test_cipher_suite_provider;
116     use crate::crypto::HpkeCiphertext;
117     use crate::group::message_processor::ProvisionalState;
118     use crate::group::test_utils::{get_test_group_context, random_bytes, TEST_GROUP};
119     use crate::identity::basic::BasicIdentityProvider;
120     use crate::tree_kem::leaf_node::test_utils::default_properties;
121     use crate::tree_kem::leaf_node::test_utils::get_basic_test_node_sig_key;
122     use crate::tree_kem::leaf_node::LeafNodeSource;
123     use crate::tree_kem::node::LeafIndex;
124     use crate::tree_kem::parent_hash::ParentHash;
125     use crate::tree_kem::test_utils::{get_test_leaf_nodes, get_test_tree};
126     use crate::tree_kem::validate_update_path;
127 
128     use super::{UpdatePath, UpdatePathNode};
129     use crate::{cipher_suite::CipherSuite, tree_kem::MlsError};
130 
131     use alloc::vec::Vec;
132 
133     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_update_path(cipher_suite: CipherSuite, cred: &str) -> UpdatePath134     async fn test_update_path(cipher_suite: CipherSuite, cred: &str) -> UpdatePath {
135         let (mut leaf_node, _, signer) = get_basic_test_node_sig_key(cipher_suite, cred).await;
136 
137         leaf_node.leaf_node_source = LeafNodeSource::Commit(ParentHash::from(hex!("beef")));
138 
139         leaf_node
140             .commit(
141                 &test_cipher_suite_provider(cipher_suite),
142                 TEST_GROUP,
143                 0,
144                 default_properties(),
145                 None,
146                 &signer,
147             )
148             .await
149             .unwrap();
150 
151         let node = UpdatePathNode {
152             public_key: random_bytes(32).into(),
153             encrypted_path_secret: vec![HpkeCiphertext {
154                 kem_output: random_bytes(32),
155                 ciphertext: random_bytes(32),
156             }],
157         };
158 
159         UpdatePath {
160             leaf_node,
161             nodes: vec![node.clone(), node],
162         }
163     }
164 
165     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_provisional_state(cipher_suite: CipherSuite) -> ProvisionalState166     async fn test_provisional_state(cipher_suite: CipherSuite) -> ProvisionalState {
167         let mut tree = get_test_tree(cipher_suite).await.public;
168         let leaf_nodes = get_test_leaf_nodes(cipher_suite).await;
169 
170         tree.add_leaves(
171             leaf_nodes,
172             &BasicIdentityProvider,
173             &test_cipher_suite_provider(cipher_suite),
174         )
175         .await
176         .unwrap();
177 
178         ProvisionalState {
179             public_tree: tree,
180             applied_proposals: Default::default(),
181             group_context: get_test_group_context(1, cipher_suite).await,
182             indexes_of_added_kpkgs: vec![],
183             external_init_index: None,
184             #[cfg(feature = "state_update")]
185             unused_proposals: vec![],
186         }
187     }
188 
189     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_valid_leaf_node()190     async fn test_valid_leaf_node() {
191         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
192         let update_path = test_update_path(TEST_CIPHER_SUITE, "creator").await;
193 
194         let validated = validate_update_path(
195             &BasicIdentityProvider,
196             &cipher_suite_provider,
197             update_path.clone(),
198             &test_provisional_state(TEST_CIPHER_SUITE).await,
199             LeafIndex(0),
200             None,
201         )
202         .await
203         .unwrap();
204 
205         let expected = update_path.nodes.into_iter().map(Some).collect::<Vec<_>>();
206 
207         assert_eq!(validated.nodes, expected);
208         assert_eq!(validated.leaf_node, update_path.leaf_node);
209     }
210 
211     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_invalid_key_package()212     async fn test_invalid_key_package() {
213         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
214         let mut update_path = test_update_path(TEST_CIPHER_SUITE, "creator").await;
215         update_path.leaf_node.signature = random_bytes(32);
216 
217         let validated = validate_update_path(
218             &BasicIdentityProvider,
219             &cipher_suite_provider,
220             update_path,
221             &test_provisional_state(TEST_CIPHER_SUITE).await,
222             LeafIndex(0),
223             None,
224         )
225         .await;
226 
227         assert_matches!(validated, Err(MlsError::InvalidSignature));
228     }
229 
230     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
validating_path_fails_with_different_identity()231     async fn validating_path_fails_with_different_identity() {
232         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
233         let cipher_suite = TEST_CIPHER_SUITE;
234         let update_path = test_update_path(cipher_suite, "foobar").await;
235 
236         let validated = validate_update_path(
237             &BasicIdentityProvider,
238             &cipher_suite_provider,
239             update_path,
240             &test_provisional_state(cipher_suite).await,
241             LeafIndex(0),
242             None,
243         )
244         .await;
245 
246         assert_matches!(validated, Err(MlsError::InvalidSuccessor));
247     }
248 
249     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
validating_path_fails_with_same_hpke_key()250     async fn validating_path_fails_with_same_hpke_key() {
251         let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
252         let update_path = test_update_path(TEST_CIPHER_SUITE, "creator").await;
253         let mut state = test_provisional_state(TEST_CIPHER_SUITE).await;
254 
255         state
256             .public_tree
257             .nodes
258             .borrow_as_leaf_mut(LeafIndex(0))
259             .unwrap()
260             .public_key = update_path.leaf_node.public_key.clone();
261 
262         let validated = validate_update_path(
263             &BasicIdentityProvider,
264             &cipher_suite_provider,
265             update_path,
266             &state,
267             LeafIndex(0),
268             None,
269         )
270         .await;
271 
272         assert_matches!(validated, Err(MlsError::SameHpkeKey(_)));
273     }
274 }
275