1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use core::{
6     fmt::{self, Debug},
7     ops::Deref,
8 };
9 
10 use crate::client::MlsError;
11 use crate::CipherSuiteProvider;
12 use alloc::vec::Vec;
13 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
14 use mls_rs_core::error::IntoAnyError;
15 
16 #[derive(MlsSize, MlsEncode)]
17 struct RefHashInput<'a> {
18     #[mls_codec(with = "mls_rs_codec::byte_vec")]
19     pub label: &'a [u8],
20     #[mls_codec(with = "mls_rs_codec::byte_vec")]
21     pub value: &'a [u8],
22 }
23 
24 impl Debug for RefHashInput<'_> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result25     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26         f.debug_struct("RefHashInput")
27             .field("label", &mls_rs_core::debug::pretty_bytes(self.label))
28             .field("value", &mls_rs_core::debug::pretty_bytes(self.value))
29             .finish()
30     }
31 }
32 
33 #[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Clone, MlsSize, MlsEncode, MlsDecode)]
34 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
35 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
36 pub struct HashReference(
37     #[mls_codec(with = "mls_rs_codec::byte_vec")]
38     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
39     Vec<u8>,
40 );
41 
42 impl Debug for HashReference {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result43     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44         mls_rs_core::debug::pretty_bytes(&self.0)
45             .named("HashReference")
46             .fmt(f)
47     }
48 }
49 
50 impl Deref for HashReference {
51     type Target = [u8];
52 
deref(&self) -> &Self::Target53     fn deref(&self) -> &Self::Target {
54         &self.0
55     }
56 }
57 
58 impl AsRef<[u8]> for HashReference {
as_ref(&self) -> &[u8]59     fn as_ref(&self) -> &[u8] {
60         &self.0
61     }
62 }
63 
64 impl From<Vec<u8>> for HashReference {
from(val: Vec<u8>) -> Self65     fn from(val: Vec<u8>) -> Self {
66         Self(val)
67     }
68 }
69 
70 impl HashReference {
71     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
compute<P: CipherSuiteProvider>( value: &[u8], label: &[u8], cipher_suite: &P, ) -> Result<HashReference, MlsError>72     pub async fn compute<P: CipherSuiteProvider>(
73         value: &[u8],
74         label: &[u8],
75         cipher_suite: &P,
76     ) -> Result<HashReference, MlsError> {
77         let input = RefHashInput { label, value };
78         let input_bytes = input.mls_encode_to_vec()?;
79 
80         cipher_suite
81             .hash(&input_bytes)
82             .await
83             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
84             .map(HashReference)
85     }
86 }
87 
88 #[cfg(test)]
89 mod tests {
90     use crate::crypto::test_utils::try_test_cipher_suite_provider;
91 
92     #[cfg(not(mls_build_async))]
93     use crate::{cipher_suite::CipherSuite, crypto::test_utils::test_cipher_suite_provider};
94 
95     use super::*;
96     use alloc::string::String;
97     use serde::{Deserialize, Serialize};
98 
99     #[cfg(not(mls_build_async))]
100     use alloc::string::ToString;
101 
102     #[cfg(target_arch = "wasm32")]
103     use wasm_bindgen_test::wasm_bindgen_test as test;
104 
105     #[derive(Debug, Deserialize, Serialize)]
106     struct HashRefTestCase {
107         label: String,
108         #[serde(with = "hex::serde")]
109         value: Vec<u8>,
110         #[serde(with = "hex::serde")]
111         out: Vec<u8>,
112     }
113 
114     #[derive(Debug, serde::Serialize, serde::Deserialize)]
115     pub struct InteropTestCase {
116         cipher_suite: u16,
117         ref_hash: HashRefTestCase,
118     }
119 
120     #[cfg(not(mls_build_async))]
121     #[cfg_attr(coverage_nightly, coverage(off))]
generate_test_vector() -> Vec<InteropTestCase>122     fn generate_test_vector() -> Vec<InteropTestCase> {
123         CipherSuite::all()
124             .map(|cipher_suite| {
125                 let provider = test_cipher_suite_provider(cipher_suite);
126 
127                 let input = b"test input";
128                 let label = "test label";
129 
130                 let output = HashReference::compute(input, label.as_bytes(), &provider).unwrap();
131 
132                 let ref_hash = HashRefTestCase {
133                     label: label.to_string(),
134                     value: input.to_vec(),
135                     out: output.to_vec(),
136                 };
137 
138                 InteropTestCase {
139                     cipher_suite: cipher_suite.into(),
140                     ref_hash,
141                 }
142             })
143             .collect()
144     }
145 
146     #[cfg(mls_build_async)]
generate_test_vector() -> Vec<InteropTestCase>147     fn generate_test_vector() -> Vec<InteropTestCase> {
148         panic!("Tests cannot be generated in async mode");
149     }
150 
151     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_basic_crypto_test_vectors()152     async fn test_basic_crypto_test_vectors() {
153         // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/crypto-basics.json
154         let test_cases: Vec<InteropTestCase> =
155             load_test_case_json!(basic_crypto, generate_test_vector());
156 
157         for test_case in test_cases {
158             if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
159                 let label = test_case.ref_hash.label.as_bytes();
160                 let value = &test_case.ref_hash.value;
161                 let computed = HashReference::compute(value, label, &cs).await.unwrap();
162                 assert_eq!(&*computed, &test_case.ref_hash.out);
163             }
164         }
165     }
166 }
167