1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use core::{
6     fmt::{self, Debug},
7     ops::Deref,
8 };
9 
10 use crate::error::{AnyError, IntoAnyError};
11 use alloc::vec::Vec;
12 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
13 
14 mod list;
15 
16 pub use list::*;
17 
18 /// Wrapper type representing an extension identifier along with default values
19 /// defined by the MLS RFC.
20 #[derive(
21     Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode,
22 )]
23 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
24 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)]
25 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26 #[repr(transparent)]
27 pub struct ExtensionType(u16);
28 
29 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
30 impl ExtensionType {
31     pub const APPLICATION_ID: ExtensionType = ExtensionType(1);
32     pub const RATCHET_TREE: ExtensionType = ExtensionType(2);
33     pub const REQUIRED_CAPABILITIES: ExtensionType = ExtensionType(3);
34     pub const EXTERNAL_PUB: ExtensionType = ExtensionType(4);
35     pub const EXTERNAL_SENDERS: ExtensionType = ExtensionType(5);
36 
37     /// Default extension types defined
38     /// in [RFC 9420](https://www.rfc-editor.org/rfc/rfc9420.html#name-leaf-node-contents)
39     pub const DEFAULT: &'static [ExtensionType] = &[
40         ExtensionType::APPLICATION_ID,
41         ExtensionType::RATCHET_TREE,
42         ExtensionType::REQUIRED_CAPABILITIES,
43         ExtensionType::EXTERNAL_PUB,
44         ExtensionType::EXTERNAL_SENDERS,
45     ];
46 
47     /// Extension type from a raw value
new(raw_value: u16) -> Self48     pub const fn new(raw_value: u16) -> Self {
49         ExtensionType(raw_value)
50     }
51 
52     /// Raw numerical wrapped value.
raw_value(&self) -> u1653     pub const fn raw_value(&self) -> u16 {
54         self.0
55     }
56 
57     /// Determines if this extension type is required to be implemented
58     /// by the MLS RFC.
is_default(&self) -> bool59     pub const fn is_default(&self) -> bool {
60         self.0 <= 5
61     }
62 }
63 
64 impl From<u16> for ExtensionType {
from(value: u16) -> Self65     fn from(value: u16) -> Self {
66         ExtensionType(value)
67     }
68 }
69 
70 impl Deref for ExtensionType {
71     type Target = u16;
72 
deref(&self) -> &Self::Target73     fn deref(&self) -> &Self::Target {
74         &self.0
75     }
76 }
77 
78 #[derive(Debug)]
79 #[cfg_attr(feature = "std", derive(thiserror::Error))]
80 pub enum ExtensionError {
81     #[cfg_attr(feature = "std", error(transparent))]
82     SerializationError(AnyError),
83     #[cfg_attr(feature = "std", error(transparent))]
84     DeserializationError(AnyError),
85     #[cfg_attr(feature = "std", error("incorrect extension type: {0:?}"))]
86     IncorrectType(ExtensionType),
87 }
88 
89 impl IntoAnyError for ExtensionError {
90     #[cfg(feature = "std")]
into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self>91     fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
92         Ok(self.into())
93     }
94 }
95 
96 #[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
97 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
98 #[cfg_attr(
99     all(feature = "ffi", not(test)),
100     safer_ffi_gen::ffi_type(clone, opaque)
101 )]
102 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
103 #[non_exhaustive]
104 /// An MLS protocol [extension](https://messaginglayersecurity.rocks/mls-protocol/draft-ietf-mls-protocol.html#name-extensions).
105 ///
106 /// Extensions are used as customization points in various parts of the
107 /// MLS protocol and are inserted into an [ExtensionList](self::ExtensionList).
108 pub struct Extension {
109     /// Extension type of this extension
110     pub extension_type: ExtensionType,
111     /// Data held within this extension
112     #[mls_codec(with = "mls_rs_codec::byte_vec")]
113     #[cfg_attr(feature = "serde", serde(with = "crate::vec_serde"))]
114     pub extension_data: Vec<u8>,
115 }
116 
117 impl Debug for Extension {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result118     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
119         f.debug_struct("Extension")
120             .field("extension_type", &self.extension_type)
121             .field(
122                 "extension_data",
123                 &crate::debug::pretty_bytes(&self.extension_data),
124             )
125             .finish()
126     }
127 }
128 
129 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
130 impl Extension {
131     /// Create an extension with specified type and data properties.
new(extension_type: ExtensionType, extension_data: Vec<u8>) -> Extension132     pub fn new(extension_type: ExtensionType, extension_data: Vec<u8>) -> Extension {
133         Extension {
134             extension_type,
135             extension_data,
136         }
137     }
138 
139     /// Extension type of this extension
140     #[cfg(feature = "ffi")]
extension_type(&self) -> ExtensionType141     pub fn extension_type(&self) -> ExtensionType {
142         self.extension_type
143     }
144 
145     /// Data held within this extension
146     #[cfg(feature = "ffi")]
extension_data(&self) -> &[u8]147     pub fn extension_data(&self) -> &[u8] {
148         &self.extension_data
149     }
150 }
151 
152 /// Trait used to convert a type to and from an [Extension]
153 pub trait MlsExtension: Sized {
154     /// Error type of the underlying serializer that can convert this type into a `Vec<u8>`.
155     type SerializationError: IntoAnyError;
156 
157     /// Error type of the underlying deserializer that can convert a `Vec<u8>` into this type.
158     type DeserializationError: IntoAnyError;
159 
160     /// Extension type value that this type represents.
extension_type() -> ExtensionType161     fn extension_type() -> ExtensionType;
162 
163     /// Convert this type to opaque bytes.
to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError>164     fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError>;
165 
166     /// Create this type from opaque bytes.
from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError>167     fn from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError>;
168 
169     /// Convert this type into an [Extension].
into_extension(self) -> Result<Extension, ExtensionError>170     fn into_extension(self) -> Result<Extension, ExtensionError> {
171         Ok(Extension::new(
172             Self::extension_type(),
173             self.to_bytes()
174                 .map_err(|e| ExtensionError::SerializationError(e.into_any_error()))?,
175         ))
176     }
177 
178     /// Create this type from an [Extension].
from_extension(ext: &Extension) -> Result<Self, ExtensionError>179     fn from_extension(ext: &Extension) -> Result<Self, ExtensionError> {
180         if ext.extension_type != Self::extension_type() {
181             return Err(ExtensionError::IncorrectType(ext.extension_type));
182         }
183 
184         Self::from_bytes(&ext.extension_data)
185             .map_err(|e| ExtensionError::DeserializationError(e.into_any_error()))
186     }
187 }
188 
189 /// Convenience trait for custom extension types that use
190 /// [mls_rs_codec] as an underlying serialization mechanism
191 pub trait MlsCodecExtension: MlsSize + MlsEncode + MlsDecode {
extension_type() -> ExtensionType192     fn extension_type() -> ExtensionType;
193 }
194 
195 impl<T> MlsExtension for T
196 where
197     T: MlsCodecExtension,
198 {
199     type SerializationError = mls_rs_codec::Error;
200     type DeserializationError = mls_rs_codec::Error;
201 
extension_type() -> ExtensionType202     fn extension_type() -> ExtensionType {
203         <Self as MlsCodecExtension>::extension_type()
204     }
205 
to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError>206     fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError> {
207         self.mls_encode_to_vec()
208     }
209 
from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError>210     fn from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError> {
211         Self::mls_decode(&mut &*data)
212     }
213 }
214 
215 #[cfg(test)]
216 mod tests {
217     use core::convert::Infallible;
218 
219     use alloc::vec;
220     use alloc::vec::Vec;
221     use assert_matches::assert_matches;
222     use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
223 
224     use super::{Extension, ExtensionError, ExtensionType, MlsCodecExtension, MlsExtension};
225 
226     struct TestExtension;
227 
228     #[derive(Debug, MlsSize, MlsEncode, MlsDecode)]
229     struct AnotherTestExtension;
230 
231     impl MlsExtension for TestExtension {
232         type SerializationError = Infallible;
233         type DeserializationError = Infallible;
234 
extension_type() -> super::ExtensionType235         fn extension_type() -> super::ExtensionType {
236             ExtensionType(42)
237         }
238 
to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError>239         fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError> {
240             Ok(vec![0])
241         }
242 
from_bytes(_data: &[u8]) -> Result<Self, Self::DeserializationError>243         fn from_bytes(_data: &[u8]) -> Result<Self, Self::DeserializationError> {
244             Ok(TestExtension)
245         }
246     }
247 
248     impl MlsCodecExtension for AnotherTestExtension {
extension_type() -> ExtensionType249         fn extension_type() -> ExtensionType {
250             ExtensionType(43)
251         }
252     }
253 
254     #[test]
into_extension()255     fn into_extension() {
256         assert_eq!(
257             TestExtension.into_extension().unwrap(),
258             Extension::new(42.into(), vec![0])
259         )
260     }
261 
262     #[test]
incorrect_type_is_discovered()263     fn incorrect_type_is_discovered() {
264         let ext = Extension::new(42.into(), vec![0]);
265 
266         assert_matches!(AnotherTestExtension::from_extension(&ext), Err(ExtensionError::IncorrectType(found)) if found == 42.into());
267     }
268 }
269