1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec::Vec;
6 use core::{
7     fmt::{self, Debug},
8     ops::Deref,
9 };
10 
11 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
12 use mls_rs_core::{crypto::CipherSuiteProvider, error::IntoAnyError};
13 
14 use crate::{
15     client::MlsError,
16     group::{framing::FramedContent, MessageSignature},
17     WireFormat,
18 };
19 
20 use super::{AuthenticatedContent, ConfirmationTag};
21 
22 #[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
23 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
24 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25 pub struct ConfirmedTranscriptHash(
26     #[mls_codec(with = "mls_rs_codec::byte_vec")]
27     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
28     Vec<u8>,
29 );
30 
31 impl Debug for ConfirmedTranscriptHash {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result32     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33         mls_rs_core::debug::pretty_bytes(&self.0)
34             .named("ConfirmedTranscriptHash")
35             .fmt(f)
36     }
37 }
38 
39 impl Deref for ConfirmedTranscriptHash {
40     type Target = Vec<u8>;
41 
deref(&self) -> &Self::Target42     fn deref(&self) -> &Self::Target {
43         &self.0
44     }
45 }
46 
47 impl From<Vec<u8>> for ConfirmedTranscriptHash {
from(value: Vec<u8>) -> Self48     fn from(value: Vec<u8>) -> Self {
49         Self(value)
50     }
51 }
52 
53 impl ConfirmedTranscriptHash {
54     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
create<P: CipherSuiteProvider>( cipher_suite_provider: &P, interim_transcript_hash: &InterimTranscriptHash, content: &AuthenticatedContent, ) -> Result<Self, MlsError>55     pub(crate) async fn create<P: CipherSuiteProvider>(
56         cipher_suite_provider: &P,
57         interim_transcript_hash: &InterimTranscriptHash,
58         content: &AuthenticatedContent,
59     ) -> Result<Self, MlsError> {
60         #[derive(Debug, MlsSize, MlsEncode)]
61         struct ConfirmedTranscriptHashInput<'a> {
62             wire_format: WireFormat,
63             content: &'a FramedContent,
64             signature: &'a MessageSignature,
65         }
66 
67         let input = ConfirmedTranscriptHashInput {
68             wire_format: content.wire_format,
69             content: &content.content,
70             signature: &content.auth.signature,
71         };
72 
73         let hash_input = [
74             interim_transcript_hash.deref(),
75             input.mls_encode_to_vec()?.deref(),
76         ]
77         .concat();
78 
79         cipher_suite_provider
80             .hash(&hash_input)
81             .await
82             .map(Into::into)
83             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
84     }
85 }
86 
87 #[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
88 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
89 pub(crate) struct InterimTranscriptHash(
90     #[mls_codec(with = "mls_rs_codec::byte_vec")]
91     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
92     Vec<u8>,
93 );
94 
95 impl Debug for InterimTranscriptHash {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result96     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97         mls_rs_core::debug::pretty_bytes(&self.0)
98             .named("InterimTranscriptHash")
99             .fmt(f)
100     }
101 }
102 
103 impl Deref for InterimTranscriptHash {
104     type Target = Vec<u8>;
105 
deref(&self) -> &Self::Target106     fn deref(&self) -> &Self::Target {
107         &self.0
108     }
109 }
110 
111 impl From<Vec<u8>> for InterimTranscriptHash {
from(value: Vec<u8>) -> Self112     fn from(value: Vec<u8>) -> Self {
113         Self(value)
114     }
115 }
116 
117 impl InterimTranscriptHash {
118     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
create<P: CipherSuiteProvider>( cipher_suite_provider: &P, confirmed: &ConfirmedTranscriptHash, confirmation_tag: &ConfirmationTag, ) -> Result<Self, MlsError>119     pub async fn create<P: CipherSuiteProvider>(
120         cipher_suite_provider: &P,
121         confirmed: &ConfirmedTranscriptHash,
122         confirmation_tag: &ConfirmationTag,
123     ) -> Result<Self, MlsError> {
124         #[derive(Debug, MlsSize, MlsEncode)]
125         struct InterimTranscriptHashInput<'a> {
126             confirmation_tag: &'a ConfirmationTag,
127         }
128 
129         let input = InterimTranscriptHashInput { confirmation_tag }.mls_encode_to_vec()?;
130 
131         cipher_suite_provider
132             .hash(&[confirmed.0.deref(), &input].concat())
133             .await
134             .map(Into::into)
135             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
136     }
137 }
138 
139 // Test vectors come from the MLS interop repository and contain a proposal by reference.
140 #[cfg(feature = "by_ref_proposal")]
141 #[cfg(test)]
142 mod tests {
143     use alloc::vec::Vec;
144 
145     use mls_rs_codec::MlsDecode;
146 
147     use crate::{
148         crypto::test_utils::try_test_cipher_suite_provider,
149         group::{framing::ContentType, message_signature::AuthenticatedContent, transcript_hashes},
150     };
151 
152     #[cfg(not(mls_build_async))]
153     use alloc::{boxed::Box, vec};
154 
155     #[cfg(not(mls_build_async))]
156     use crate::{
157         crypto::test_utils::test_cipher_suite_provider,
158         group::{
159             confirmation_tag::ConfirmationTag,
160             framing::Content,
161             proposal::{Proposal, ProposalOrRef, RemoveProposal},
162             test_utils::get_test_group_context,
163             Commit, LeafIndex, Sender,
164         },
165         mls_rs_codec::MlsEncode,
166         CipherSuite, CipherSuiteProvider, WireFormat,
167     };
168 
169     #[cfg(not(mls_build_async))]
170     use super::{ConfirmedTranscriptHash, InterimTranscriptHash};
171 
172     #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
173     struct TestCase {
174         pub cipher_suite: u16,
175 
176         #[serde(with = "hex::serde")]
177         pub confirmation_key: Vec<u8>,
178         #[serde(with = "hex::serde")]
179         pub authenticated_content: Vec<u8>,
180         #[serde(with = "hex::serde")]
181         pub interim_transcript_hash_before: Vec<u8>,
182 
183         #[serde(with = "hex::serde")]
184         pub confirmed_transcript_hash_after: Vec<u8>,
185         #[serde(with = "hex::serde")]
186         pub interim_transcript_hash_after: Vec<u8>,
187     }
188 
189     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
transcript_hash()190     async fn transcript_hash() {
191         let test_cases: Vec<TestCase> =
192             load_test_case_json!(interop_transcript_hashes, generate_test_vector());
193 
194         for test_case in test_cases.into_iter() {
195             let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
196                 continue;
197             };
198 
199             let auth_content =
200                 AuthenticatedContent::mls_decode(&mut &*test_case.authenticated_content).unwrap();
201 
202             assert!(auth_content.content.content_type() == ContentType::Commit);
203 
204             let conf_key = &test_case.confirmation_key;
205             let conf_hash_after = test_case.confirmed_transcript_hash_after.into();
206             let conf_tag = auth_content.auth.confirmation_tag.clone().unwrap();
207 
208             let matches = conf_tag
209                 .matches(conf_key, &conf_hash_after, &cs)
210                 .await
211                 .unwrap();
212 
213             assert!(matches);
214 
215             let (expected_interim, expected_conf) = transcript_hashes(
216                 &cs,
217                 &test_case.interim_transcript_hash_before.into(),
218                 &auth_content,
219             )
220             .await
221             .unwrap();
222 
223             assert_eq!(*expected_interim, test_case.interim_transcript_hash_after);
224             assert_eq!(expected_conf, conf_hash_after);
225         }
226     }
227 
228     #[cfg(not(mls_build_async))]
229     #[cfg_attr(coverage_nightly, coverage(off))]
generate_test_vector() -> Vec<TestCase>230     fn generate_test_vector() -> Vec<TestCase> {
231         CipherSuite::all().fold(vec![], |mut test_cases, cs| {
232             let cs = test_cipher_suite_provider(cs);
233 
234             let context = get_test_group_context(0x3456, cs.cipher_suite());
235 
236             let proposal = Proposal::Remove(RemoveProposal {
237                 to_remove: LeafIndex(1),
238             });
239 
240             let proposal = ProposalOrRef::Proposal(Box::new(proposal));
241 
242             let commit = Commit {
243                 proposals: vec![proposal],
244                 path: None,
245             };
246 
247             let signer = cs.signature_key_generate().unwrap().0;
248 
249             let mut auth_content = AuthenticatedContent::new_signed(
250                 &cs,
251                 &context,
252                 Sender::Member(0),
253                 Content::Commit(alloc::boxed::Box::new(commit)),
254                 &signer,
255                 WireFormat::PublicMessage,
256                 vec![],
257             )
258             .unwrap();
259 
260             let interim_hash_before = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap().into();
261 
262             let conf_hash_after =
263                 ConfirmedTranscriptHash::create(&cs, &interim_hash_before, &auth_content).unwrap();
264 
265             let conf_key = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap();
266             let conf_tag = ConfirmationTag::create(&conf_key, &conf_hash_after, &cs).unwrap();
267 
268             let interim_hash_after =
269                 InterimTranscriptHash::create(&cs, &conf_hash_after, &conf_tag).unwrap();
270 
271             auth_content.auth.confirmation_tag = Some(conf_tag);
272 
273             let test_case = TestCase {
274                 cipher_suite: cs.cipher_suite().into(),
275 
276                 confirmation_key: conf_key,
277                 authenticated_content: auth_content.mls_encode_to_vec().unwrap(),
278                 interim_transcript_hash_before: interim_hash_before.0,
279 
280                 confirmed_transcript_hash_after: conf_hash_after.0,
281                 interim_transcript_hash_after: interim_hash_after.0,
282             };
283 
284             test_cases.push(test_case);
285             test_cases
286         })
287     }
288 
289     #[cfg(mls_build_async)]
generate_test_vector() -> Vec<TestCase>290     fn generate_test_vector() -> Vec<TestCase> {
291         panic!("Tests cannot be generated in async mode");
292     }
293 }
294