1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use alloc::vec; 6 use alloc::vec::Vec; 7 use core::{ 8 fmt::{self, Debug}, 9 ops::Deref, 10 }; 11 use mls_rs_core::crypto::CipherSuiteProvider; 12 use zeroize::Zeroizing; 13 14 #[cfg(feature = "psk")] 15 use mls_rs_codec::MlsEncode; 16 17 #[cfg(feature = "psk")] 18 use mls_rs_core::{error::IntoAnyError, psk::PreSharedKey}; 19 20 #[cfg(feature = "psk")] 21 use crate::{ 22 client::MlsError, 23 group::key_schedule::kdf_expand_with_label, 24 psk::{PSKLabel, PreSharedKeyID}, 25 }; 26 27 #[cfg(feature = "psk")] 28 #[derive(Clone)] 29 pub(crate) struct PskSecretInput { 30 pub id: PreSharedKeyID, 31 pub psk: PreSharedKey, 32 } 33 34 #[derive(PartialEq, Eq, Clone)] 35 pub(crate) struct PskSecret(Zeroizing<Vec<u8>>); 36 37 impl Debug for PskSecret { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 39 mls_rs_core::debug::pretty_bytes(&self.0) 40 .named("PskSecret") 41 .fmt(f) 42 } 43 } 44 45 #[cfg(test)] 46 impl From<Vec<u8>> for PskSecret { from(value: Vec<u8>) -> Self47 fn from(value: Vec<u8>) -> Self { 48 PskSecret(Zeroizing::new(value)) 49 } 50 } 51 52 impl Deref for PskSecret { 53 type Target = [u8]; 54 deref(&self) -> &Self::Target55 fn deref(&self) -> &Self::Target { 56 &self.0 57 } 58 } 59 60 impl PskSecret { new<P: CipherSuiteProvider>(provider: &P) -> PskSecret61 pub(crate) fn new<P: CipherSuiteProvider>(provider: &P) -> PskSecret { 62 PskSecret(Zeroizing::new(vec![0u8; provider.kdf_extract_size()])) 63 } 64 65 #[cfg(feature = "psk")] 66 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] calculate<P: CipherSuiteProvider>( input: &[PskSecretInput], cipher_suite_provider: &P, ) -> Result<PskSecret, MlsError>67 pub(crate) async fn calculate<P: CipherSuiteProvider>( 68 input: &[PskSecretInput], 69 cipher_suite_provider: &P, 70 ) -> Result<PskSecret, MlsError> { 71 let len = u16::try_from(input.len()).map_err(|_| MlsError::TooManyPskIds)?; 72 let mut psk_secret = PskSecret::new(cipher_suite_provider); 73 74 for (index, psk_secret_input) in input.iter().enumerate() { 75 let index = index as u16; 76 77 let label = PSKLabel { 78 id: &psk_secret_input.id, 79 index, 80 count: len, 81 }; 82 83 let psk_extracted = cipher_suite_provider 84 .kdf_extract( 85 &vec![0; cipher_suite_provider.kdf_extract_size()], 86 &psk_secret_input.psk, 87 ) 88 .await 89 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?; 90 91 let psk_input = kdf_expand_with_label( 92 cipher_suite_provider, 93 &psk_extracted, 94 b"derived psk", 95 &label.mls_encode_to_vec()?, 96 None, 97 ) 98 .await?; 99 100 psk_secret = cipher_suite_provider 101 .kdf_extract(&psk_input, &psk_secret) 102 .await 103 .map(PskSecret) 104 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?; 105 } 106 107 Ok(psk_secret) 108 } 109 } 110 111 #[cfg(feature = "psk")] 112 #[cfg(test)] 113 mod tests { 114 use alloc::vec::Vec; 115 #[cfg(not(mls_build_async))] 116 use core::iter; 117 use serde::{Deserialize, Serialize}; 118 119 use crate::{ 120 crypto::test_utils::try_test_cipher_suite_provider, 121 psk::ExternalPskId, 122 psk::{JustPreSharedKeyID, PreSharedKeyID, PskNonce}, 123 CipherSuiteProvider, 124 }; 125 126 #[cfg(not(mls_build_async))] 127 use crate::{ 128 crypto::test_utils::test_cipher_suite_provider, psk::test_utils::make_external_psk_id, 129 CipherSuite, 130 }; 131 132 use super::{PskSecret, PskSecretInput}; 133 134 #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] 135 struct PskInfo { 136 #[serde(with = "hex::serde")] 137 id: Vec<u8>, 138 #[serde(with = "hex::serde")] 139 psk: Vec<u8>, 140 #[serde(with = "hex::serde")] 141 nonce: Vec<u8>, 142 } 143 144 impl From<PskInfo> for PskSecretInput { from(info: PskInfo) -> Self145 fn from(info: PskInfo) -> Self { 146 let id = PreSharedKeyID { 147 key_id: JustPreSharedKeyID::External(ExternalPskId::new(info.id)), 148 psk_nonce: PskNonce(info.nonce), 149 }; 150 151 PskSecretInput { 152 id, 153 psk: info.psk.into(), 154 } 155 } 156 } 157 158 #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] 159 struct TestScenario { 160 cipher_suite: u16, 161 psks: Vec<PskInfo>, 162 #[serde(with = "hex::serde")] 163 psk_secret: Vec<u8>, 164 } 165 166 impl TestScenario { 167 #[cfg_attr(coverage_nightly, coverage(off))] 168 #[cfg(not(mls_build_async))] make_psk_list<CS: CipherSuiteProvider>(cs: &CS, n: usize) -> Vec<PskInfo>169 fn make_psk_list<CS: CipherSuiteProvider>(cs: &CS, n: usize) -> Vec<PskInfo> { 170 iter::repeat_with( 171 #[cfg_attr(coverage_nightly, coverage(off))] 172 || PskInfo { 173 id: make_external_psk_id(cs).to_vec(), 174 psk: cs.random_bytes_vec(cs.kdf_extract_size()).unwrap(), 175 nonce: crate::psk::test_utils::make_nonce(cs.cipher_suite()).0, 176 }, 177 ) 178 .take(n) 179 .collect::<Vec<_>>() 180 } 181 182 #[cfg(not(mls_build_async))] 183 #[cfg_attr(coverage_nightly, coverage(off))] generate() -> Vec<TestScenario>184 fn generate() -> Vec<TestScenario> { 185 CipherSuite::all() 186 .flat_map( 187 #[cfg_attr(coverage_nightly, coverage(off))] 188 |cs| (1..=10).map(move |n| (cs, n)), 189 ) 190 .map( 191 #[cfg_attr(coverage_nightly, coverage(off))] 192 |(cs, n)| { 193 let provider = test_cipher_suite_provider(cs); 194 let psks = Self::make_psk_list(&provider, n); 195 let psk_secret = Self::compute_psk_secret(&provider, psks.clone()); 196 TestScenario { 197 cipher_suite: cs.into(), 198 psks: psks.to_vec(), 199 psk_secret: psk_secret.to_vec(), 200 } 201 }, 202 ) 203 .collect() 204 } 205 206 #[cfg(mls_build_async)] generate() -> Vec<TestScenario>207 fn generate() -> Vec<TestScenario> { 208 panic!("Tests cannot be generated in async mode"); 209 } 210 211 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] compute_psk_secret<P: CipherSuiteProvider>( provider: &P, psks: Vec<PskInfo>, ) -> PskSecret212 async fn compute_psk_secret<P: CipherSuiteProvider>( 213 provider: &P, 214 psks: Vec<PskInfo>, 215 ) -> PskSecret { 216 let input = psks 217 .into_iter() 218 .map(PskSecretInput::from) 219 .collect::<Vec<_>>(); 220 221 PskSecret::calculate(&input, provider).await.unwrap() 222 } 223 } 224 225 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] expected_psk_secret_is_produced()226 async fn expected_psk_secret_is_produced() { 227 let scenarios: Vec<TestScenario> = 228 load_test_case_json!(psk_secret, TestScenario::generate()); 229 230 for scenario in scenarios { 231 if let Some(provider) = try_test_cipher_suite_provider(scenario.cipher_suite) { 232 let computed = 233 TestScenario::compute_psk_secret(&provider, scenario.psks.clone()).await; 234 235 assert_eq!(scenario.psk_secret, computed.to_vec()); 236 } 237 } 238 } 239 } 240