1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use core::{ 6 fmt::{self, Debug}, 7 ops::Deref, 8 }; 9 10 use crate::error::{AnyError, IntoAnyError}; 11 use alloc::vec::Vec; 12 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 13 14 mod list; 15 16 pub use list::*; 17 18 /// Wrapper type representing an extension identifier along with default values 19 /// defined by the MLS RFC. 20 #[derive( 21 Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode, 22 )] 23 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 24 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)] 25 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 26 #[repr(transparent)] 27 pub struct ExtensionType(u16); 28 29 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)] 30 impl ExtensionType { 31 pub const APPLICATION_ID: ExtensionType = ExtensionType(1); 32 pub const RATCHET_TREE: ExtensionType = ExtensionType(2); 33 pub const REQUIRED_CAPABILITIES: ExtensionType = ExtensionType(3); 34 pub const EXTERNAL_PUB: ExtensionType = ExtensionType(4); 35 pub const EXTERNAL_SENDERS: ExtensionType = ExtensionType(5); 36 37 /// Default extension types defined 38 /// in [RFC 9420](https://www.rfc-editor.org/rfc/rfc9420.html#name-leaf-node-contents) 39 pub const DEFAULT: &'static [ExtensionType] = &[ 40 ExtensionType::APPLICATION_ID, 41 ExtensionType::RATCHET_TREE, 42 ExtensionType::REQUIRED_CAPABILITIES, 43 ExtensionType::EXTERNAL_PUB, 44 ExtensionType::EXTERNAL_SENDERS, 45 ]; 46 47 /// Extension type from a raw value new(raw_value: u16) -> Self48 pub const fn new(raw_value: u16) -> Self { 49 ExtensionType(raw_value) 50 } 51 52 /// Raw numerical wrapped value. raw_value(&self) -> u1653 pub const fn raw_value(&self) -> u16 { 54 self.0 55 } 56 57 /// Determines if this extension type is required to be implemented 58 /// by the MLS RFC. is_default(&self) -> bool59 pub const fn is_default(&self) -> bool { 60 self.0 <= 5 61 } 62 } 63 64 impl From<u16> for ExtensionType { from(value: u16) -> Self65 fn from(value: u16) -> Self { 66 ExtensionType(value) 67 } 68 } 69 70 impl Deref for ExtensionType { 71 type Target = u16; 72 deref(&self) -> &Self::Target73 fn deref(&self) -> &Self::Target { 74 &self.0 75 } 76 } 77 78 #[derive(Debug)] 79 #[cfg_attr(feature = "std", derive(thiserror::Error))] 80 pub enum ExtensionError { 81 #[cfg_attr(feature = "std", error(transparent))] 82 SerializationError(AnyError), 83 #[cfg_attr(feature = "std", error(transparent))] 84 DeserializationError(AnyError), 85 #[cfg_attr(feature = "std", error("incorrect extension type: {0:?}"))] 86 IncorrectType(ExtensionType), 87 } 88 89 impl IntoAnyError for ExtensionError { 90 #[cfg(feature = "std")] into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self>91 fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> { 92 Ok(self.into()) 93 } 94 } 95 96 #[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)] 97 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 98 #[cfg_attr( 99 all(feature = "ffi", not(test)), 100 safer_ffi_gen::ffi_type(clone, opaque) 101 )] 102 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 103 #[non_exhaustive] 104 /// An MLS protocol [extension](https://messaginglayersecurity.rocks/mls-protocol/draft-ietf-mls-protocol.html#name-extensions). 105 /// 106 /// Extensions are used as customization points in various parts of the 107 /// MLS protocol and are inserted into an [ExtensionList](self::ExtensionList). 108 pub struct Extension { 109 /// Extension type of this extension 110 pub extension_type: ExtensionType, 111 /// Data held within this extension 112 #[mls_codec(with = "mls_rs_codec::byte_vec")] 113 #[cfg_attr(feature = "serde", serde(with = "crate::vec_serde"))] 114 pub extension_data: Vec<u8>, 115 } 116 117 impl Debug for Extension { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result118 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 119 f.debug_struct("Extension") 120 .field("extension_type", &self.extension_type) 121 .field( 122 "extension_data", 123 &crate::debug::pretty_bytes(&self.extension_data), 124 ) 125 .finish() 126 } 127 } 128 129 #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)] 130 impl Extension { 131 /// Create an extension with specified type and data properties. new(extension_type: ExtensionType, extension_data: Vec<u8>) -> Extension132 pub fn new(extension_type: ExtensionType, extension_data: Vec<u8>) -> Extension { 133 Extension { 134 extension_type, 135 extension_data, 136 } 137 } 138 139 /// Extension type of this extension 140 #[cfg(feature = "ffi")] extension_type(&self) -> ExtensionType141 pub fn extension_type(&self) -> ExtensionType { 142 self.extension_type 143 } 144 145 /// Data held within this extension 146 #[cfg(feature = "ffi")] extension_data(&self) -> &[u8]147 pub fn extension_data(&self) -> &[u8] { 148 &self.extension_data 149 } 150 } 151 152 /// Trait used to convert a type to and from an [Extension] 153 pub trait MlsExtension: Sized { 154 /// Error type of the underlying serializer that can convert this type into a `Vec<u8>`. 155 type SerializationError: IntoAnyError; 156 157 /// Error type of the underlying deserializer that can convert a `Vec<u8>` into this type. 158 type DeserializationError: IntoAnyError; 159 160 /// Extension type value that this type represents. extension_type() -> ExtensionType161 fn extension_type() -> ExtensionType; 162 163 /// Convert this type to opaque bytes. to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError>164 fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError>; 165 166 /// Create this type from opaque bytes. from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError>167 fn from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError>; 168 169 /// Convert this type into an [Extension]. into_extension(self) -> Result<Extension, ExtensionError>170 fn into_extension(self) -> Result<Extension, ExtensionError> { 171 Ok(Extension::new( 172 Self::extension_type(), 173 self.to_bytes() 174 .map_err(|e| ExtensionError::SerializationError(e.into_any_error()))?, 175 )) 176 } 177 178 /// Create this type from an [Extension]. from_extension(ext: &Extension) -> Result<Self, ExtensionError>179 fn from_extension(ext: &Extension) -> Result<Self, ExtensionError> { 180 if ext.extension_type != Self::extension_type() { 181 return Err(ExtensionError::IncorrectType(ext.extension_type)); 182 } 183 184 Self::from_bytes(&ext.extension_data) 185 .map_err(|e| ExtensionError::DeserializationError(e.into_any_error())) 186 } 187 } 188 189 /// Convenience trait for custom extension types that use 190 /// [mls_rs_codec] as an underlying serialization mechanism 191 pub trait MlsCodecExtension: MlsSize + MlsEncode + MlsDecode { extension_type() -> ExtensionType192 fn extension_type() -> ExtensionType; 193 } 194 195 impl<T> MlsExtension for T 196 where 197 T: MlsCodecExtension, 198 { 199 type SerializationError = mls_rs_codec::Error; 200 type DeserializationError = mls_rs_codec::Error; 201 extension_type() -> ExtensionType202 fn extension_type() -> ExtensionType { 203 <Self as MlsCodecExtension>::extension_type() 204 } 205 to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError>206 fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError> { 207 self.mls_encode_to_vec() 208 } 209 from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError>210 fn from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError> { 211 Self::mls_decode(&mut &*data) 212 } 213 } 214 215 #[cfg(test)] 216 mod tests { 217 use core::convert::Infallible; 218 219 use alloc::vec; 220 use alloc::vec::Vec; 221 use assert_matches::assert_matches; 222 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 223 224 use super::{Extension, ExtensionError, ExtensionType, MlsCodecExtension, MlsExtension}; 225 226 struct TestExtension; 227 228 #[derive(Debug, MlsSize, MlsEncode, MlsDecode)] 229 struct AnotherTestExtension; 230 231 impl MlsExtension for TestExtension { 232 type SerializationError = Infallible; 233 type DeserializationError = Infallible; 234 extension_type() -> super::ExtensionType235 fn extension_type() -> super::ExtensionType { 236 ExtensionType(42) 237 } 238 to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError>239 fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError> { 240 Ok(vec![0]) 241 } 242 from_bytes(_data: &[u8]) -> Result<Self, Self::DeserializationError>243 fn from_bytes(_data: &[u8]) -> Result<Self, Self::DeserializationError> { 244 Ok(TestExtension) 245 } 246 } 247 248 impl MlsCodecExtension for AnotherTestExtension { extension_type() -> ExtensionType249 fn extension_type() -> ExtensionType { 250 ExtensionType(43) 251 } 252 } 253 254 #[test] into_extension()255 fn into_extension() { 256 assert_eq!( 257 TestExtension.into_extension().unwrap(), 258 Extension::new(42.into(), vec![0]) 259 ) 260 } 261 262 #[test] incorrect_type_is_discovered()263 fn incorrect_type_is_discovered() { 264 let ext = Extension::new(42.into(), vec![0]); 265 266 assert_matches!(AnotherTestExtension::from_extension(&ext), Err(ExtensionError::IncorrectType(found)) if found == 42.into()); 267 } 268 } 269