1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use core::{ 6 fmt::{self, Debug}, 7 ops::Deref, 8 }; 9 10 use crate::client::MlsError; 11 use crate::CipherSuiteProvider; 12 use alloc::vec::Vec; 13 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 14 use mls_rs_core::error::IntoAnyError; 15 16 #[derive(MlsSize, MlsEncode)] 17 struct RefHashInput<'a> { 18 #[mls_codec(with = "mls_rs_codec::byte_vec")] 19 pub label: &'a [u8], 20 #[mls_codec(with = "mls_rs_codec::byte_vec")] 21 pub value: &'a [u8], 22 } 23 24 impl Debug for RefHashInput<'_> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 26 f.debug_struct("RefHashInput") 27 .field("label", &mls_rs_core::debug::pretty_bytes(self.label)) 28 .field("value", &mls_rs_core::debug::pretty_bytes(self.value)) 29 .finish() 30 } 31 } 32 33 #[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Clone, MlsSize, MlsEncode, MlsDecode)] 34 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] 35 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 36 pub struct HashReference( 37 #[mls_codec(with = "mls_rs_codec::byte_vec")] 38 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))] 39 Vec<u8>, 40 ); 41 42 impl Debug for HashReference { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 44 mls_rs_core::debug::pretty_bytes(&self.0) 45 .named("HashReference") 46 .fmt(f) 47 } 48 } 49 50 impl Deref for HashReference { 51 type Target = [u8]; 52 deref(&self) -> &Self::Target53 fn deref(&self) -> &Self::Target { 54 &self.0 55 } 56 } 57 58 impl AsRef<[u8]> for HashReference { as_ref(&self) -> &[u8]59 fn as_ref(&self) -> &[u8] { 60 &self.0 61 } 62 } 63 64 impl From<Vec<u8>> for HashReference { from(val: Vec<u8>) -> Self65 fn from(val: Vec<u8>) -> Self { 66 Self(val) 67 } 68 } 69 70 impl HashReference { 71 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] compute<P: CipherSuiteProvider>( value: &[u8], label: &[u8], cipher_suite: &P, ) -> Result<HashReference, MlsError>72 pub async fn compute<P: CipherSuiteProvider>( 73 value: &[u8], 74 label: &[u8], 75 cipher_suite: &P, 76 ) -> Result<HashReference, MlsError> { 77 let input = RefHashInput { label, value }; 78 let input_bytes = input.mls_encode_to_vec()?; 79 80 cipher_suite 81 .hash(&input_bytes) 82 .await 83 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error())) 84 .map(HashReference) 85 } 86 } 87 88 #[cfg(test)] 89 mod tests { 90 use crate::crypto::test_utils::try_test_cipher_suite_provider; 91 92 #[cfg(not(mls_build_async))] 93 use crate::{cipher_suite::CipherSuite, crypto::test_utils::test_cipher_suite_provider}; 94 95 use super::*; 96 use alloc::string::String; 97 use serde::{Deserialize, Serialize}; 98 99 #[cfg(not(mls_build_async))] 100 use alloc::string::ToString; 101 102 #[cfg(target_arch = "wasm32")] 103 use wasm_bindgen_test::wasm_bindgen_test as test; 104 105 #[derive(Debug, Deserialize, Serialize)] 106 struct HashRefTestCase { 107 label: String, 108 #[serde(with = "hex::serde")] 109 value: Vec<u8>, 110 #[serde(with = "hex::serde")] 111 out: Vec<u8>, 112 } 113 114 #[derive(Debug, serde::Serialize, serde::Deserialize)] 115 pub struct InteropTestCase { 116 cipher_suite: u16, 117 ref_hash: HashRefTestCase, 118 } 119 120 #[cfg(not(mls_build_async))] 121 #[cfg_attr(coverage_nightly, coverage(off))] generate_test_vector() -> Vec<InteropTestCase>122 fn generate_test_vector() -> Vec<InteropTestCase> { 123 CipherSuite::all() 124 .map(|cipher_suite| { 125 let provider = test_cipher_suite_provider(cipher_suite); 126 127 let input = b"test input"; 128 let label = "test label"; 129 130 let output = HashReference::compute(input, label.as_bytes(), &provider).unwrap(); 131 132 let ref_hash = HashRefTestCase { 133 label: label.to_string(), 134 value: input.to_vec(), 135 out: output.to_vec(), 136 }; 137 138 InteropTestCase { 139 cipher_suite: cipher_suite.into(), 140 ref_hash, 141 } 142 }) 143 .collect() 144 } 145 146 #[cfg(mls_build_async)] generate_test_vector() -> Vec<InteropTestCase>147 fn generate_test_vector() -> Vec<InteropTestCase> { 148 panic!("Tests cannot be generated in async mode"); 149 } 150 151 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_basic_crypto_test_vectors()152 async fn test_basic_crypto_test_vectors() { 153 // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/crypto-basics.json 154 let test_cases: Vec<InteropTestCase> = 155 load_test_case_json!(basic_crypto, generate_test_vector()); 156 157 for test_case in test_cases { 158 if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) { 159 let label = test_case.ref_hash.label.as_bytes(); 160 let value = &test_case.ref_hash.value; 161 let computed = HashReference::compute(value, label, &cs).await.unwrap(); 162 assert_eq!(&*computed, &test_case.ref_hash.out); 163 } 164 } 165 } 166 } 167