1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use crate::client::MlsError;
6 use crate::crypto::{CipherSuiteProvider, HpkePublicKey, HpkeSecretKey};
7 use crate::group::key_schedule::kdf_derive_secret;
8 use alloc::vec;
9 use alloc::vec::Vec;
10 use core::{
11     fmt::{self, Debug},
12     ops::Deref,
13 };
14 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
15 use mls_rs_core::error::IntoAnyError;
16 use zeroize::Zeroizing;
17 
18 use super::hpke_encryption::HpkeEncryptable;
19 
20 #[derive(Clone, Eq, PartialEq, MlsSize, MlsEncode, MlsDecode)]
21 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
22 pub struct PathSecret(
23     #[mls_codec(with = "mls_rs_codec::byte_vec")]
24     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
25     Zeroizing<Vec<u8>>,
26 );
27 
28 impl Debug for PathSecret {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result29     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30         mls_rs_core::debug::pretty_bytes(&self.0)
31             .named("PathSecret")
32             .fmt(f)
33     }
34 }
35 
36 impl Deref for PathSecret {
37     type Target = Vec<u8>;
38 
deref(&self) -> &Self::Target39     fn deref(&self) -> &Self::Target {
40         &self.0
41     }
42 }
43 
44 impl From<Vec<u8>> for PathSecret {
from(data: Vec<u8>) -> Self45     fn from(data: Vec<u8>) -> Self {
46         PathSecret(Zeroizing::new(data))
47     }
48 }
49 
50 impl From<Zeroizing<Vec<u8>>> for PathSecret {
from(data: Zeroizing<Vec<u8>>) -> Self51     fn from(data: Zeroizing<Vec<u8>>) -> Self {
52         PathSecret(data)
53     }
54 }
55 
56 impl PathSecret {
random<P: CipherSuiteProvider>( cipher_suite_provider: &P, ) -> Result<PathSecret, MlsError>57     pub fn random<P: CipherSuiteProvider>(
58         cipher_suite_provider: &P,
59     ) -> Result<PathSecret, MlsError> {
60         cipher_suite_provider
61             .random_bytes_vec(cipher_suite_provider.kdf_extract_size())
62             .map(Into::into)
63             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
64     }
65 
empty<P: CipherSuiteProvider>(cipher_suite_provider: &P) -> Self66     pub fn empty<P: CipherSuiteProvider>(cipher_suite_provider: &P) -> Self {
67         // Define commit_secret as the all-zero vector of the same length as a path_secret
68         PathSecret::from(vec![0u8; cipher_suite_provider.kdf_extract_size()])
69     }
70 }
71 
72 impl HpkeEncryptable for PathSecret {
73     const ENCRYPT_LABEL: &'static str = "UpdatePathNode";
74 
from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError>75     fn from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError> {
76         Ok(Self(Zeroizing::new(bytes)))
77     }
78 
get_bytes(&self) -> Result<Vec<u8>, MlsError>79     fn get_bytes(&self) -> Result<Vec<u8>, MlsError> {
80         Ok(self.to_vec())
81     }
82 }
83 
84 impl PathSecret {
85     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
to_hpke_key_pair<P: CipherSuiteProvider>( &self, cs: &P, ) -> Result<(HpkeSecretKey, HpkePublicKey), MlsError>86     pub async fn to_hpke_key_pair<P: CipherSuiteProvider>(
87         &self,
88         cs: &P,
89     ) -> Result<(HpkeSecretKey, HpkePublicKey), MlsError> {
90         let node_secret = Zeroizing::new(kdf_derive_secret(cs, self, b"node").await?);
91 
92         cs.kem_derive(&node_secret)
93             .await
94             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
95     }
96 }
97 
98 #[derive(Clone, Debug)]
99 pub struct PathSecretGenerator<'a, P> {
100     cipher_suite_provider: &'a P,
101     last: Option<PathSecret>,
102     starting_with: Option<PathSecret>,
103 }
104 
105 impl<'a, P: CipherSuiteProvider> PathSecretGenerator<'a, P> {
new(cipher_suite_provider: &'a P) -> Self106     pub fn new(cipher_suite_provider: &'a P) -> Self {
107         Self {
108             cipher_suite_provider,
109             last: None,
110             starting_with: None,
111         }
112     }
113 
starting_with(cipher_suite_provider: &'a P, secret: PathSecret) -> Self114     pub fn starting_with(cipher_suite_provider: &'a P, secret: PathSecret) -> Self {
115         Self {
116             starting_with: Some(secret),
117             ..Self::new(cipher_suite_provider)
118         }
119     }
120 
121     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
next_secret(&mut self) -> Result<PathSecret, MlsError>122     pub async fn next_secret(&mut self) -> Result<PathSecret, MlsError> {
123         let secret = if let Some(starting_with) = self.starting_with.take() {
124             Ok(starting_with)
125         } else if let Some(last) = self.last.take() {
126             kdf_derive_secret(self.cipher_suite_provider, &last, b"path")
127                 .await
128                 .map(PathSecret::from)
129         } else {
130             PathSecret::random(self.cipher_suite_provider)
131         }?;
132 
133         self.last = Some(secret.clone());
134 
135         Ok(secret)
136     }
137 }
138 
139 #[cfg(test)]
140 mod tests {
141     use crate::{
142         cipher_suite::CipherSuite,
143         client::test_utils::TEST_CIPHER_SUITE,
144         crypto::test_utils::{
145             test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider,
146         },
147     };
148 
149     use super::*;
150 
151     use alloc::string::String;
152 
153     #[cfg(target_arch = "wasm32")]
154     use wasm_bindgen_test::wasm_bindgen_test as test;
155 
156     #[derive(serde::Deserialize, serde::Serialize)]
157     struct TestCase {
158         cipher_suite: u16,
159         generations: Vec<String>,
160     }
161 
162     impl TestCase {
163         #[cfg(not(mls_build_async))]
164         #[cfg_attr(coverage_nightly, coverage(off))]
generate() -> Vec<TestCase>165         fn generate() -> Vec<TestCase> {
166             CipherSuite::all()
167                 .map(
168                     #[cfg_attr(coverage_nightly, coverage(off))]
169                     |cipher_suite| {
170                         let cs_provider = test_cipher_suite_provider(cipher_suite);
171                         let mut generator = PathSecretGenerator::new(&cs_provider);
172 
173                         let generations = (0..10)
174                             .map(|_| hex::encode(&*generator.next_secret().unwrap()))
175                             .collect();
176 
177                         TestCase {
178                             cipher_suite: cipher_suite.into(),
179                             generations,
180                         }
181                     },
182                 )
183                 .collect()
184         }
185 
186         #[cfg(mls_build_async)]
generate() -> Vec<TestCase>187         fn generate() -> Vec<TestCase> {
188             panic!("Tests cannot be generated in async mode");
189         }
190     }
191 
load_test_cases() -> Vec<TestCase>192     fn load_test_cases() -> Vec<TestCase> {
193         load_test_case_json!(path_secret, TestCase::generate())
194     }
195 
196     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_path_secret_generation()197     async fn test_path_secret_generation() {
198         let cases = load_test_cases();
199 
200         for test_case in cases {
201             let Some(cs_provider) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
202                 continue;
203             };
204 
205             let first_secret = PathSecret::from(hex::decode(&test_case.generations[0]).unwrap());
206             let mut generator = PathSecretGenerator::starting_with(&cs_provider, first_secret);
207 
208             for expected in &test_case.generations {
209                 let generated = hex::encode(&*generator.next_secret().await.unwrap());
210                 assert_eq!(expected, &generated);
211             }
212         }
213     }
214 
215     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_first_path_is_random()216     async fn test_first_path_is_random() {
217         let cs_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
218 
219         let mut generator = PathSecretGenerator::new(&cs_provider);
220         let first_secret = generator.next_secret().await.unwrap();
221 
222         for _ in 0..100 {
223             let mut next_generator = PathSecretGenerator::new(&cs_provider);
224             let next_secret = next_generator.next_secret().await.unwrap();
225             assert_ne!(first_secret, next_secret);
226         }
227     }
228 
229     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_starting_with()230     async fn test_starting_with() {
231         let cs_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
232         let secret = PathSecret::random(&cs_provider).unwrap();
233 
234         let mut generator = PathSecretGenerator::starting_with(&cs_provider, secret.clone());
235 
236         let first_secret = generator.next_secret().await.unwrap();
237         let second_secret = generator.next_secret().await.unwrap();
238 
239         assert_eq!(secret, first_secret);
240         assert_ne!(first_secret, second_secret);
241     }
242 
243     #[test]
test_empty_path_secret()244     fn test_empty_path_secret() {
245         for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
246             let cs_provider = test_cipher_suite_provider(cipher_suite);
247             let empty = PathSecret::empty(&cs_provider);
248             assert_eq!(
249                 empty,
250                 PathSecret::from(vec![0u8; cs_provider.kdf_extract_size()])
251             )
252         }
253     }
254 
255     #[test]
test_random_path_secret()256     fn test_random_path_secret() {
257         let cs_provider = test_cipher_suite_provider(CipherSuite::P256_AES128);
258         let initial = PathSecret::random(&cs_provider).unwrap();
259 
260         for _ in 0..100 {
261             let next = PathSecret::random(&cs_provider).unwrap();
262             assert_ne!(next, initial);
263         }
264     }
265 }
266