1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use core::ops::Deref;
6 
7 use super::*;
8 use crate::hash_reference::HashReference;
9 
10 #[cfg_attr(
11     all(feature = "ffi", not(test)),
12     safer_ffi_gen::ffi_type(clone, opaque)
13 )]
14 #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)]
15 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
16 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
17 /// Unique identifier for a proposal message.
18 pub struct ProposalRef(HashReference);
19 
20 impl Deref for ProposalRef {
21     type Target = [u8];
22 
deref(&self) -> &Self::Target23     fn deref(&self) -> &Self::Target {
24         &self.0
25     }
26 }
27 
28 #[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
29 impl ProposalRef {
30     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
from_content<CS: CipherSuiteProvider>( cipher_suite_provider: &CS, content: &AuthenticatedContent, ) -> Result<Self, MlsError>31     pub(crate) async fn from_content<CS: CipherSuiteProvider>(
32         cipher_suite_provider: &CS,
33         content: &AuthenticatedContent,
34     ) -> Result<Self, MlsError> {
35         let bytes = &content.mls_encode_to_vec()?;
36 
37         Ok(ProposalRef(
38             HashReference::compute(bytes, b"MLS 1.0 Proposal Reference", cipher_suite_provider)
39                 .await?,
40         ))
41     }
42 
as_slice(&self) -> &[u8]43     pub fn as_slice(&self) -> &[u8] {
44         &self.0
45     }
46 }
47 
48 #[cfg(test)]
49 pub(crate) mod test_utils {
50     use super::*;
51     use crate::group::test_utils::{random_bytes, TEST_GROUP};
52     use alloc::boxed::Box;
53 
54     impl ProposalRef {
new_fake(bytes: Vec<u8>) -> Self55         pub fn new_fake(bytes: Vec<u8>) -> Self {
56             Self(bytes.into())
57         }
58     }
59 
auth_content_from_proposal<S>(proposal: Proposal, sender: S) -> AuthenticatedContent where S: Into<Sender>,60     pub fn auth_content_from_proposal<S>(proposal: Proposal, sender: S) -> AuthenticatedContent
61     where
62         S: Into<Sender>,
63     {
64         AuthenticatedContent {
65             wire_format: WireFormat::PublicMessage,
66             content: FramedContent {
67                 group_id: TEST_GROUP.to_vec(),
68                 epoch: 0,
69                 sender: sender.into(),
70                 authenticated_data: vec![],
71                 content: Content::Proposal(Box::new(proposal)),
72             },
73             auth: FramedContentAuthData {
74                 signature: MessageSignature::from(random_bytes(128)),
75                 confirmation_tag: None,
76             },
77         }
78     }
79 }
80 
81 #[cfg(test)]
82 mod test {
83     use super::test_utils::auth_content_from_proposal;
84     use super::*;
85     use crate::{
86         crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
87         key_package::test_utils::test_key_package,
88         tree_kem::leaf_node::test_utils::get_basic_test_node,
89     };
90     use alloc::boxed::Box;
91 
92     use crate::extension::RequiredCapabilitiesExt;
93 
94     #[cfg_attr(coverage_nightly, coverage(off))]
get_test_extension_list() -> ExtensionList95     fn get_test_extension_list() -> ExtensionList {
96         let test_extension = RequiredCapabilitiesExt {
97             extensions: vec![42.into()],
98             proposals: Default::default(),
99             credentials: vec![],
100         };
101 
102         let mut extension_list = ExtensionList::new();
103         extension_list.set_from(test_extension).unwrap();
104 
105         extension_list
106     }
107 
108     #[derive(serde::Serialize, serde::Deserialize)]
109     struct TestCase {
110         cipher_suite: u16,
111         #[serde(with = "hex::serde")]
112         input: Vec<u8>,
113         #[serde(with = "hex::serde")]
114         output: Vec<u8>,
115     }
116 
117     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
118     #[cfg_attr(coverage_nightly, coverage(off))]
generate_proposal_test_cases() -> Vec<TestCase>119     async fn generate_proposal_test_cases() -> Vec<TestCase> {
120         let mut test_cases = Vec::new();
121 
122         for (protocol_version, cipher_suite) in
123             ProtocolVersion::all().flat_map(|p| CipherSuite::all().map(move |cs| (p, cs)))
124         {
125             let sender = LeafIndex(0);
126 
127             let add = auth_content_from_proposal(
128                 Proposal::Add(Box::new(AddProposal {
129                     key_package: test_key_package(protocol_version, cipher_suite, "alice").await,
130                 })),
131                 sender,
132             );
133 
134             let update = auth_content_from_proposal(
135                 Proposal::Update(UpdateProposal {
136                     leaf_node: get_basic_test_node(cipher_suite, "foo").await,
137                 }),
138                 sender,
139             );
140 
141             let remove = auth_content_from_proposal(
142                 Proposal::Remove(RemoveProposal {
143                     to_remove: LeafIndex(1),
144                 }),
145                 sender,
146             );
147 
148             let group_context_ext = auth_content_from_proposal(
149                 Proposal::GroupContextExtensions(get_test_extension_list()),
150                 sender,
151             );
152 
153             let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
154 
155             test_cases.push(TestCase {
156                 cipher_suite: cipher_suite.into(),
157                 input: add.mls_encode_to_vec().unwrap(),
158                 output: ProposalRef::from_content(&cipher_suite_provider, &add)
159                     .await
160                     .unwrap()
161                     .to_vec(),
162             });
163 
164             test_cases.push(TestCase {
165                 cipher_suite: cipher_suite.into(),
166                 input: update.mls_encode_to_vec().unwrap(),
167                 output: ProposalRef::from_content(&cipher_suite_provider, &update)
168                     .await
169                     .unwrap()
170                     .to_vec(),
171             });
172 
173             test_cases.push(TestCase {
174                 cipher_suite: cipher_suite.into(),
175                 input: remove.mls_encode_to_vec().unwrap(),
176                 output: ProposalRef::from_content(&cipher_suite_provider, &remove)
177                     .await
178                     .unwrap()
179                     .to_vec(),
180             });
181 
182             test_cases.push(TestCase {
183                 cipher_suite: cipher_suite.into(),
184                 input: group_context_ext.mls_encode_to_vec().unwrap(),
185                 output: ProposalRef::from_content(&cipher_suite_provider, &group_context_ext)
186                     .await
187                     .unwrap()
188                     .to_vec(),
189             });
190         }
191 
192         test_cases
193     }
194 
195     #[cfg(mls_build_async)]
load_test_cases() -> Vec<TestCase>196     async fn load_test_cases() -> Vec<TestCase> {
197         load_test_case_json!(proposal_ref, generate_proposal_test_cases().await)
198     }
199 
200     #[cfg(not(mls_build_async))]
load_test_cases() -> Vec<TestCase>201     fn load_test_cases() -> Vec<TestCase> {
202         load_test_case_json!(proposal_ref, generate_proposal_test_cases())
203     }
204 
205     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_proposal_ref()206     async fn test_proposal_ref() {
207         let test_cases = load_test_cases().await;
208 
209         for one_case in test_cases {
210             let Some(cs_provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else {
211                 continue;
212             };
213 
214             let proposal_content =
215                 AuthenticatedContent::mls_decode(&mut one_case.input.as_slice()).unwrap();
216 
217             let proposal_ref = ProposalRef::from_content(&cs_provider, &proposal_content)
218                 .await
219                 .unwrap();
220 
221             let expected_out = ProposalRef(HashReference::from(one_case.output));
222 
223             assert_eq!(expected_out, proposal_ref);
224         }
225     }
226 }
227