1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec::Vec;
6 use core::fmt::{self, Debug};
7 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
8 use mls_rs_core::extension::{ExtensionType, MlsCodecExtension};
9 
10 use mls_rs_core::{group::ProposalType, identity::CredentialType};
11 
12 #[cfg(feature = "by_ref_proposal")]
13 use mls_rs_core::{
14     extension::ExtensionList,
15     identity::{IdentityProvider, SigningIdentity},
16     time::MlsTime,
17 };
18 
19 use crate::group::ExportedTree;
20 
21 use mls_rs_core::crypto::HpkePublicKey;
22 
23 /// Application specific identifier.
24 ///
25 /// A custom application level identifier that can be optionally stored
26 /// within the `leaf_node_extensions` of a group [Member](crate::group::Member).
27 #[cfg_attr(
28     all(feature = "ffi", not(test)),
29     safer_ffi_gen::ffi_type(clone, opaque)
30 )]
31 #[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
32 pub struct ApplicationIdExt {
33     /// Application level identifier presented by this extension.
34     #[mls_codec(with = "mls_rs_codec::byte_vec")]
35     pub identifier: Vec<u8>,
36 }
37 
38 impl Debug for ApplicationIdExt {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result39     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40         f.debug_struct("ApplicationIdExt")
41             .field(
42                 "identifier",
43                 &mls_rs_core::debug::pretty_bytes(&self.identifier),
44             )
45             .finish()
46     }
47 }
48 
49 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
50 impl ApplicationIdExt {
51     /// Create a new application level identifier extension.
new(identifier: Vec<u8>) -> Self52     pub fn new(identifier: Vec<u8>) -> Self {
53         ApplicationIdExt { identifier }
54     }
55 
56     /// Get the application level identifier presented by this extension.
57     #[cfg(feature = "ffi")]
identifier(&self) -> &[u8]58     pub fn identifier(&self) -> &[u8] {
59         &self.identifier
60     }
61 }
62 
63 impl MlsCodecExtension for ApplicationIdExt {
extension_type() -> ExtensionType64     fn extension_type() -> ExtensionType {
65         ExtensionType::APPLICATION_ID
66     }
67 }
68 
69 /// Representation of an MLS ratchet tree.
70 ///
71 /// Used to provide new members
72 /// a copy of the current group state in-band.
73 #[cfg_attr(
74     all(feature = "ffi", not(test)),
75     safer_ffi_gen::ffi_type(clone, opaque)
76 )]
77 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
78 pub struct RatchetTreeExt {
79     pub tree_data: ExportedTree<'static>,
80 }
81 
82 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
83 impl RatchetTreeExt {
84     /// Required custom extension types.
85     #[cfg(feature = "ffi")]
tree_data(&self) -> &ExportedTree<'static>86     pub fn tree_data(&self) -> &ExportedTree<'static> {
87         &self.tree_data
88     }
89 }
90 
91 impl MlsCodecExtension for RatchetTreeExt {
extension_type() -> ExtensionType92     fn extension_type() -> ExtensionType {
93         ExtensionType::RATCHET_TREE
94     }
95 }
96 
97 /// Require members to have certain capabilities.
98 ///
99 /// Used within a
100 /// [Group Context Extensions Proposal](crate::group::proposal::Proposal)
101 /// in order to require that all current and future members of a group MUST
102 /// support specific extensions, proposals, or credentials.
103 ///
104 /// # Warning
105 ///
106 /// Extension, proposal, and credential types defined by the MLS RFC and
107 /// provided are considered required by default and should NOT be used
108 /// within this extension.
109 #[cfg_attr(
110     all(feature = "ffi", not(test)),
111     safer_ffi_gen::ffi_type(clone, opaque)
112 )]
113 #[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, Default)]
114 pub struct RequiredCapabilitiesExt {
115     pub extensions: Vec<ExtensionType>,
116     pub proposals: Vec<ProposalType>,
117     pub credentials: Vec<CredentialType>,
118 }
119 
120 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
121 impl RequiredCapabilitiesExt {
122     /// Create a required capabilities extension.
new( extensions: Vec<ExtensionType>, proposals: Vec<ProposalType>, credentials: Vec<CredentialType>, ) -> Self123     pub fn new(
124         extensions: Vec<ExtensionType>,
125         proposals: Vec<ProposalType>,
126         credentials: Vec<CredentialType>,
127     ) -> Self {
128         Self {
129             extensions,
130             proposals,
131             credentials,
132         }
133     }
134 
135     /// Required custom extension types.
136     #[cfg(feature = "ffi")]
extensions(&self) -> &[ExtensionType]137     pub fn extensions(&self) -> &[ExtensionType] {
138         &self.extensions
139     }
140 
141     /// Required custom proposal types.
142     #[cfg(feature = "ffi")]
proposals(&self) -> &[ProposalType]143     pub fn proposals(&self) -> &[ProposalType] {
144         &self.proposals
145     }
146 
147     /// Required custom credential types.
148     #[cfg(feature = "ffi")]
credentials(&self) -> &[CredentialType]149     pub fn credentials(&self) -> &[CredentialType] {
150         &self.credentials
151     }
152 }
153 
154 impl MlsCodecExtension for RequiredCapabilitiesExt {
extension_type() -> ExtensionType155     fn extension_type() -> ExtensionType {
156         ExtensionType::REQUIRED_CAPABILITIES
157     }
158 }
159 
160 /// External public key used for [External Commits](crate::Client::commit_external).
161 ///
162 /// This proposal type is optionally provided as part of a
163 /// [Group Info](crate::group::Group::group_info_message).
164 #[cfg_attr(
165     all(feature = "ffi", not(test)),
166     safer_ffi_gen::ffi_type(clone, opaque)
167 )]
168 #[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
169 pub struct ExternalPubExt {
170     /// Public key to be used for an external commit.
171     #[mls_codec(with = "mls_rs_codec::byte_vec")]
172     pub external_pub: HpkePublicKey,
173 }
174 
175 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
176 impl ExternalPubExt {
177     /// Get the public key to be used for an external commit.
178     #[cfg(feature = "ffi")]
external_pub(&self) -> &HpkePublicKey179     pub fn external_pub(&self) -> &HpkePublicKey {
180         &self.external_pub
181     }
182 }
183 
184 impl MlsCodecExtension for ExternalPubExt {
extension_type() -> ExtensionType185     fn extension_type() -> ExtensionType {
186         ExtensionType::EXTERNAL_PUB
187     }
188 }
189 
190 /// Enable proposals by an [ExternalClient](crate::external_client::ExternalClient).
191 #[cfg(feature = "by_ref_proposal")]
192 #[cfg_attr(
193     all(feature = "ffi", not(test)),
194     safer_ffi_gen::ffi_type(clone, opaque)
195 )]
196 #[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
197 #[non_exhaustive]
198 pub struct ExternalSendersExt {
199     pub allowed_senders: Vec<SigningIdentity>,
200 }
201 
202 #[cfg(feature = "by_ref_proposal")]
203 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
204 impl ExternalSendersExt {
new(allowed_senders: Vec<SigningIdentity>) -> Self205     pub fn new(allowed_senders: Vec<SigningIdentity>) -> Self {
206         Self { allowed_senders }
207     }
208 
209     #[cfg(feature = "ffi")]
allowed_senders(&self) -> &[SigningIdentity]210     pub fn allowed_senders(&self) -> &[SigningIdentity] {
211         &self.allowed_senders
212     }
213 
214     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_all<I: IdentityProvider>( &self, provider: &I, timestamp: Option<MlsTime>, group_context_extensions: &ExtensionList, ) -> Result<(), I::Error>215     pub(crate) async fn verify_all<I: IdentityProvider>(
216         &self,
217         provider: &I,
218         timestamp: Option<MlsTime>,
219         group_context_extensions: &ExtensionList,
220     ) -> Result<(), I::Error> {
221         for id in self.allowed_senders.iter() {
222             provider
223                 .validate_external_sender(id, timestamp, Some(group_context_extensions))
224                 .await?;
225         }
226 
227         Ok(())
228     }
229 }
230 
231 #[cfg(feature = "by_ref_proposal")]
232 impl MlsCodecExtension for ExternalSendersExt {
extension_type() -> ExtensionType233     fn extension_type() -> ExtensionType {
234         ExtensionType::EXTERNAL_SENDERS
235     }
236 }
237 
238 #[cfg(test)]
239 mod tests {
240     use super::*;
241 
242     use crate::tree_kem::node::NodeVec;
243     #[cfg(feature = "by_ref_proposal")]
244     use crate::{
245         client::test_utils::TEST_CIPHER_SUITE, identity::test_utils::get_test_signing_identity,
246     };
247 
248     use mls_rs_core::extension::MlsExtension;
249 
250     use mls_rs_core::identity::BasicCredential;
251 
252     use alloc::vec;
253 
254     #[cfg(target_arch = "wasm32")]
255     use wasm_bindgen_test::wasm_bindgen_test as test;
256 
257     #[test]
test_application_id_extension()258     fn test_application_id_extension() {
259         let test_id = vec![0u8; 32];
260         let test_extension = ApplicationIdExt {
261             identifier: test_id.clone(),
262         };
263 
264         let as_extension = test_extension.into_extension().unwrap();
265 
266         assert_eq!(as_extension.extension_type, ExtensionType::APPLICATION_ID);
267 
268         let restored = ApplicationIdExt::from_extension(&as_extension).unwrap();
269         assert_eq!(restored.identifier, test_id);
270     }
271 
272     #[test]
test_ratchet_tree()273     fn test_ratchet_tree() {
274         let ext = RatchetTreeExt {
275             tree_data: ExportedTree::new(NodeVec::from(vec![None, None])),
276         };
277 
278         let as_extension = ext.clone().into_extension().unwrap();
279         assert_eq!(as_extension.extension_type, ExtensionType::RATCHET_TREE);
280 
281         let restored = RatchetTreeExt::from_extension(&as_extension).unwrap();
282         assert_eq!(ext, restored)
283     }
284 
285     #[test]
test_required_capabilities()286     fn test_required_capabilities() {
287         let ext = RequiredCapabilitiesExt {
288             extensions: vec![0.into(), 1.into()],
289             proposals: vec![42.into(), 43.into()],
290             credentials: vec![BasicCredential::credential_type()],
291         };
292 
293         let as_extension = ext.clone().into_extension().unwrap();
294 
295         assert_eq!(
296             as_extension.extension_type,
297             ExtensionType::REQUIRED_CAPABILITIES
298         );
299 
300         let restored = RequiredCapabilitiesExt::from_extension(&as_extension).unwrap();
301         assert_eq!(ext, restored)
302     }
303 
304     #[cfg(feature = "by_ref_proposal")]
305     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_external_senders()306     async fn test_external_senders() {
307         let identity = get_test_signing_identity(TEST_CIPHER_SUITE, &[1]).await.0;
308         let ext = ExternalSendersExt::new(vec![identity]);
309 
310         let as_extension = ext.clone().into_extension().unwrap();
311 
312         assert_eq!(as_extension.extension_type, ExtensionType::EXTERNAL_SENDERS);
313 
314         let restored = ExternalSendersExt::from_extension(&as_extension).unwrap();
315         assert_eq!(ext, restored)
316     }
317 
318     #[test]
test_external_pub()319     fn test_external_pub() {
320         let ext = ExternalPubExt {
321             external_pub: vec![0, 1, 2, 3].into(),
322         };
323 
324         let as_extension = ext.clone().into_extension().unwrap();
325         assert_eq!(as_extension.extension_type, ExtensionType::EXTERNAL_PUB);
326 
327         let restored = ExternalPubExt::from_extension(&as_extension).unwrap();
328         assert_eq!(ext, restored)
329     }
330 }
331