1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec;
6 use alloc::vec::Vec;
7 use mls_rs_codec::{MlsDecode, MlsEncode};
8 use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider, key_package::KeyPackageData};
9 
10 use crate::client::MlsError;
11 use crate::{
12     crypto::{HpkeSecretKey, SignatureSecretKey},
13     group::framing::MlsMessagePayload,
14     identity::SigningIdentity,
15     protocol_version::ProtocolVersion,
16     signer::Signable,
17     tree_kem::{
18         leaf_node::{ConfigProperties, LeafNode},
19         Capabilities, Lifetime,
20     },
21     CipherSuiteProvider, ExtensionList, MlsMessage,
22 };
23 
24 use super::{KeyPackage, KeyPackageRef};
25 
26 #[derive(Clone, Debug)]
27 pub struct KeyPackageGenerator<'a, IP, CP>
28 where
29     IP: IdentityProvider,
30     CP: CipherSuiteProvider,
31 {
32     pub protocol_version: ProtocolVersion,
33     pub cipher_suite_provider: &'a CP,
34     pub signing_identity: &'a SigningIdentity,
35     pub signing_key: &'a SignatureSecretKey,
36     pub identity_provider: &'a IP,
37 }
38 
39 #[derive(Clone, Debug)]
40 pub struct KeyPackageGeneration {
41     pub(crate) reference: KeyPackageRef,
42     pub(crate) key_package: KeyPackage,
43     pub(crate) init_secret_key: HpkeSecretKey,
44     pub(crate) leaf_node_secret_key: HpkeSecretKey,
45 }
46 
47 impl KeyPackageGeneration {
to_storage(&self) -> Result<(Vec<u8>, KeyPackageData), MlsError>48     pub fn to_storage(&self) -> Result<(Vec<u8>, KeyPackageData), MlsError> {
49         let id = self.reference.to_vec();
50 
51         let data = KeyPackageData::new(
52             self.key_package.mls_encode_to_vec()?,
53             self.init_secret_key.clone(),
54             self.leaf_node_secret_key.clone(),
55             self.key_package.expiration()?,
56         );
57 
58         Ok((id, data))
59     }
60 
from_storage(id: Vec<u8>, data: KeyPackageData) -> Result<Self, MlsError>61     pub fn from_storage(id: Vec<u8>, data: KeyPackageData) -> Result<Self, MlsError> {
62         Ok(KeyPackageGeneration {
63             reference: KeyPackageRef::from(id),
64             key_package: KeyPackage::mls_decode(&mut &*data.key_package_bytes)?,
65             init_secret_key: data.init_key,
66             leaf_node_secret_key: data.leaf_node_key,
67         })
68     }
69 
key_package_message(&self) -> MlsMessage70     pub fn key_package_message(&self) -> MlsMessage {
71         MlsMessage::new(
72             self.key_package.version,
73             MlsMessagePayload::KeyPackage(self.key_package.clone()),
74         )
75     }
76 }
77 
78 impl<'a, IP, CP> KeyPackageGenerator<'a, IP, CP>
79 where
80     IP: IdentityProvider,
81     CP: CipherSuiteProvider,
82 {
83     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
sign(&self, package: &mut KeyPackage) -> Result<(), MlsError>84     pub(super) async fn sign(&self, package: &mut KeyPackage) -> Result<(), MlsError> {
85         package
86             .sign(self.cipher_suite_provider, self.signing_key, &())
87             .await
88     }
89 
90     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
generate( &self, lifetime: Lifetime, capabilities: Capabilities, key_package_extensions: ExtensionList, leaf_node_extensions: ExtensionList, ) -> Result<KeyPackageGeneration, MlsError>91     pub async fn generate(
92         &self,
93         lifetime: Lifetime,
94         capabilities: Capabilities,
95         key_package_extensions: ExtensionList,
96         leaf_node_extensions: ExtensionList,
97     ) -> Result<KeyPackageGeneration, MlsError> {
98         let (init_secret_key, public_init) = self
99             .cipher_suite_provider
100             .kem_generate()
101             .await
102             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
103 
104         let properties = ConfigProperties {
105             capabilities,
106             extensions: leaf_node_extensions,
107         };
108 
109         let (leaf_node, leaf_node_secret) = LeafNode::generate(
110             self.cipher_suite_provider,
111             properties,
112             self.signing_identity.clone(),
113             self.signing_key,
114             lifetime,
115         )
116         .await?;
117 
118         let mut package = KeyPackage {
119             version: self.protocol_version,
120             cipher_suite: self.cipher_suite_provider.cipher_suite(),
121             hpke_init_key: public_init,
122             leaf_node,
123             extensions: key_package_extensions,
124             signature: vec![],
125         };
126 
127         package.grease(self.cipher_suite_provider)?;
128 
129         self.sign(&mut package).await?;
130 
131         let reference = package.to_reference(self.cipher_suite_provider).await?;
132 
133         Ok(KeyPackageGeneration {
134             key_package: package,
135             init_secret_key,
136             leaf_node_secret_key: leaf_node_secret,
137             reference,
138         })
139     }
140 }
141 
142 #[cfg(test)]
143 mod tests {
144     use assert_matches::assert_matches;
145     use mls_rs_core::crypto::CipherSuiteProvider;
146 
147     use crate::{
148         crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
149         extension::test_utils::TestExtension,
150         group::test_utils::random_bytes,
151         identity::basic::BasicIdentityProvider,
152         identity::test_utils::get_test_signing_identity,
153         key_package::validate_key_package_properties,
154         protocol_version::ProtocolVersion,
155         tree_kem::{
156             leaf_node::{test_utils::get_test_capabilities, LeafNodeSource},
157             leaf_node_validator::{LeafNodeValidator, ValidationContext},
158             Lifetime,
159         },
160         ExtensionList,
161     };
162 
163     use super::KeyPackageGenerator;
164 
test_key_package_ext(val: u8) -> ExtensionList165     fn test_key_package_ext(val: u8) -> ExtensionList {
166         let mut ext_list = ExtensionList::new();
167         ext_list.set_from(TestExtension::from(val)).unwrap();
168         ext_list
169     }
170 
test_leaf_node_ext(val: u8) -> ExtensionList171     fn test_leaf_node_ext(val: u8) -> ExtensionList {
172         let mut ext_list = ExtensionList::new();
173         ext_list.set_from(TestExtension::from(val)).unwrap();
174         ext_list
175     }
176 
test_lifetime() -> Lifetime177     fn test_lifetime() -> Lifetime {
178         Lifetime::years(1).unwrap()
179     }
180 
181     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_key_generation()182     async fn test_key_generation() {
183         for (protocol_version, cipher_suite) in ProtocolVersion::all().flat_map(|p| {
184             TestCryptoProvider::all_supported_cipher_suites()
185                 .into_iter()
186                 .map(move |cs| (p, cs))
187         }) {
188             let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
189 
190             let (signing_identity, signing_key) =
191                 get_test_signing_identity(cipher_suite, b"foo").await;
192 
193             let key_package_ext = test_key_package_ext(32);
194             let leaf_node_ext = test_leaf_node_ext(42);
195             let lifetime = test_lifetime();
196 
197             let test_generator = KeyPackageGenerator {
198                 protocol_version,
199                 cipher_suite_provider: &cipher_suite_provider,
200                 signing_identity: &signing_identity,
201                 signing_key: &signing_key,
202                 identity_provider: &BasicIdentityProvider,
203             };
204 
205             let mut capabilities = get_test_capabilities();
206             capabilities.extensions.push(42.into());
207             capabilities.extensions.push(43.into());
208             capabilities.extensions.push(32.into());
209 
210             let generated = test_generator
211                 .generate(
212                     lifetime.clone(),
213                     capabilities.clone(),
214                     key_package_ext.clone(),
215                     leaf_node_ext.clone(),
216                 )
217                 .await
218                 .unwrap();
219 
220             assert_matches!(generated.key_package.leaf_node.leaf_node_source,
221                             LeafNodeSource::KeyPackage(ref lt) if lt == &lifetime);
222 
223             assert_eq!(
224                 generated.key_package.leaf_node.ungreased_capabilities(),
225                 capabilities
226             );
227 
228             assert_eq!(
229                 generated.key_package.leaf_node.ungreased_extensions(),
230                 leaf_node_ext
231             );
232 
233             assert_eq!(
234                 generated.key_package.ungreased_extensions(),
235                 key_package_ext
236             );
237 
238             assert_ne!(
239                 generated.key_package.hpke_init_key.as_ref(),
240                 generated.key_package.leaf_node.public_key.as_ref()
241             );
242 
243             assert_eq!(generated.key_package.cipher_suite, cipher_suite);
244             assert_eq!(generated.key_package.version, protocol_version);
245 
246             // Verify that the hpke key pair generated will work
247             let test_data = random_bytes(32);
248 
249             let sealed = cipher_suite_provider
250                 .hpke_seal(&generated.key_package.hpke_init_key, &[], None, &test_data)
251                 .await
252                 .unwrap();
253 
254             let opened = cipher_suite_provider
255                 .hpke_open(
256                     &sealed,
257                     &generated.init_secret_key,
258                     &generated.key_package.hpke_init_key,
259                     &[],
260                     None,
261                 )
262                 .await
263                 .unwrap();
264 
265             assert_eq!(opened, test_data);
266 
267             let validator =
268                 LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None);
269 
270             validator
271                 .check_if_valid(
272                     &generated.key_package.leaf_node,
273                     ValidationContext::Add(None),
274                 )
275                 .await
276                 .unwrap();
277 
278             validate_key_package_properties(
279                 &generated.key_package,
280                 protocol_version,
281                 &cipher_suite_provider,
282             )
283             .await
284             .unwrap();
285         }
286     }
287 
288     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_randomness()289     async fn test_randomness() {
290         for (protocol_version, cipher_suite) in ProtocolVersion::all().flat_map(|p| {
291             TestCryptoProvider::all_supported_cipher_suites()
292                 .into_iter()
293                 .map(move |cs| (p, cs))
294         }) {
295             let (signing_identity, signing_key) =
296                 get_test_signing_identity(cipher_suite, b"foo").await;
297 
298             let test_generator = KeyPackageGenerator {
299                 protocol_version,
300                 cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
301                 signing_identity: &signing_identity,
302                 signing_key: &signing_key,
303                 identity_provider: &BasicIdentityProvider,
304             };
305 
306             let first_key_package = test_generator
307                 .generate(
308                     test_lifetime(),
309                     get_test_capabilities(),
310                     ExtensionList::default(),
311                     ExtensionList::default(),
312                 )
313                 .await
314                 .unwrap();
315 
316             for _ in 0..100 {
317                 let next_key_package = test_generator
318                     .generate(
319                         test_lifetime(),
320                         get_test_capabilities(),
321                         ExtensionList::default(),
322                         ExtensionList::default(),
323                     )
324                     .await
325                     .unwrap();
326 
327                 assert_ne!(
328                     first_key_package.key_package.hpke_init_key,
329                     next_key_package.key_package.hpke_init_key
330                 );
331 
332                 assert_ne!(
333                     first_key_package.key_package.leaf_node.public_key,
334                     next_key_package.key_package.leaf_node.public_key
335                 );
336             }
337         }
338     }
339 }
340