1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec::Vec;
6 use core::fmt::{self, Debug};
7 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
8 use mls_rs_core::error::IntoAnyError;
9 use zeroize::Zeroizing;
10 
11 use crate::{
12     client::MlsError,
13     crypto::CipherSuiteProvider,
14     group::{epoch::SenderDataSecret, framing::ContentType, key_schedule::kdf_expand_with_label},
15     tree_kem::node::LeafIndex,
16 };
17 
18 use super::ReuseGuard;
19 
20 #[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
21 pub(crate) struct SenderData {
22     pub sender: LeafIndex,
23     pub generation: u32,
24     pub reuse_guard: ReuseGuard,
25 }
26 
27 #[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
28 pub(crate) struct SenderDataAAD {
29     #[mls_codec(with = "mls_rs_codec::byte_vec")]
30     pub group_id: Vec<u8>,
31     pub epoch: u64,
32     pub content_type: ContentType,
33 }
34 
35 impl Debug for SenderDataAAD {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result36     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37         f.debug_struct("SenderDataAAD")
38             .field(
39                 "group_id",
40                 &mls_rs_core::debug::pretty_group_id(&self.group_id),
41             )
42             .field("epoch", &self.epoch)
43             .field("content_type", &self.content_type)
44             .finish()
45     }
46 }
47 
48 pub(crate) struct SenderDataKey<'a, CP: CipherSuiteProvider> {
49     pub(crate) key: Zeroizing<Vec<u8>>,
50     pub(crate) nonce: Zeroizing<Vec<u8>>,
51     cipher_suite_provider: &'a CP,
52 }
53 
54 impl<CP: CipherSuiteProvider + Debug> Debug for SenderDataKey<'_, CP> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result55     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56         f.debug_struct("SenderDataKey")
57             .field("key", &mls_rs_core::debug::pretty_bytes(&self.key))
58             .field("nonce", &mls_rs_core::debug::pretty_bytes(&self.nonce))
59             .field("cipher_suite_provider", self.cipher_suite_provider)
60             .finish()
61     }
62 }
63 
64 impl<'a, CP: CipherSuiteProvider> SenderDataKey<'a, CP> {
65     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
new( sender_data_secret: &SenderDataSecret, ciphertext: &[u8], cipher_suite_provider: &'a CP, ) -> Result<SenderDataKey<'a, CP>, MlsError>66     pub(super) async fn new(
67         sender_data_secret: &SenderDataSecret,
68         ciphertext: &[u8],
69         cipher_suite_provider: &'a CP,
70     ) -> Result<SenderDataKey<'a, CP>, MlsError> {
71         // Sample the first extract_size bytes of the ciphertext, and if it is shorter, just use
72         // the ciphertext itself
73         let extract_size = cipher_suite_provider.kdf_extract_size();
74         let ciphertext_sample = ciphertext.get(0..extract_size).unwrap_or(ciphertext);
75 
76         // Generate a sender data key and nonce using the sender_data_secret from the current
77         // epoch's key schedule
78         let key = kdf_expand_with_label(
79             cipher_suite_provider,
80             sender_data_secret,
81             b"key",
82             ciphertext_sample,
83             Some(cipher_suite_provider.aead_key_size()),
84         )
85         .await?;
86 
87         let nonce = kdf_expand_with_label(
88             cipher_suite_provider,
89             sender_data_secret,
90             b"nonce",
91             ciphertext_sample,
92             Some(cipher_suite_provider.aead_nonce_size()),
93         )
94         .await?;
95 
96         Ok(Self {
97             key,
98             nonce,
99             cipher_suite_provider,
100         })
101     }
102 
103     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
seal( &self, sender_data: &SenderData, aad: &SenderDataAAD, ) -> Result<Vec<u8>, MlsError>104     pub(crate) async fn seal(
105         &self,
106         sender_data: &SenderData,
107         aad: &SenderDataAAD,
108     ) -> Result<Vec<u8>, MlsError> {
109         self.cipher_suite_provider
110             .aead_seal(
111                 &self.key,
112                 &sender_data.mls_encode_to_vec()?,
113                 Some(&aad.mls_encode_to_vec()?),
114                 &self.nonce,
115             )
116             .await
117             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
118     }
119 
120     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
open( &self, sender_data: &[u8], aad: &SenderDataAAD, ) -> Result<SenderData, MlsError>121     pub(crate) async fn open(
122         &self,
123         sender_data: &[u8],
124         aad: &SenderDataAAD,
125     ) -> Result<SenderData, MlsError> {
126         self.cipher_suite_provider
127             .aead_open(
128                 &self.key,
129                 sender_data,
130                 Some(&aad.mls_encode_to_vec()?),
131                 &self.nonce,
132             )
133             .await
134             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
135             .and_then(|data| SenderData::mls_decode(&mut &**data).map_err(From::from))
136     }
137 }
138 
139 #[cfg(test)]
140 pub(crate) mod test_utils {
141     use alloc::vec::Vec;
142     use mls_rs_core::crypto::CipherSuiteProvider;
143 
144     use super::SenderDataKey;
145 
146     #[derive(Debug, serde::Serialize, serde::Deserialize)]
147     pub struct InteropSenderData {
148         #[serde(with = "hex::serde")]
149         pub sender_data_secret: Vec<u8>,
150         #[serde(with = "hex::serde")]
151         pub ciphertext: Vec<u8>,
152         #[serde(with = "hex::serde")]
153         pub key: Vec<u8>,
154         #[serde(with = "hex::serde")]
155         pub nonce: Vec<u8>,
156     }
157 
158     impl InteropSenderData {
159         #[cfg(not(mls_build_async))]
160         #[cfg_attr(coverage_nightly, coverage(off))]
new<P: CipherSuiteProvider>(cs: &P) -> Self161         pub(crate) fn new<P: CipherSuiteProvider>(cs: &P) -> Self {
162             let secret = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap().into();
163             let ciphertext = cs.random_bytes_vec(77).unwrap();
164             let key = SenderDataKey::new(&secret, &ciphertext, cs).unwrap();
165             let secret = (*secret).clone();
166 
167             Self {
168                 ciphertext,
169                 key: key.key.to_vec(),
170                 nonce: key.nonce.to_vec(),
171                 sender_data_secret: secret,
172             }
173         }
174 
175         #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify<P: CipherSuiteProvider>(&self, cs: &P)176         pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) {
177             let secret = self.sender_data_secret.clone().into();
178 
179             let key = SenderDataKey::new(&secret, &self.ciphertext, cs)
180                 .await
181                 .unwrap();
182 
183             assert_eq!(key.key.to_vec(), self.key, "sender data key mismatch");
184             assert_eq!(key.nonce.to_vec(), self.nonce, "sender data nonce mismatch");
185         }
186     }
187 }
188 
189 #[cfg(test)]
190 mod tests {
191 
192     use alloc::vec::Vec;
193     #[cfg(target_arch = "wasm32")]
194     use wasm_bindgen_test::wasm_bindgen_test as test;
195 
196     use crate::{
197         crypto::test_utils::try_test_cipher_suite_provider,
198         group::{ciphertext_processor::reuse_guard::ReuseGuard, framing::ContentType},
199         tree_kem::node::LeafIndex,
200     };
201 
202     use super::{SenderData, SenderDataAAD, SenderDataKey};
203 
204     #[cfg(not(mls_build_async))]
205     use crate::{
206         cipher_suite::CipherSuite, crypto::test_utils::test_cipher_suite_provider,
207         group::test_utils::random_bytes, CipherSuiteProvider,
208     };
209 
210     #[derive(serde::Deserialize, serde::Serialize)]
211     struct TestCase {
212         cipher_suite: u16,
213         #[serde(with = "hex::serde")]
214         secret: Vec<u8>,
215         #[serde(with = "hex::serde")]
216         ciphertext_bytes: Vec<u8>,
217         #[serde(with = "hex::serde")]
218         expected_key: Vec<u8>,
219         #[serde(with = "hex::serde")]
220         expected_nonce: Vec<u8>,
221         sender_data: TestSenderData,
222         sender_data_aad: TestSenderDataAAD,
223         #[serde(with = "hex::serde")]
224         expected_ciphertext: Vec<u8>,
225     }
226 
227     #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
228     struct TestSenderData {
229         sender: u32,
230         generation: u32,
231         #[serde(with = "hex::serde")]
232         reuse_guard: Vec<u8>,
233     }
234 
235     impl From<TestSenderData> for SenderData {
from(value: TestSenderData) -> Self236         fn from(value: TestSenderData) -> Self {
237             let reuse_guard = ReuseGuard::new(value.reuse_guard);
238 
239             Self {
240                 sender: LeafIndex(value.sender),
241                 generation: value.generation,
242                 reuse_guard,
243             }
244         }
245     }
246 
247     #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
248     struct TestSenderDataAAD {
249         epoch: u64,
250         #[serde(with = "hex::serde")]
251         group_id: Vec<u8>,
252     }
253 
254     impl From<TestSenderDataAAD> for SenderDataAAD {
from(value: TestSenderDataAAD) -> Self255         fn from(value: TestSenderDataAAD) -> Self {
256             Self {
257                 epoch: value.epoch,
258                 group_id: value.group_id,
259                 content_type: ContentType::Application,
260             }
261         }
262     }
263 
264     #[cfg(not(mls_build_async))]
265     #[cfg_attr(coverage_nightly, coverage(off))]
generate_test_vector() -> Vec<TestCase>266     fn generate_test_vector() -> Vec<TestCase> {
267         let test_cases = CipherSuite::all().map(test_cipher_suite_provider).map(
268             #[cfg_attr(coverage_nightly, coverage(off))]
269             |provider| {
270                 let ext_size = provider.kdf_extract_size();
271                 let secret = random_bytes(ext_size).into();
272                 let ciphertext_sizes = [ext_size - 5, ext_size, ext_size + 5];
273 
274                 let sender_data = TestSenderData {
275                     sender: 0,
276                     generation: 13,
277                     reuse_guard: random_bytes(4),
278                 };
279 
280                 let sender_data_aad = TestSenderDataAAD {
281                     group_id: b"group".to_vec(),
282                     epoch: 42,
283                 };
284 
285                 ciphertext_sizes.into_iter().map(
286                     #[cfg_attr(coverage_nightly, coverage(off))]
287                     move |ciphertext_size| {
288                         let ciphertext_bytes = random_bytes(ciphertext_size);
289 
290                         let sender_data_key =
291                             SenderDataKey::new(&secret, &ciphertext_bytes, &provider).unwrap();
292 
293                         let expected_ciphertext = sender_data_key
294                             .seal(&sender_data.clone().into(), &sender_data_aad.clone().into())
295                             .unwrap();
296 
297                         TestCase {
298                             cipher_suite: provider.cipher_suite().into(),
299                             secret: secret.to_vec(),
300                             ciphertext_bytes,
301                             expected_key: sender_data_key.key.to_vec(),
302                             expected_nonce: sender_data_key.nonce.to_vec(),
303                             sender_data: sender_data.clone(),
304                             sender_data_aad: sender_data_aad.clone(),
305                             expected_ciphertext,
306                         }
307                     },
308                 )
309             },
310         );
311 
312         test_cases.flatten().collect()
313     }
314 
315     #[cfg(mls_build_async)]
generate_test_vector() -> Vec<TestCase>316     fn generate_test_vector() -> Vec<TestCase> {
317         panic!("Tests cannot be generated in async mode");
318     }
319 
load_test_cases() -> Vec<TestCase>320     fn load_test_cases() -> Vec<TestCase> {
321         load_test_case_json!(sender_data_key_test_vector, generate_test_vector())
322     }
323 
324     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
sender_data_key_test_vector()325     async fn sender_data_key_test_vector() {
326         for test_case in load_test_cases() {
327             let Some(provider) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
328                 continue;
329             };
330 
331             let sender_data_key = SenderDataKey::new(
332                 &test_case.secret.into(),
333                 &test_case.ciphertext_bytes,
334                 &provider,
335             )
336             .await
337             .unwrap();
338 
339             assert_eq!(sender_data_key.key.to_vec(), test_case.expected_key);
340             assert_eq!(sender_data_key.nonce.to_vec(), test_case.expected_nonce);
341 
342             let sender_data = test_case.sender_data.into();
343             let sender_data_aad = test_case.sender_data_aad.into();
344 
345             let ciphertext = sender_data_key
346                 .seal(&sender_data, &sender_data_aad)
347                 .await
348                 .unwrap();
349 
350             assert_eq!(ciphertext, test_case.expected_ciphertext);
351 
352             let plaintext = sender_data_key
353                 .open(&ciphertext, &sender_data_aad)
354                 .await
355                 .unwrap();
356 
357             assert_eq!(plaintext, sender_data);
358         }
359     }
360 }
361