1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use core::ops::Deref; 6 7 use super::*; 8 use crate::hash_reference::HashReference; 9 10 #[cfg_attr( 11 all(feature = "ffi", not(test)), 12 safer_ffi_gen::ffi_type(clone, opaque) 13 )] 14 #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)] 15 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 16 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 17 /// Unique identifier for a proposal message. 18 pub struct ProposalRef(HashReference); 19 20 impl Deref for ProposalRef { 21 type Target = [u8]; 22 deref(&self) -> &Self::Target23 fn deref(&self) -> &Self::Target { 24 &self.0 25 } 26 } 27 28 #[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)] 29 impl ProposalRef { 30 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] from_content<CS: CipherSuiteProvider>( cipher_suite_provider: &CS, content: &AuthenticatedContent, ) -> Result<Self, MlsError>31 pub(crate) async fn from_content<CS: CipherSuiteProvider>( 32 cipher_suite_provider: &CS, 33 content: &AuthenticatedContent, 34 ) -> Result<Self, MlsError> { 35 let bytes = &content.mls_encode_to_vec()?; 36 37 Ok(ProposalRef( 38 HashReference::compute(bytes, b"MLS 1.0 Proposal Reference", cipher_suite_provider) 39 .await?, 40 )) 41 } 42 as_slice(&self) -> &[u8]43 pub fn as_slice(&self) -> &[u8] { 44 &self.0 45 } 46 } 47 48 #[cfg(test)] 49 pub(crate) mod test_utils { 50 use super::*; 51 use crate::group::test_utils::{random_bytes, TEST_GROUP}; 52 use alloc::boxed::Box; 53 54 impl ProposalRef { new_fake(bytes: Vec<u8>) -> Self55 pub fn new_fake(bytes: Vec<u8>) -> Self { 56 Self(bytes.into()) 57 } 58 } 59 auth_content_from_proposal<S>(proposal: Proposal, sender: S) -> AuthenticatedContent where S: Into<Sender>,60 pub fn auth_content_from_proposal<S>(proposal: Proposal, sender: S) -> AuthenticatedContent 61 where 62 S: Into<Sender>, 63 { 64 AuthenticatedContent { 65 wire_format: WireFormat::PublicMessage, 66 content: FramedContent { 67 group_id: TEST_GROUP.to_vec(), 68 epoch: 0, 69 sender: sender.into(), 70 authenticated_data: vec![], 71 content: Content::Proposal(Box::new(proposal)), 72 }, 73 auth: FramedContentAuthData { 74 signature: MessageSignature::from(random_bytes(128)), 75 confirmation_tag: None, 76 }, 77 } 78 } 79 } 80 81 #[cfg(test)] 82 mod test { 83 use super::test_utils::auth_content_from_proposal; 84 use super::*; 85 use crate::{ 86 crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider}, 87 key_package::test_utils::test_key_package, 88 tree_kem::leaf_node::test_utils::get_basic_test_node, 89 }; 90 use alloc::boxed::Box; 91 92 use crate::extension::RequiredCapabilitiesExt; 93 94 #[cfg_attr(coverage_nightly, coverage(off))] get_test_extension_list() -> ExtensionList95 fn get_test_extension_list() -> ExtensionList { 96 let test_extension = RequiredCapabilitiesExt { 97 extensions: vec![42.into()], 98 proposals: Default::default(), 99 credentials: vec![], 100 }; 101 102 let mut extension_list = ExtensionList::new(); 103 extension_list.set_from(test_extension).unwrap(); 104 105 extension_list 106 } 107 108 #[derive(serde::Serialize, serde::Deserialize)] 109 struct TestCase { 110 cipher_suite: u16, 111 #[serde(with = "hex::serde")] 112 input: Vec<u8>, 113 #[serde(with = "hex::serde")] 114 output: Vec<u8>, 115 } 116 117 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] 118 #[cfg_attr(coverage_nightly, coverage(off))] generate_proposal_test_cases() -> Vec<TestCase>119 async fn generate_proposal_test_cases() -> Vec<TestCase> { 120 let mut test_cases = Vec::new(); 121 122 for (protocol_version, cipher_suite) in 123 ProtocolVersion::all().flat_map(|p| CipherSuite::all().map(move |cs| (p, cs))) 124 { 125 let sender = LeafIndex(0); 126 127 let add = auth_content_from_proposal( 128 Proposal::Add(Box::new(AddProposal { 129 key_package: test_key_package(protocol_version, cipher_suite, "alice").await, 130 })), 131 sender, 132 ); 133 134 let update = auth_content_from_proposal( 135 Proposal::Update(UpdateProposal { 136 leaf_node: get_basic_test_node(cipher_suite, "foo").await, 137 }), 138 sender, 139 ); 140 141 let remove = auth_content_from_proposal( 142 Proposal::Remove(RemoveProposal { 143 to_remove: LeafIndex(1), 144 }), 145 sender, 146 ); 147 148 let group_context_ext = auth_content_from_proposal( 149 Proposal::GroupContextExtensions(get_test_extension_list()), 150 sender, 151 ); 152 153 let cipher_suite_provider = test_cipher_suite_provider(cipher_suite); 154 155 test_cases.push(TestCase { 156 cipher_suite: cipher_suite.into(), 157 input: add.mls_encode_to_vec().unwrap(), 158 output: ProposalRef::from_content(&cipher_suite_provider, &add) 159 .await 160 .unwrap() 161 .to_vec(), 162 }); 163 164 test_cases.push(TestCase { 165 cipher_suite: cipher_suite.into(), 166 input: update.mls_encode_to_vec().unwrap(), 167 output: ProposalRef::from_content(&cipher_suite_provider, &update) 168 .await 169 .unwrap() 170 .to_vec(), 171 }); 172 173 test_cases.push(TestCase { 174 cipher_suite: cipher_suite.into(), 175 input: remove.mls_encode_to_vec().unwrap(), 176 output: ProposalRef::from_content(&cipher_suite_provider, &remove) 177 .await 178 .unwrap() 179 .to_vec(), 180 }); 181 182 test_cases.push(TestCase { 183 cipher_suite: cipher_suite.into(), 184 input: group_context_ext.mls_encode_to_vec().unwrap(), 185 output: ProposalRef::from_content(&cipher_suite_provider, &group_context_ext) 186 .await 187 .unwrap() 188 .to_vec(), 189 }); 190 } 191 192 test_cases 193 } 194 195 #[cfg(mls_build_async)] load_test_cases() -> Vec<TestCase>196 async fn load_test_cases() -> Vec<TestCase> { 197 load_test_case_json!(proposal_ref, generate_proposal_test_cases().await) 198 } 199 200 #[cfg(not(mls_build_async))] load_test_cases() -> Vec<TestCase>201 fn load_test_cases() -> Vec<TestCase> { 202 load_test_case_json!(proposal_ref, generate_proposal_test_cases()) 203 } 204 205 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_proposal_ref()206 async fn test_proposal_ref() { 207 let test_cases = load_test_cases().await; 208 209 for one_case in test_cases { 210 let Some(cs_provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else { 211 continue; 212 }; 213 214 let proposal_content = 215 AuthenticatedContent::mls_decode(&mut one_case.input.as_slice()).unwrap(); 216 217 let proposal_ref = ProposalRef::from_content(&cs_provider, &proposal_content) 218 .await 219 .unwrap(); 220 221 let expected_out = ProposalRef(HashReference::from(one_case.output)); 222 223 assert_eq!(expected_out, proposal_ref); 224 } 225 } 226 } 227