1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 extern crate alloc;
15 pub use crate::prelude::*;
16 use crate::CryptoProvider;
17 use core::iter;
18 use core::marker::PhantomData;
19 use crypto_provider::hkdf::Hkdf;
20 use hex_literal::hex;
21 use rstest_reuse::template;
22 
23 /// Generates the test cases to validate the hkdf implementation.
24 /// For example, to test `MyCryptoProvider`:
25 ///
26 /// ```
27 /// mod tests {
28 ///     use std::marker::PhantomData;
29 ///     use crypto_provider::testing::CryptoProviderTestCase;
30 ///     #[apply(hkdf_test_cases)]
31 ///     fn hkdf_tests(testcase: CryptoProviderTestCase<MyCryptoProvider>){
32 ///         testcase(PhantomData::<MyCryptoProvider>);
33 ///     }
34 /// }
35 /// ```
36 #[template]
37 #[export]
38 #[rstest]
39 #[case::basic_test_hkdf(basic_test_hkdf)]
40 #[case::test_rfc5869_sha256(test_rfc5869_sha256)]
41 #[case::test_lengths(test_lengths)]
42 #[case::test_max_length(test_max_length)]
43 #[case::test_max_length_exceeded(test_max_length_exceeded)]
44 #[case::test_unsupported_length(test_unsupported_length)]
45 #[case::test_expand_multi_info(test_expand_multi_info)]
46 #[case::run_hkdf_sha256_vectors(run_hkdf_sha256_vectors)]
47 #[case::run_hkdf_sha512_vectors(run_hkdf_sha512_vectors)]
hkdf_test_cases<C: CryptoProvider>(#[case] testcase: CryptoProviderTestCase<C>)48 fn hkdf_test_cases<C: CryptoProvider>(#[case] testcase: CryptoProviderTestCase<C>) {}
49 
50 const MAX_SHA256_LENGTH: usize = 255 * (256 / 8); // =8160
51 
52 /// Content of a HKDF test-case.
53 pub struct Test<'a> {
54     ikm: &'a [u8],
55     salt: &'a [u8],
56     info: &'a [u8],
57     okm: &'a [u8],
58 }
59 
60 /// data taken from sample code in Readme of crates.io page
basic_test_hkdf<C: CryptoProvider>(_: PhantomData<C>)61 pub fn basic_test_hkdf<C: CryptoProvider>(_: PhantomData<C>) {
62     let ikm = hex!("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b");
63     let salt = hex!("000102030405060708090a0b0c");
64     let info = hex!("f0f1f2f3f4f5f6f7f8f9");
65 
66     let hk = C::HkdfSha256::new(Some(&salt[..]), &ikm);
67     let mut okm = [0u8; 42];
68     hk.expand(&info, &mut okm).expect("42 is a valid length for Sha256 to output");
69 
70     let expected = hex!(
71         "
72         3cb25f25faacd57a90434f64d0362f2a
73         2d2d0a90cf1a5a4c5db02d56ecc4c5bf
74         34007208d5b887185865
75         "
76     );
77     assert_eq!(okm, expected);
78 }
79 
80 #[rustfmt::skip]
81     /// Test Vectors from <https://tools.ietf.org/html/rfc5869>.
test_rfc5869_sha256<C: CryptoProvider>(_: PhantomData<C>)82     pub fn test_rfc5869_sha256<C: CryptoProvider>(_: PhantomData<C>) {
83         let tests = [
84             Test {
85                 // Test Case 1
86                 ikm: &hex!("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b"),
87                 salt: &hex!("000102030405060708090a0b0c"),
88                 info: &hex!("f0f1f2f3f4f5f6f7f8f9"),
89                 okm: &hex!("
90                     3cb25f25faacd57a90434f64d0362f2a
91                     2d2d0a90cf1a5a4c5db02d56ecc4c5bf
92                     34007208d5b887185865
93                 "),
94             },
95             Test {
96                 // Test Case 2
97                 ikm: &hex!("
98                     000102030405060708090a0b0c0d0e0f
99                     101112131415161718191a1b1c1d1e1f
100                     202122232425262728292a2b2c2d2e2f
101                     303132333435363738393a3b3c3d3e3f
102                     404142434445464748494a4b4c4d4e4f
103                 "),
104                 salt: &hex!("
105                     606162636465666768696a6b6c6d6e6f
106                     707172737475767778797a7b7c7d7e7f
107                     808182838485868788898a8b8c8d8e8f
108                     909192939495969798999a9b9c9d9e9f
109                     a0a1a2a3a4a5a6a7a8a9aaabacadaeaf
110                 "),
111                 info: &hex!("
112                     b0b1b2b3b4b5b6b7b8b9babbbcbdbebf
113                     c0c1c2c3c4c5c6c7c8c9cacbcccdcecf
114                     d0d1d2d3d4d5d6d7d8d9dadbdcdddedf
115                     e0e1e2e3e4e5e6e7e8e9eaebecedeeef
116                     f0f1f2f3f4f5f6f7f8f9fafbfcfdfeff
117                 "),
118                 okm: &hex!("
119                     b11e398dc80327a1c8e7f78c596a4934
120                     4f012eda2d4efad8a050cc4c19afa97c
121                     59045a99cac7827271cb41c65e590e09
122                     da3275600c2f09b8367793a9aca3db71
123                     cc30c58179ec3e87c14c01d5c1f3434f
124                     1d87
125                 "),
126             },
127             Test {
128                 // Test Case 3
129                 ikm: &hex!("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b"),
130                 salt: &hex!(""),
131                 info: &hex!(""),
132                 okm: &hex!("
133                     8da4e775a563c18f715f802a063c5a31
134                     b8a11f5c5ee1879ec3454e5f3c738d2d
135                     9d201395faa4b61a96c8
136                 "),
137             },
138         ];
139         for Test { ikm, salt, info, okm } in tests.iter() {
140             let salt = if salt.is_empty() {
141                 None
142             } else {
143                 Some(&salt[..])
144             };
145             let hkdf = C::HkdfSha256::new(salt, ikm);
146             let mut okm2 = vec![0u8; okm.len()];
147             assert!(hkdf.expand(&info[..], &mut okm2).is_ok());
148             assert_eq!(okm2[..], okm[..]);
149         }
150     }
151 
152 /// Tests a bunch of HKDFs of differing lengths.
test_lengths<C: CryptoProvider>(_: PhantomData<C>)153 pub fn test_lengths<C: CryptoProvider>(_: PhantomData<C>) {
154     let hkdf = C::HkdfSha256::new(None, &[]);
155     let mut longest = vec![0u8; MAX_SHA256_LENGTH];
156     assert!(hkdf.expand(&[], &mut longest).is_ok());
157     // Runtime is O(length), so exhaustively testing all legal lengths
158     // would take too long (at least without --release). Only test a
159     // subset: the first 500, the last 10, and every 100th in between.
160     // 0 is an invalid key length for openssl, so start at 1
161     let lengths = (1..MAX_SHA256_LENGTH + 1)
162         .filter(|&len| !(500..=MAX_SHA256_LENGTH - 10).contains(&len) || len % 100 == 0);
163 
164     for length in lengths {
165         let mut okm = vec![0u8; length];
166 
167         assert!(hkdf.expand(&[], &mut okm).is_ok());
168         assert_eq!(okm.len(), length);
169         assert_eq!(okm[..], longest[..length]);
170     }
171 }
172 
173 /// Tests an HKDF with the maximum length for Sha256.
test_max_length<C: CryptoProvider>(_: PhantomData<C>)174 pub fn test_max_length<C: CryptoProvider>(_: PhantomData<C>) {
175     let hkdf = C::HkdfSha256::new(Some(&[]), &[]);
176     let mut okm = vec![0u8; MAX_SHA256_LENGTH];
177     assert!(hkdf.expand(&[], &mut okm).is_ok());
178 }
179 
180 /// Tests an HKDF above the maximum length for Sha256.
test_max_length_exceeded<C: CryptoProvider>(_: PhantomData<C>)181 pub fn test_max_length_exceeded<C: CryptoProvider>(_: PhantomData<C>) {
182     let hkdf = C::HkdfSha256::new(Some(&[]), &[]);
183     let mut okm = vec![0u8; MAX_SHA256_LENGTH + 1];
184     assert!(hkdf.expand(&[], &mut okm).is_err());
185 }
186 
187 /// Tests an HKDF with an unsupported length.
test_unsupported_length<C: CryptoProvider>(_: PhantomData<C>)188 pub fn test_unsupported_length<C: CryptoProvider>(_: PhantomData<C>) {
189     let hkdf = C::HkdfSha256::new(Some(&[]), &[]);
190     let mut okm = vec![0u8; 90000];
191     assert!(hkdf.expand(&[], &mut okm).is_err());
192 }
193 
194 /// Tests HKDF-Expand on the concatenation of info components.
test_expand_multi_info<C: CryptoProvider>(_: PhantomData<C>)195 pub fn test_expand_multi_info<C: CryptoProvider>(_: PhantomData<C>) {
196     let info_components = &[
197         &b"09090909090909090909090909090909090909090909"[..],
198         &b"8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a"[..],
199         &b"0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0"[..],
200         &b"4c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4"[..],
201         &b"1d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1d"[..],
202     ];
203 
204     let hkdf = C::HkdfSha256::new(None, b"some ikm here");
205 
206     // Compute HKDF-Expand on the concatenation of all the info components
207     let mut oneshot_res = [0u8; 16];
208     hkdf.expand(&info_components.concat(), &mut oneshot_res).unwrap();
209 
210     // Now iteratively join the components of info_components until it's all 1 component. The value
211     // of HKDF-Expand should be the same throughout
212     let mut num_concatted = 0;
213     let mut info_head = Vec::new();
214 
215     while num_concatted < info_components.len() {
216         info_head.extend(info_components[num_concatted]);
217 
218         // Build the new input to be the info head followed by the remaining components
219         let input: Vec<&[u8]> = iter::once(info_head.as_slice())
220             .chain(info_components.iter().cloned().skip(num_concatted + 1))
221             .collect();
222 
223         // Compute and compare to the one-shot answer
224         let mut multipart_res = [0u8; 16];
225         hkdf.expand_multi_info(&input, &mut multipart_res).unwrap();
226         assert_eq!(multipart_res, oneshot_res);
227         num_concatted += 1;
228     }
229 }
230 
231 /// Runs hkdf test vectors using Sha256.
run_hkdf_sha256_vectors<C: CryptoProvider>(_: PhantomData<C>)232 pub fn run_hkdf_sha256_vectors<C: CryptoProvider>(_: PhantomData<C>) {
233     run_hkdf_test_vectors::<C::HkdfSha256>(HashAlg::Sha256)
234 }
235 
236 /// Runs hkdf test vectors using Sha512.
run_hkdf_sha512_vectors<C: CryptoProvider>(_: PhantomData<C>)237 pub fn run_hkdf_sha512_vectors<C: CryptoProvider>(_: PhantomData<C>) {
238     run_hkdf_test_vectors::<C::HkdfSha512>(HashAlg::Sha512)
239 }
240 
241 enum HashAlg {
242     Sha256,
243     Sha512,
244 }
245 
246 /// Runs the wycheproof test vectors for the given hashing algorithm.
run_hkdf_test_vectors<K: Hkdf>(hash: HashAlg)247 fn run_hkdf_test_vectors<K: Hkdf>(hash: HashAlg) {
248     let test_name = match hash {
249         HashAlg::Sha256 => wycheproof::hkdf::TestName::HkdfSha256,
250         HashAlg::Sha512 => wycheproof::hkdf::TestName::HkdfSha512,
251     };
252 
253     let test_set =
254         wycheproof::hkdf::TestSet::load(test_name).expect("should be able to load test set");
255     for test_group in test_set.test_groups {
256         for test in test_group.tests {
257             let ikm = test.ikm;
258             let salt = test.salt;
259             let info = test.info;
260             let okm = test.okm;
261             let tc_id = test.tc_id;
262             if let Some(desc) =
263                 run_test::<K>(ikm.as_slice(), salt.as_slice(), info.as_slice(), okm.as_slice())
264             {
265                 panic!(
266                     "\n\
267                          Failed test {tc_id}: {desc}\n\
268                          ikm:\t{ikm:?}\n\
269                          salt:\t{salt:?}\n\
270                          info:\t{info:?}\n\
271                          okm:\t{okm:?}\n"
272                 );
273             }
274         }
275     }
276 }
277 
run_test<K: Hkdf>(ikm: &[u8], salt: &[u8], info: &[u8], okm: &[u8]) -> Option<&'static str>278 fn run_test<K: Hkdf>(ikm: &[u8], salt: &[u8], info: &[u8], okm: &[u8]) -> Option<&'static str> {
279     let prk = K::new(Some(salt), ikm);
280     let mut got_okm = vec![0; okm.len()];
281 
282     if prk.expand(info, &mut got_okm).is_err() {
283         return Some("prk expand");
284     }
285     if got_okm != okm {
286         return Some("mismatch in okm");
287     }
288     None
289 }
290