1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use alloc::vec; 6 use alloc::vec::Vec; 7 use mls_rs_codec::{MlsDecode, MlsEncode}; 8 use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider, key_package::KeyPackageData}; 9 10 use crate::client::MlsError; 11 use crate::{ 12 crypto::{HpkeSecretKey, SignatureSecretKey}, 13 group::framing::MlsMessagePayload, 14 identity::SigningIdentity, 15 protocol_version::ProtocolVersion, 16 signer::Signable, 17 tree_kem::{ 18 leaf_node::{ConfigProperties, LeafNode}, 19 Capabilities, Lifetime, 20 }, 21 CipherSuiteProvider, ExtensionList, MlsMessage, 22 }; 23 24 use super::{KeyPackage, KeyPackageRef}; 25 26 #[derive(Clone, Debug)] 27 pub struct KeyPackageGenerator<'a, IP, CP> 28 where 29 IP: IdentityProvider, 30 CP: CipherSuiteProvider, 31 { 32 pub protocol_version: ProtocolVersion, 33 pub cipher_suite_provider: &'a CP, 34 pub signing_identity: &'a SigningIdentity, 35 pub signing_key: &'a SignatureSecretKey, 36 pub identity_provider: &'a IP, 37 } 38 39 #[derive(Clone, Debug)] 40 pub struct KeyPackageGeneration { 41 pub(crate) reference: KeyPackageRef, 42 pub(crate) key_package: KeyPackage, 43 pub(crate) init_secret_key: HpkeSecretKey, 44 pub(crate) leaf_node_secret_key: HpkeSecretKey, 45 } 46 47 impl KeyPackageGeneration { to_storage(&self) -> Result<(Vec<u8>, KeyPackageData), MlsError>48 pub fn to_storage(&self) -> Result<(Vec<u8>, KeyPackageData), MlsError> { 49 let id = self.reference.to_vec(); 50 51 let data = KeyPackageData::new( 52 self.key_package.mls_encode_to_vec()?, 53 self.init_secret_key.clone(), 54 self.leaf_node_secret_key.clone(), 55 self.key_package.expiration()?, 56 ); 57 58 Ok((id, data)) 59 } 60 from_storage(id: Vec<u8>, data: KeyPackageData) -> Result<Self, MlsError>61 pub fn from_storage(id: Vec<u8>, data: KeyPackageData) -> Result<Self, MlsError> { 62 Ok(KeyPackageGeneration { 63 reference: KeyPackageRef::from(id), 64 key_package: KeyPackage::mls_decode(&mut &*data.key_package_bytes)?, 65 init_secret_key: data.init_key, 66 leaf_node_secret_key: data.leaf_node_key, 67 }) 68 } 69 key_package_message(&self) -> MlsMessage70 pub fn key_package_message(&self) -> MlsMessage { 71 MlsMessage::new( 72 self.key_package.version, 73 MlsMessagePayload::KeyPackage(self.key_package.clone()), 74 ) 75 } 76 } 77 78 impl<'a, IP, CP> KeyPackageGenerator<'a, IP, CP> 79 where 80 IP: IdentityProvider, 81 CP: CipherSuiteProvider, 82 { 83 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] sign(&self, package: &mut KeyPackage) -> Result<(), MlsError>84 pub(super) async fn sign(&self, package: &mut KeyPackage) -> Result<(), MlsError> { 85 package 86 .sign(self.cipher_suite_provider, self.signing_key, &()) 87 .await 88 } 89 90 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] generate( &self, lifetime: Lifetime, capabilities: Capabilities, key_package_extensions: ExtensionList, leaf_node_extensions: ExtensionList, ) -> Result<KeyPackageGeneration, MlsError>91 pub async fn generate( 92 &self, 93 lifetime: Lifetime, 94 capabilities: Capabilities, 95 key_package_extensions: ExtensionList, 96 leaf_node_extensions: ExtensionList, 97 ) -> Result<KeyPackageGeneration, MlsError> { 98 let (init_secret_key, public_init) = self 99 .cipher_suite_provider 100 .kem_generate() 101 .await 102 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?; 103 104 let properties = ConfigProperties { 105 capabilities, 106 extensions: leaf_node_extensions, 107 }; 108 109 let (leaf_node, leaf_node_secret) = LeafNode::generate( 110 self.cipher_suite_provider, 111 properties, 112 self.signing_identity.clone(), 113 self.signing_key, 114 lifetime, 115 ) 116 .await?; 117 118 let mut package = KeyPackage { 119 version: self.protocol_version, 120 cipher_suite: self.cipher_suite_provider.cipher_suite(), 121 hpke_init_key: public_init, 122 leaf_node, 123 extensions: key_package_extensions, 124 signature: vec![], 125 }; 126 127 package.grease(self.cipher_suite_provider)?; 128 129 self.sign(&mut package).await?; 130 131 let reference = package.to_reference(self.cipher_suite_provider).await?; 132 133 Ok(KeyPackageGeneration { 134 key_package: package, 135 init_secret_key, 136 leaf_node_secret_key: leaf_node_secret, 137 reference, 138 }) 139 } 140 } 141 142 #[cfg(test)] 143 mod tests { 144 use assert_matches::assert_matches; 145 use mls_rs_core::crypto::CipherSuiteProvider; 146 147 use crate::{ 148 crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider}, 149 extension::test_utils::TestExtension, 150 group::test_utils::random_bytes, 151 identity::basic::BasicIdentityProvider, 152 identity::test_utils::get_test_signing_identity, 153 key_package::validate_key_package_properties, 154 protocol_version::ProtocolVersion, 155 tree_kem::{ 156 leaf_node::{test_utils::get_test_capabilities, LeafNodeSource}, 157 leaf_node_validator::{LeafNodeValidator, ValidationContext}, 158 Lifetime, 159 }, 160 ExtensionList, 161 }; 162 163 use super::KeyPackageGenerator; 164 test_key_package_ext(val: u8) -> ExtensionList165 fn test_key_package_ext(val: u8) -> ExtensionList { 166 let mut ext_list = ExtensionList::new(); 167 ext_list.set_from(TestExtension::from(val)).unwrap(); 168 ext_list 169 } 170 test_leaf_node_ext(val: u8) -> ExtensionList171 fn test_leaf_node_ext(val: u8) -> ExtensionList { 172 let mut ext_list = ExtensionList::new(); 173 ext_list.set_from(TestExtension::from(val)).unwrap(); 174 ext_list 175 } 176 test_lifetime() -> Lifetime177 fn test_lifetime() -> Lifetime { 178 Lifetime::years(1).unwrap() 179 } 180 181 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_key_generation()182 async fn test_key_generation() { 183 for (protocol_version, cipher_suite) in ProtocolVersion::all().flat_map(|p| { 184 TestCryptoProvider::all_supported_cipher_suites() 185 .into_iter() 186 .map(move |cs| (p, cs)) 187 }) { 188 let cipher_suite_provider = test_cipher_suite_provider(cipher_suite); 189 190 let (signing_identity, signing_key) = 191 get_test_signing_identity(cipher_suite, b"foo").await; 192 193 let key_package_ext = test_key_package_ext(32); 194 let leaf_node_ext = test_leaf_node_ext(42); 195 let lifetime = test_lifetime(); 196 197 let test_generator = KeyPackageGenerator { 198 protocol_version, 199 cipher_suite_provider: &cipher_suite_provider, 200 signing_identity: &signing_identity, 201 signing_key: &signing_key, 202 identity_provider: &BasicIdentityProvider, 203 }; 204 205 let mut capabilities = get_test_capabilities(); 206 capabilities.extensions.push(42.into()); 207 capabilities.extensions.push(43.into()); 208 capabilities.extensions.push(32.into()); 209 210 let generated = test_generator 211 .generate( 212 lifetime.clone(), 213 capabilities.clone(), 214 key_package_ext.clone(), 215 leaf_node_ext.clone(), 216 ) 217 .await 218 .unwrap(); 219 220 assert_matches!(generated.key_package.leaf_node.leaf_node_source, 221 LeafNodeSource::KeyPackage(ref lt) if lt == &lifetime); 222 223 assert_eq!( 224 generated.key_package.leaf_node.ungreased_capabilities(), 225 capabilities 226 ); 227 228 assert_eq!( 229 generated.key_package.leaf_node.ungreased_extensions(), 230 leaf_node_ext 231 ); 232 233 assert_eq!( 234 generated.key_package.ungreased_extensions(), 235 key_package_ext 236 ); 237 238 assert_ne!( 239 generated.key_package.hpke_init_key.as_ref(), 240 generated.key_package.leaf_node.public_key.as_ref() 241 ); 242 243 assert_eq!(generated.key_package.cipher_suite, cipher_suite); 244 assert_eq!(generated.key_package.version, protocol_version); 245 246 // Verify that the hpke key pair generated will work 247 let test_data = random_bytes(32); 248 249 let sealed = cipher_suite_provider 250 .hpke_seal(&generated.key_package.hpke_init_key, &[], None, &test_data) 251 .await 252 .unwrap(); 253 254 let opened = cipher_suite_provider 255 .hpke_open( 256 &sealed, 257 &generated.init_secret_key, 258 &generated.key_package.hpke_init_key, 259 &[], 260 None, 261 ) 262 .await 263 .unwrap(); 264 265 assert_eq!(opened, test_data); 266 267 let validator = 268 LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None); 269 270 validator 271 .check_if_valid( 272 &generated.key_package.leaf_node, 273 ValidationContext::Add(None), 274 ) 275 .await 276 .unwrap(); 277 278 validate_key_package_properties( 279 &generated.key_package, 280 protocol_version, 281 &cipher_suite_provider, 282 ) 283 .await 284 .unwrap(); 285 } 286 } 287 288 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_randomness()289 async fn test_randomness() { 290 for (protocol_version, cipher_suite) in ProtocolVersion::all().flat_map(|p| { 291 TestCryptoProvider::all_supported_cipher_suites() 292 .into_iter() 293 .map(move |cs| (p, cs)) 294 }) { 295 let (signing_identity, signing_key) = 296 get_test_signing_identity(cipher_suite, b"foo").await; 297 298 let test_generator = KeyPackageGenerator { 299 protocol_version, 300 cipher_suite_provider: &test_cipher_suite_provider(cipher_suite), 301 signing_identity: &signing_identity, 302 signing_key: &signing_key, 303 identity_provider: &BasicIdentityProvider, 304 }; 305 306 let first_key_package = test_generator 307 .generate( 308 test_lifetime(), 309 get_test_capabilities(), 310 ExtensionList::default(), 311 ExtensionList::default(), 312 ) 313 .await 314 .unwrap(); 315 316 for _ in 0..100 { 317 let next_key_package = test_generator 318 .generate( 319 test_lifetime(), 320 get_test_capabilities(), 321 ExtensionList::default(), 322 ExtensionList::default(), 323 ) 324 .await 325 .unwrap(); 326 327 assert_ne!( 328 first_key_package.key_package.hpke_init_key, 329 next_key_package.key_package.hpke_init_key 330 ); 331 332 assert_ne!( 333 first_key_package.key_package.leaf_node.public_key, 334 next_key_package.key_package.leaf_node.public_key 335 ); 336 } 337 } 338 } 339 } 340