1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec;
6 use alloc::vec::Vec;
7 use core::{
8     fmt::{self, Debug},
9     ops::Deref,
10 };
11 use mls_rs_core::crypto::CipherSuiteProvider;
12 use zeroize::Zeroizing;
13 
14 #[cfg(feature = "psk")]
15 use mls_rs_codec::MlsEncode;
16 
17 #[cfg(feature = "psk")]
18 use mls_rs_core::{error::IntoAnyError, psk::PreSharedKey};
19 
20 #[cfg(feature = "psk")]
21 use crate::{
22     client::MlsError,
23     group::key_schedule::kdf_expand_with_label,
24     psk::{PSKLabel, PreSharedKeyID},
25 };
26 
27 #[cfg(feature = "psk")]
28 #[derive(Clone)]
29 pub(crate) struct PskSecretInput {
30     pub id: PreSharedKeyID,
31     pub psk: PreSharedKey,
32 }
33 
34 #[derive(PartialEq, Eq, Clone)]
35 pub(crate) struct PskSecret(Zeroizing<Vec<u8>>);
36 
37 impl Debug for PskSecret {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result38     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39         mls_rs_core::debug::pretty_bytes(&self.0)
40             .named("PskSecret")
41             .fmt(f)
42     }
43 }
44 
45 #[cfg(test)]
46 impl From<Vec<u8>> for PskSecret {
from(value: Vec<u8>) -> Self47     fn from(value: Vec<u8>) -> Self {
48         PskSecret(Zeroizing::new(value))
49     }
50 }
51 
52 impl Deref for PskSecret {
53     type Target = [u8];
54 
deref(&self) -> &Self::Target55     fn deref(&self) -> &Self::Target {
56         &self.0
57     }
58 }
59 
60 impl PskSecret {
new<P: CipherSuiteProvider>(provider: &P) -> PskSecret61     pub(crate) fn new<P: CipherSuiteProvider>(provider: &P) -> PskSecret {
62         PskSecret(Zeroizing::new(vec![0u8; provider.kdf_extract_size()]))
63     }
64 
65     #[cfg(feature = "psk")]
66     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
calculate<P: CipherSuiteProvider>( input: &[PskSecretInput], cipher_suite_provider: &P, ) -> Result<PskSecret, MlsError>67     pub(crate) async fn calculate<P: CipherSuiteProvider>(
68         input: &[PskSecretInput],
69         cipher_suite_provider: &P,
70     ) -> Result<PskSecret, MlsError> {
71         let len = u16::try_from(input.len()).map_err(|_| MlsError::TooManyPskIds)?;
72         let mut psk_secret = PskSecret::new(cipher_suite_provider);
73 
74         for (index, psk_secret_input) in input.iter().enumerate() {
75             let index = index as u16;
76 
77             let label = PSKLabel {
78                 id: &psk_secret_input.id,
79                 index,
80                 count: len,
81             };
82 
83             let psk_extracted = cipher_suite_provider
84                 .kdf_extract(
85                     &vec![0; cipher_suite_provider.kdf_extract_size()],
86                     &psk_secret_input.psk,
87                 )
88                 .await
89                 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
90 
91             let psk_input = kdf_expand_with_label(
92                 cipher_suite_provider,
93                 &psk_extracted,
94                 b"derived psk",
95                 &label.mls_encode_to_vec()?,
96                 None,
97             )
98             .await?;
99 
100             psk_secret = cipher_suite_provider
101                 .kdf_extract(&psk_input, &psk_secret)
102                 .await
103                 .map(PskSecret)
104                 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
105         }
106 
107         Ok(psk_secret)
108     }
109 }
110 
111 #[cfg(feature = "psk")]
112 #[cfg(test)]
113 mod tests {
114     use alloc::vec::Vec;
115     #[cfg(not(mls_build_async))]
116     use core::iter;
117     use serde::{Deserialize, Serialize};
118 
119     use crate::{
120         crypto::test_utils::try_test_cipher_suite_provider,
121         psk::ExternalPskId,
122         psk::{JustPreSharedKeyID, PreSharedKeyID, PskNonce},
123         CipherSuiteProvider,
124     };
125 
126     #[cfg(not(mls_build_async))]
127     use crate::{
128         crypto::test_utils::test_cipher_suite_provider, psk::test_utils::make_external_psk_id,
129         CipherSuite,
130     };
131 
132     use super::{PskSecret, PskSecretInput};
133 
134     #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
135     struct PskInfo {
136         #[serde(with = "hex::serde")]
137         id: Vec<u8>,
138         #[serde(with = "hex::serde")]
139         psk: Vec<u8>,
140         #[serde(with = "hex::serde")]
141         nonce: Vec<u8>,
142     }
143 
144     impl From<PskInfo> for PskSecretInput {
from(info: PskInfo) -> Self145         fn from(info: PskInfo) -> Self {
146             let id = PreSharedKeyID {
147                 key_id: JustPreSharedKeyID::External(ExternalPskId::new(info.id)),
148                 psk_nonce: PskNonce(info.nonce),
149             };
150 
151             PskSecretInput {
152                 id,
153                 psk: info.psk.into(),
154             }
155         }
156     }
157 
158     #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
159     struct TestScenario {
160         cipher_suite: u16,
161         psks: Vec<PskInfo>,
162         #[serde(with = "hex::serde")]
163         psk_secret: Vec<u8>,
164     }
165 
166     impl TestScenario {
167         #[cfg_attr(coverage_nightly, coverage(off))]
168         #[cfg(not(mls_build_async))]
make_psk_list<CS: CipherSuiteProvider>(cs: &CS, n: usize) -> Vec<PskInfo>169         fn make_psk_list<CS: CipherSuiteProvider>(cs: &CS, n: usize) -> Vec<PskInfo> {
170             iter::repeat_with(
171                 #[cfg_attr(coverage_nightly, coverage(off))]
172                 || PskInfo {
173                     id: make_external_psk_id(cs).to_vec(),
174                     psk: cs.random_bytes_vec(cs.kdf_extract_size()).unwrap(),
175                     nonce: crate::psk::test_utils::make_nonce(cs.cipher_suite()).0,
176                 },
177             )
178             .take(n)
179             .collect::<Vec<_>>()
180         }
181 
182         #[cfg(not(mls_build_async))]
183         #[cfg_attr(coverage_nightly, coverage(off))]
generate() -> Vec<TestScenario>184         fn generate() -> Vec<TestScenario> {
185             CipherSuite::all()
186                 .flat_map(
187                     #[cfg_attr(coverage_nightly, coverage(off))]
188                     |cs| (1..=10).map(move |n| (cs, n)),
189                 )
190                 .map(
191                     #[cfg_attr(coverage_nightly, coverage(off))]
192                     |(cs, n)| {
193                         let provider = test_cipher_suite_provider(cs);
194                         let psks = Self::make_psk_list(&provider, n);
195                         let psk_secret = Self::compute_psk_secret(&provider, psks.clone());
196                         TestScenario {
197                             cipher_suite: cs.into(),
198                             psks: psks.to_vec(),
199                             psk_secret: psk_secret.to_vec(),
200                         }
201                     },
202                 )
203                 .collect()
204         }
205 
206         #[cfg(mls_build_async)]
generate() -> Vec<TestScenario>207         fn generate() -> Vec<TestScenario> {
208             panic!("Tests cannot be generated in async mode");
209         }
210 
211         #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
compute_psk_secret<P: CipherSuiteProvider>( provider: &P, psks: Vec<PskInfo>, ) -> PskSecret212         async fn compute_psk_secret<P: CipherSuiteProvider>(
213             provider: &P,
214             psks: Vec<PskInfo>,
215         ) -> PskSecret {
216             let input = psks
217                 .into_iter()
218                 .map(PskSecretInput::from)
219                 .collect::<Vec<_>>();
220 
221             PskSecret::calculate(&input, provider).await.unwrap()
222         }
223     }
224 
225     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
expected_psk_secret_is_produced()226     async fn expected_psk_secret_is_produced() {
227         let scenarios: Vec<TestScenario> =
228             load_test_case_json!(psk_secret, TestScenario::generate());
229 
230         for scenario in scenarios {
231             if let Some(provider) = try_test_cipher_suite_provider(scenario.cipher_suite) {
232                 let computed =
233                     TestScenario::compute_psk_secret(&provider, scenario.psks.clone()).await;
234 
235                 assert_eq!(scenario.psk_secret, computed.to_vec());
236             }
237         }
238     }
239 }
240