1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use alloc::vec::Vec; 6 use core::{ 7 fmt::{self, Debug}, 8 ops::Deref, 9 }; 10 11 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 12 use mls_rs_core::{crypto::CipherSuiteProvider, error::IntoAnyError}; 13 14 use crate::{ 15 client::MlsError, 16 group::{framing::FramedContent, MessageSignature}, 17 WireFormat, 18 }; 19 20 use super::{AuthenticatedContent, ConfirmationTag}; 21 22 #[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)] 23 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 24 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 25 pub struct ConfirmedTranscriptHash( 26 #[mls_codec(with = "mls_rs_codec::byte_vec")] 27 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))] 28 Vec<u8>, 29 ); 30 31 impl Debug for ConfirmedTranscriptHash { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 33 mls_rs_core::debug::pretty_bytes(&self.0) 34 .named("ConfirmedTranscriptHash") 35 .fmt(f) 36 } 37 } 38 39 impl Deref for ConfirmedTranscriptHash { 40 type Target = Vec<u8>; 41 deref(&self) -> &Self::Target42 fn deref(&self) -> &Self::Target { 43 &self.0 44 } 45 } 46 47 impl From<Vec<u8>> for ConfirmedTranscriptHash { from(value: Vec<u8>) -> Self48 fn from(value: Vec<u8>) -> Self { 49 Self(value) 50 } 51 } 52 53 impl ConfirmedTranscriptHash { 54 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] create<P: CipherSuiteProvider>( cipher_suite_provider: &P, interim_transcript_hash: &InterimTranscriptHash, content: &AuthenticatedContent, ) -> Result<Self, MlsError>55 pub(crate) async fn create<P: CipherSuiteProvider>( 56 cipher_suite_provider: &P, 57 interim_transcript_hash: &InterimTranscriptHash, 58 content: &AuthenticatedContent, 59 ) -> Result<Self, MlsError> { 60 #[derive(Debug, MlsSize, MlsEncode)] 61 struct ConfirmedTranscriptHashInput<'a> { 62 wire_format: WireFormat, 63 content: &'a FramedContent, 64 signature: &'a MessageSignature, 65 } 66 67 let input = ConfirmedTranscriptHashInput { 68 wire_format: content.wire_format, 69 content: &content.content, 70 signature: &content.auth.signature, 71 }; 72 73 let hash_input = [ 74 interim_transcript_hash.deref(), 75 input.mls_encode_to_vec()?.deref(), 76 ] 77 .concat(); 78 79 cipher_suite_provider 80 .hash(&hash_input) 81 .await 82 .map(Into::into) 83 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error())) 84 } 85 } 86 87 #[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)] 88 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 89 pub(crate) struct InterimTranscriptHash( 90 #[mls_codec(with = "mls_rs_codec::byte_vec")] 91 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))] 92 Vec<u8>, 93 ); 94 95 impl Debug for InterimTranscriptHash { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result96 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 97 mls_rs_core::debug::pretty_bytes(&self.0) 98 .named("InterimTranscriptHash") 99 .fmt(f) 100 } 101 } 102 103 impl Deref for InterimTranscriptHash { 104 type Target = Vec<u8>; 105 deref(&self) -> &Self::Target106 fn deref(&self) -> &Self::Target { 107 &self.0 108 } 109 } 110 111 impl From<Vec<u8>> for InterimTranscriptHash { from(value: Vec<u8>) -> Self112 fn from(value: Vec<u8>) -> Self { 113 Self(value) 114 } 115 } 116 117 impl InterimTranscriptHash { 118 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] create<P: CipherSuiteProvider>( cipher_suite_provider: &P, confirmed: &ConfirmedTranscriptHash, confirmation_tag: &ConfirmationTag, ) -> Result<Self, MlsError>119 pub async fn create<P: CipherSuiteProvider>( 120 cipher_suite_provider: &P, 121 confirmed: &ConfirmedTranscriptHash, 122 confirmation_tag: &ConfirmationTag, 123 ) -> Result<Self, MlsError> { 124 #[derive(Debug, MlsSize, MlsEncode)] 125 struct InterimTranscriptHashInput<'a> { 126 confirmation_tag: &'a ConfirmationTag, 127 } 128 129 let input = InterimTranscriptHashInput { confirmation_tag }.mls_encode_to_vec()?; 130 131 cipher_suite_provider 132 .hash(&[confirmed.0.deref(), &input].concat()) 133 .await 134 .map(Into::into) 135 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error())) 136 } 137 } 138 139 // Test vectors come from the MLS interop repository and contain a proposal by reference. 140 #[cfg(feature = "by_ref_proposal")] 141 #[cfg(test)] 142 mod tests { 143 use alloc::vec::Vec; 144 145 use mls_rs_codec::MlsDecode; 146 147 use crate::{ 148 crypto::test_utils::try_test_cipher_suite_provider, 149 group::{framing::ContentType, message_signature::AuthenticatedContent, transcript_hashes}, 150 }; 151 152 #[cfg(not(mls_build_async))] 153 use alloc::{boxed::Box, vec}; 154 155 #[cfg(not(mls_build_async))] 156 use crate::{ 157 crypto::test_utils::test_cipher_suite_provider, 158 group::{ 159 confirmation_tag::ConfirmationTag, 160 framing::Content, 161 proposal::{Proposal, ProposalOrRef, RemoveProposal}, 162 test_utils::get_test_group_context, 163 Commit, LeafIndex, Sender, 164 }, 165 mls_rs_codec::MlsEncode, 166 CipherSuite, CipherSuiteProvider, WireFormat, 167 }; 168 169 #[cfg(not(mls_build_async))] 170 use super::{ConfirmedTranscriptHash, InterimTranscriptHash}; 171 172 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)] 173 struct TestCase { 174 pub cipher_suite: u16, 175 176 #[serde(with = "hex::serde")] 177 pub confirmation_key: Vec<u8>, 178 #[serde(with = "hex::serde")] 179 pub authenticated_content: Vec<u8>, 180 #[serde(with = "hex::serde")] 181 pub interim_transcript_hash_before: Vec<u8>, 182 183 #[serde(with = "hex::serde")] 184 pub confirmed_transcript_hash_after: Vec<u8>, 185 #[serde(with = "hex::serde")] 186 pub interim_transcript_hash_after: Vec<u8>, 187 } 188 189 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] transcript_hash()190 async fn transcript_hash() { 191 let test_cases: Vec<TestCase> = 192 load_test_case_json!(interop_transcript_hashes, generate_test_vector()); 193 194 for test_case in test_cases.into_iter() { 195 let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) else { 196 continue; 197 }; 198 199 let auth_content = 200 AuthenticatedContent::mls_decode(&mut &*test_case.authenticated_content).unwrap(); 201 202 assert!(auth_content.content.content_type() == ContentType::Commit); 203 204 let conf_key = &test_case.confirmation_key; 205 let conf_hash_after = test_case.confirmed_transcript_hash_after.into(); 206 let conf_tag = auth_content.auth.confirmation_tag.clone().unwrap(); 207 208 let matches = conf_tag 209 .matches(conf_key, &conf_hash_after, &cs) 210 .await 211 .unwrap(); 212 213 assert!(matches); 214 215 let (expected_interim, expected_conf) = transcript_hashes( 216 &cs, 217 &test_case.interim_transcript_hash_before.into(), 218 &auth_content, 219 ) 220 .await 221 .unwrap(); 222 223 assert_eq!(*expected_interim, test_case.interim_transcript_hash_after); 224 assert_eq!(expected_conf, conf_hash_after); 225 } 226 } 227 228 #[cfg(not(mls_build_async))] 229 #[cfg_attr(coverage_nightly, coverage(off))] generate_test_vector() -> Vec<TestCase>230 fn generate_test_vector() -> Vec<TestCase> { 231 CipherSuite::all().fold(vec![], |mut test_cases, cs| { 232 let cs = test_cipher_suite_provider(cs); 233 234 let context = get_test_group_context(0x3456, cs.cipher_suite()); 235 236 let proposal = Proposal::Remove(RemoveProposal { 237 to_remove: LeafIndex(1), 238 }); 239 240 let proposal = ProposalOrRef::Proposal(Box::new(proposal)); 241 242 let commit = Commit { 243 proposals: vec![proposal], 244 path: None, 245 }; 246 247 let signer = cs.signature_key_generate().unwrap().0; 248 249 let mut auth_content = AuthenticatedContent::new_signed( 250 &cs, 251 &context, 252 Sender::Member(0), 253 Content::Commit(alloc::boxed::Box::new(commit)), 254 &signer, 255 WireFormat::PublicMessage, 256 vec![], 257 ) 258 .unwrap(); 259 260 let interim_hash_before = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap().into(); 261 262 let conf_hash_after = 263 ConfirmedTranscriptHash::create(&cs, &interim_hash_before, &auth_content).unwrap(); 264 265 let conf_key = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap(); 266 let conf_tag = ConfirmationTag::create(&conf_key, &conf_hash_after, &cs).unwrap(); 267 268 let interim_hash_after = 269 InterimTranscriptHash::create(&cs, &conf_hash_after, &conf_tag).unwrap(); 270 271 auth_content.auth.confirmation_tag = Some(conf_tag); 272 273 let test_case = TestCase { 274 cipher_suite: cs.cipher_suite().into(), 275 276 confirmation_key: conf_key, 277 authenticated_content: auth_content.mls_encode_to_vec().unwrap(), 278 interim_transcript_hash_before: interim_hash_before.0, 279 280 confirmed_transcript_hash_after: conf_hash_after.0, 281 interim_transcript_hash_after: interim_hash_after.0, 282 }; 283 284 test_cases.push(test_case); 285 test_cases 286 }) 287 } 288 289 #[cfg(mls_build_async)] generate_test_vector() -> Vec<TestCase>290 fn generate_test_vector() -> Vec<TestCase> { 291 panic!("Tests cannot be generated in async mode"); 292 } 293 } 294