1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #![allow(clippy::indexing_slicing)]
16 
17 use crypto_provider::CryptoProvider;
18 use crypto_provider_default::CryptoProviderImpl;
19 use rand::SeedableRng;
20 use rand::{rngs::StdRng, CryptoRng, RngCore};
21 use ukey2_rs::{HandshakeImplementation, NextProtocol};
22 
23 use crate::{
24     crypto_utils::{decrypt_cbc, encrypt_cbc},
25     java_utils, Aes256Key, D2DConnectionContextV1, D2DHandshakeContext, DeserializeError,
26     InitiatorD2DHandshakeContext, ServerD2DHandshakeContext,
27 };
28 
29 type AesCbcPkcs7Padded = <CryptoProviderImpl as CryptoProvider>::AesCbcPkcs7Padded;
30 
31 #[test]
crypto_test_encrypt_decrypt()32 fn crypto_test_encrypt_decrypt() {
33     let message = b"Hello World!";
34     let key = b"42424242424242424242424242424242";
35     let (ciphertext, iv) =
36         encrypt_cbc::<_, AesCbcPkcs7Padded>(key, message, &mut rand::rngs::StdRng::from_entropy());
37     let decrypt_result = decrypt_cbc::<AesCbcPkcs7Padded>(key, ciphertext.as_slice(), &iv);
38     let ptext = decrypt_result.expect("Decrypt should be successful");
39     assert_eq!(ptext, message.to_vec());
40 }
41 
42 #[test]
crypto_test_encrypt_seeded()43 fn crypto_test_encrypt_seeded() {
44     let message = b"Hello World!";
45     let key = b"42424242424242424242424242424242";
46     let mut rng = MockRng;
47     let (ciphertext, iv) = encrypt_cbc::<_, AesCbcPkcs7Padded>(key, message, &mut rng);
48     // Expected values extracted from the results of the current implementation.
49     // This test makes sure that we don't accidentally change the encryption logic that
50     // causes incompatibility between versions.
51     assert_eq!(&iv, &[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]);
52     assert_eq!(
53         ciphertext,
54         &[20, 59, 195, 101, 11, 208, 245, 128, 247, 196, 81, 80, 158, 77, 174, 61]
55     );
56 }
57 
58 #[test]
crypto_test_decrypt_seeded()59 fn crypto_test_decrypt_seeded() {
60     let iv = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1];
61     let ciphertext = [20, 59, 195, 101, 11, 208, 245, 128, 247, 196, 81, 80, 158, 77, 174, 61];
62     let key = b"42424242424242424242424242424242";
63     let plaintext = decrypt_cbc::<AesCbcPkcs7Padded>(key, &ciphertext, &iv).unwrap();
64     assert_eq!(plaintext, b"Hello World!");
65 }
66 
67 #[test]
decrypt_test_wrong_key()68 fn decrypt_test_wrong_key() {
69     let message = b"Hello World!";
70     let good_key = b"42424242424242424242424242424242";
71     let (ciphertext, iv) = encrypt_cbc::<_, AesCbcPkcs7Padded>(
72         good_key,
73         message,
74         &mut rand::rngs::StdRng::from_entropy(),
75     );
76     let bad_key = b"43434343434343434343434343434343";
77     let decrypt_result = decrypt_cbc::<AesCbcPkcs7Padded>(bad_key, ciphertext.as_slice(), &iv);
78     match decrypt_result {
79         // The padding is valid, but the decrypted value should be bad since the keys don't match
80         Ok(decrypted_bad) => assert_ne!(decrypted_bad, message),
81         // The padding is bad, so it returns an error and is unable to decrypt
82         Err(crypto_provider::aes::cbc::DecryptionError::BadPadding) => (),
83     }
84     let decrypt_result = decrypt_cbc::<AesCbcPkcs7Padded>(good_key, ciphertext.as_slice(), &iv);
85     let ptext = decrypt_result.unwrap();
86     assert_eq!(ptext, message.to_vec());
87 }
88 
run_cbc_handshake() -> (D2DConnectionContextV1, D2DConnectionContextV1)89 fn run_cbc_handshake() -> (D2DConnectionContextV1, D2DConnectionContextV1) {
90     run_handshake_with_rng::<CryptoProviderImpl, _>(
91         rand::rngs::StdRng::from_entropy(),
92         vec![NextProtocol::Aes256CbcHmacSha256],
93     )
94 }
95 
run_gcm_handshake() -> (D2DConnectionContextV1, D2DConnectionContextV1)96 fn run_gcm_handshake() -> (D2DConnectionContextV1, D2DConnectionContextV1) {
97     run_handshake_with_rng::<CryptoProviderImpl, _>(
98         rand::rngs::StdRng::from_entropy(),
99         vec![NextProtocol::Aes256GcmSiv],
100     )
101 }
102 
run_handshake_with_rng<C, R>( mut rng: R, next_protocols: Vec<NextProtocol>, ) -> (D2DConnectionContextV1<R>, D2DConnectionContextV1<R>) where C: CryptoProvider, R: rand::RngCore + rand::CryptoRng + rand::SeedableRng + Send,103 fn run_handshake_with_rng<C, R>(
104     mut rng: R,
105     next_protocols: Vec<NextProtocol>,
106 ) -> (D2DConnectionContextV1<R>, D2DConnectionContextV1<R>)
107 where
108     C: CryptoProvider,
109     R: rand::RngCore + rand::CryptoRng + rand::SeedableRng + Send,
110 {
111     let mut initiator_ctx = InitiatorD2DHandshakeContext::<C, R>::new_impl(
112         HandshakeImplementation::Spec,
113         R::from_rng(&mut rng).unwrap(),
114         next_protocols.clone(),
115     );
116     let mut server_ctx = ServerD2DHandshakeContext::<C, R>::new_impl(
117         HandshakeImplementation::Spec,
118         R::from_rng(&mut rng).unwrap(),
119         &next_protocols,
120     );
121     server_ctx
122         .handle_handshake_message(
123             initiator_ctx.get_next_handshake_message().expect("No message").as_slice(),
124         )
125         .expect("Failed to handle message");
126     initiator_ctx
127         .handle_handshake_message(
128             server_ctx.get_next_handshake_message().expect("No message").as_slice(),
129         )
130         .expect("Failed to handle message");
131     server_ctx
132         .handle_handshake_message(
133             initiator_ctx.get_next_handshake_message().expect("No message").as_slice(),
134         )
135         .expect("Failed to handle message");
136     assert!(initiator_ctx.is_handshake_complete());
137     assert!(server_ctx.is_handshake_complete());
138     (initiator_ctx.to_connection_context().unwrap(), server_ctx.to_connection_context().unwrap())
139 }
140 
141 // TODO: Find a way to inject RNG / generated ephemeral secrets in openSSL and test them here
142 #[cfg(feature = "test_rustcrypto")]
143 #[test]
send_receive_message_seeded()144 fn send_receive_message_seeded() {
145     use crypto_provider_rustcrypto::RustCryptoImpl;
146     let rng = MockRng;
147     let message = b"Hello World!";
148     let (mut init_conn_ctx, mut server_conn_ctx) =
149         run_handshake_with_rng::<RustCryptoImpl<MockRng>, _>(
150             rng,
151             vec![NextProtocol::Aes256CbcHmacSha256],
152         );
153     let encoded =
154         init_conn_ctx.encode_message_to_peer::<RustCryptoImpl<MockRng>, &[u8]>(message, None);
155     // Expected values extracted from the results of the current implementation.
156     // This test makes sure that we don't accidentally change the encryption logic that
157     // causes incompatibility between versions.
158     assert_eq!(
159         encoded,
160         &[
161             10, 64, 10, 28, 8, 1, 16, 2, 42, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
162             50, 4, 8, 13, 16, 1, 18, 32, 23, 58, 102, 24, 40, 222, 59, 212, 182, 181, 96, 44, 57,
163             21, 93, 253, 71, 54, 67, 37, 226, 43, 104, 224, 178, 221, 219, 189, 106, 135, 175, 150,
164             18, 32, 134, 9, 237, 41, 112, 183, 129, 198, 240, 13, 139, 66, 21, 56, 28, 100, 142,
165             240, 155, 52, 242, 11, 211, 132, 175, 230, 15, 241, 208, 185, 15, 105
166         ]
167     );
168     let decoded = server_conn_ctx
169         .decode_message_from_peer::<CryptoProviderImpl, &[u8]>(&encoded, None)
170         .unwrap();
171     assert_eq!(message, &decoded[..]);
172 }
173 
174 #[test]
send_receive_message()175 fn send_receive_message() {
176     let message = b"Hello World!";
177     let (mut init_conn_ctx, mut server_conn_ctx) = run_cbc_handshake();
178     let encoded = init_conn_ctx.encode_message_to_peer::<CryptoProviderImpl, &[u8]>(message, None);
179     let decoded = server_conn_ctx
180         .decode_message_from_peer::<CryptoProviderImpl, &[u8]>(encoded.as_slice(), None);
181     assert_eq!(message.to_vec(), decoded.expect("Decode should be successful"));
182 }
183 
184 #[test]
send_receive_message_gcm()185 fn send_receive_message_gcm() {
186     let message = b"Hello World!";
187     let (mut init_conn_ctx, mut server_conn_ctx) = run_gcm_handshake();
188     let encoded = init_conn_ctx.encode_message_to_peer::<CryptoProviderImpl, &[u8]>(message, None);
189     let decoded = server_conn_ctx
190         .decode_message_from_peer::<CryptoProviderImpl, &[u8]>(encoded.as_slice(), None);
191     assert_eq!(message.to_vec(), decoded.expect("Decode should be successful"));
192 }
193 
194 #[test]
send_receive_message_associated_data()195 fn send_receive_message_associated_data() {
196     let message = b"Hello World!";
197     let (mut init_conn_ctx, mut server_conn_ctx) = run_cbc_handshake();
198     let encoded = init_conn_ctx
199         .encode_message_to_peer::<CryptoProviderImpl, _>(message, Some(b"associated data"));
200     let decoded = server_conn_ctx.decode_message_from_peer::<CryptoProviderImpl, _>(
201         encoded.as_slice(),
202         Some(b"associated data"),
203     );
204     assert_eq!(message.to_vec(), decoded.expect("Decode should be successful"));
205     // Make sure decode fails with missing associated data.
206     let decoded = server_conn_ctx
207         .decode_message_from_peer::<CryptoProviderImpl, &[u8]>(encoded.as_slice(), None);
208     assert!(decoded.is_err());
209     // Make sure decode fails with different associated data.
210     let decoded = server_conn_ctx.decode_message_from_peer::<CryptoProviderImpl, _>(
211         encoded.as_slice(),
212         Some(b"assoc1ated data"),
213     );
214     assert!(decoded.is_err());
215 }
216 
217 #[test]
test_save_restore_session()218 fn test_save_restore_session() {
219     let (init_conn_ctx, server_conn_ctx) = run_cbc_handshake();
220     let init_session = init_conn_ctx.save_session();
221     let server_session = server_conn_ctx.save_session();
222     let mut init_restored_ctx =
223         D2DConnectionContextV1::from_saved_session::<CryptoProviderImpl>(init_session.as_slice())
224             .expect("failed to restore client session");
225     let mut server_restored_ctx =
226         D2DConnectionContextV1::from_saved_session::<CryptoProviderImpl>(server_session.as_slice())
227             .expect("failed to restore server session");
228     let message = b"Hello World!";
229     let encoded =
230         init_restored_ctx.encode_message_to_peer::<CryptoProviderImpl, &[u8]>(message, None);
231     let decoded = server_restored_ctx
232         .decode_message_from_peer::<CryptoProviderImpl, &[u8]>(encoded.as_slice(), None);
233     assert_eq!(message.to_vec(), decoded.expect("Decode should be successful"));
234 }
235 
236 #[test]
test_save_restore_bad_session()237 fn test_save_restore_bad_session() {
238     let (init_conn_ctx, server_conn_ctx) = run_cbc_handshake();
239     let init_session = init_conn_ctx.save_session();
240     let server_session = server_conn_ctx.save_session();
241     let _ =
242         D2DConnectionContextV1::from_saved_session::<CryptoProviderImpl>(init_session.as_slice())
243             .expect("failed to restore client session");
244     let server_restored_ctx =
245         D2DConnectionContextV1::from_saved_session::<CryptoProviderImpl>(&server_session[0..60]);
246     assert_eq!(server_restored_ctx.unwrap_err(), DeserializeError::BadDataLength);
247 }
248 
249 #[test]
test_save_restore_bad_protocol_version()250 fn test_save_restore_bad_protocol_version() {
251     let (init_conn_ctx, server_conn_ctx) = run_cbc_handshake();
252     let init_session = init_conn_ctx.save_session();
253     let mut server_session = server_conn_ctx.save_session();
254     let _ =
255         D2DConnectionContextV1::from_saved_session::<CryptoProviderImpl>(init_session.as_slice())
256             .expect("failed to restore client session");
257     server_session[0] = 0; // Change the protocol version to an invalid one (0)
258     let server_restored_ctx =
259         D2DConnectionContextV1::from_saved_session::<CryptoProviderImpl>(&server_session);
260     assert_eq!(server_restored_ctx.unwrap_err(), DeserializeError::BadProtocolVersion);
261 }
262 
263 #[test]
test_unique_session()264 fn test_unique_session() {
265     let (mut init_conn_ctx, mut server_conn_ctx) = run_cbc_handshake();
266     let init_session = init_conn_ctx.get_session_unique::<CryptoProviderImpl>();
267     let server_session = server_conn_ctx.get_session_unique::<CryptoProviderImpl>();
268     let message = b"Hello World!";
269     let encoded = init_conn_ctx.encode_message_to_peer::<CryptoProviderImpl, &[u8]>(message, None);
270     let decoded = server_conn_ctx
271         .decode_message_from_peer::<CryptoProviderImpl, &[u8]>(encoded.as_slice(), None);
272     assert_eq!(message.to_vec(), decoded.expect("Decode should be successful"));
273     let init_session_after = init_conn_ctx.get_session_unique::<CryptoProviderImpl>();
274     let server_session_after = server_conn_ctx.get_session_unique::<CryptoProviderImpl>();
275     let bad_server_ctx = D2DConnectionContextV1::new::<CryptoProviderImpl>(
276         server_conn_ctx.get_sequence_number_for_decoding(),
277         server_conn_ctx.get_sequence_number_for_encoding(),
278         Aes256Key::default(),
279         Aes256Key::default(),
280         StdRng::from_entropy(),
281         NextProtocol::Aes256CbcHmacSha256,
282     );
283     assert_eq!(init_session, init_session_after);
284     assert_eq!(server_session, server_session_after);
285     assert_eq!(init_session, server_session);
286     assert_ne!(server_session, bad_server_ctx.get_session_unique::<CryptoProviderImpl>());
287 }
288 
289 #[test]
test_java_hashcode()290 fn test_java_hashcode() {
291     assert_eq!(java_utils::hash_code("4".as_bytes()), 83i32);
292     assert_eq!(java_utils::hash_code(&[0x65, 0x47]), 4163i32);
293     assert_eq!(java_utils::hash_code(&[0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78]), 1590192324i32);
294     assert_eq!(
295         java_utils::hash_code(&[0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0xFF]),
296         2051321787
297     );
298 }
299 
300 /// A mock RNG that always returns 1 at each byte. The output from this RNG is
301 /// not changed from call to call to avoid ordering changes in code from
302 /// changing the expected output. The downside is that code that keeps looping
303 /// and generating a new random number until it fits certain criteria will hang
304 /// indefinitely.
305 #[derive(Eq, PartialEq, Clone, Debug)]
306 struct MockRng;
307 
308 impl SeedableRng for MockRng {
309     type Seed = [u8; 0];
310 
from_seed(_seed: Self::Seed) -> Self311     fn from_seed(_seed: Self::Seed) -> Self {
312         Self
313     }
314 }
315 
316 impl CryptoRng for MockRng {}
317 
318 impl RngCore for MockRng {
next_u32(&mut self) -> u32319     fn next_u32(&mut self) -> u32 {
320         let mut buf = [0_u8; 4];
321         self.fill_bytes(&mut buf);
322         u32::from_le_bytes(buf)
323     }
324 
next_u64(&mut self) -> u64325     fn next_u64(&mut self) -> u64 {
326         let mut buf = [0_u8; 8];
327         self.fill_bytes(&mut buf);
328         u64::from_le_bytes(buf)
329     }
330 
fill_bytes(&mut self, dest: &mut [u8])331     fn fill_bytes(&mut self, dest: &mut [u8]) {
332         dest.fill(1);
333     }
334 
try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error>335     fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
336         self.fill_bytes(dest);
337         Ok(())
338     }
339 }
340