1 // Copyright 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 //! An encrypted session implementation which uses
16 //! Noise_NK_X25519_AESGCM_SHA512 and Noise_NNpsk0_X25519_AESGCM_SHA512.
17
18 use crate::crypto::{
19 Commit, Counter, DhPrivateKey, DhPublicKey, HandshakeMessage,
20 HandshakePayload, Hash, SessionCrypto,
21 };
22 use crate::error::{DpeResult, ErrCode};
23 use crate::memory::Message;
24 use core::marker::PhantomData;
25 use log::{debug, error};
26 use noise_protocol::{HandshakeStateBuilder, Hash as NoiseHash, U8Array};
27
28 impl From<noise_protocol::Error> for ErrCode {
from(_err: noise_protocol::Error) -> Self29 fn from(_err: noise_protocol::Error) -> Self {
30 ErrCode::InvalidArgument
31 }
32 }
33
34 impl<NoiseHash> From<&NoiseHash> for Hash
35 where
36 NoiseHash: U8Array,
37 {
from(value: &NoiseHash) -> Self38 fn from(value: &NoiseHash) -> Self {
39 // The Noise hash size may not match HASH_SIZE.
40 Hash::from_slice_infallible(value.as_slice())
41 }
42 }
43
44 /// A cipher state type that can be used as a
45 /// [`SessionCipherState`](crate::crypto::SessionCrypto::SessionCipherState).
46 pub struct NoiseCipherState<C: noise_protocol::Cipher> {
47 k: C::Key,
48 n: u64,
49 n_staged: u64,
50 }
51
52 impl<C: noise_protocol::Cipher> Clone for NoiseCipherState<C> {
clone(&self) -> Self53 fn clone(&self) -> Self {
54 Self { k: self.k.clone(), n: self.n, n_staged: self.n_staged }
55 }
56 }
57
58 impl<C: noise_protocol::Cipher> Default for NoiseCipherState<C> {
default() -> Self59 fn default() -> Self {
60 Self { k: C::Key::new(), n: 0, n_staged: 0 }
61 }
62 }
63
64 impl<C: noise_protocol::Cipher> core::fmt::Debug for NoiseCipherState<C> {
fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result65 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
66 write!(f, "k: redacted, n: {}", self.n)?;
67 Ok(())
68 }
69 }
70
71 impl<C: noise_protocol::Cipher> core::hash::Hash for NoiseCipherState<C> {
hash<H: core::hash::Hasher>(&self, state: &mut H)72 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
73 self.k.as_slice().hash(state);
74 self.n.hash(state);
75 self.n_staged.hash(state);
76 }
77 }
78
79 #[cfg(test)]
80 impl<C: noise_protocol::Cipher> PartialEq for NoiseCipherState<C> {
eq(&self, other: &Self) -> bool81 fn eq(&self, other: &Self) -> bool {
82 self.k.as_slice() == other.k.as_slice()
83 && self.n == other.n
84 && self.n_staged == other.n_staged
85 }
86 }
87
88 #[cfg(test)]
89 impl<C: noise_protocol::Cipher> Eq for NoiseCipherState<C> {}
90
91 impl<C: noise_protocol::Cipher> Counter for NoiseCipherState<C> {
n(&self) -> u6492 fn n(&self) -> u64 {
93 self.n
94 }
set_n(&mut self, n: u64)95 fn set_n(&mut self, n: u64) {
96 self.n = n;
97 }
98 }
99
100 impl<C: noise_protocol::Cipher> Commit for NoiseCipherState<C> {
101 // Called when an encrypted message is finalized to commit the new cipher
102 // state.
commit(&mut self)103 fn commit(&mut self) {
104 self.n = self.n_staged;
105 }
106 }
107
108 impl<C: noise_protocol::Cipher> From<&noise_protocol::CipherState<C>>
109 for NoiseCipherState<C>
110 {
from(cs: &noise_protocol::CipherState<C>) -> Self111 fn from(cs: &noise_protocol::CipherState<C>) -> Self {
112 let (key, counter) = cs.clone().extract();
113 NoiseCipherState { k: key, n: counter, n_staged: counter }
114 }
115 }
116
117 /// Returns the public key corresponding to a given `dh_private_key`.
get_dh_public_key<D: noise_protocol::DH>( dh_private_key: &DhPrivateKey, ) -> DpeResult<DhPublicKey>118 pub fn get_dh_public_key<D: noise_protocol::DH>(
119 dh_private_key: &DhPrivateKey,
120 ) -> DpeResult<DhPublicKey> {
121 DhPublicKey::from_slice(
122 D::pubkey(&D::Key::from_slice(dh_private_key.as_slice())).as_slice(),
123 )
124 }
125
126 /// A trait representing [`NoiseSessionCrypto`] dependencies.
127 pub trait NoiseCryptoDeps {
128 /// Cipher type
129 type Cipher: noise_protocol::Cipher;
130 /// DH type
131 type DH: noise_protocol::DH;
132 /// Hash type
133 type Hash: noise_protocol::Hash;
134 }
135
136 /// A Noise implementation of the [`SessionCrypto`] trait.
137 pub struct NoiseSessionCrypto<D: NoiseCryptoDeps> {
138 #[allow(dead_code)]
139 phantom: PhantomData<D>,
140 }
141
142 impl<D> Clone for NoiseSessionCrypto<D>
143 where
144 D: NoiseCryptoDeps,
145 {
clone(&self) -> Self146 fn clone(&self) -> Self {
147 Self { phantom: Default::default() }
148 }
149 }
150
151 impl<D> Default for NoiseSessionCrypto<D>
152 where
153 D: NoiseCryptoDeps,
154 {
default() -> Self155 fn default() -> Self {
156 Self { phantom: Default::default() }
157 }
158 }
159
160 impl<D> core::fmt::Debug for NoiseSessionCrypto<D>
161 where
162 D: NoiseCryptoDeps,
163 {
fmt(&self, _: &mut core::fmt::Formatter<'_>) -> core::fmt::Result164 fn fmt(&self, _: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
165 Ok(())
166 }
167 }
168
169 impl<D> core::hash::Hash for NoiseSessionCrypto<D>
170 where
171 D: NoiseCryptoDeps,
172 {
hash<Hr: core::hash::Hasher>(&self, _: &mut Hr)173 fn hash<Hr: core::hash::Hasher>(&self, _: &mut Hr) {}
174 }
175
176 impl<D> PartialEq for NoiseSessionCrypto<D>
177 where
178 D: NoiseCryptoDeps,
179 {
eq(&self, _: &Self) -> bool180 fn eq(&self, _: &Self) -> bool {
181 true
182 }
183 }
184
185 impl<D> Eq for NoiseSessionCrypto<D> where D: NoiseCryptoDeps {}
186
187 impl<D> SessionCrypto for NoiseSessionCrypto<D>
188 where
189 D: NoiseCryptoDeps,
190 {
191 type SessionCipherState = NoiseCipherState<D::Cipher>;
192
193 /// Implements the responder role of a Noise_NK handshake.
new_session_handshake( static_dh_key: &DhPrivateKey, initiator_handshake: &HandshakeMessage, payload: &HandshakePayload, responder_handshake: &mut HandshakeMessage, decrypt_cipher_state: &mut NoiseCipherState<D::Cipher>, encrypt_cipher_state: &mut NoiseCipherState<D::Cipher>, psk_seed: &mut Hash, ) -> DpeResult<()>194 fn new_session_handshake(
195 static_dh_key: &DhPrivateKey,
196 initiator_handshake: &HandshakeMessage,
197 payload: &HandshakePayload,
198 responder_handshake: &mut HandshakeMessage,
199 decrypt_cipher_state: &mut NoiseCipherState<D::Cipher>,
200 encrypt_cipher_state: &mut NoiseCipherState<D::Cipher>,
201 psk_seed: &mut Hash,
202 ) -> DpeResult<()> {
203 #[allow(unused_results)]
204 let mut handshake: noise_protocol::HandshakeState<
205 D::DH,
206 D::Cipher,
207 D::Hash,
208 > = {
209 let mut builder = HandshakeStateBuilder::new();
210 builder.set_pattern(noise_protocol::patterns::noise_nk());
211 builder.set_is_initiator(false);
212 builder.set_prologue(&[]);
213 builder.set_s(<D::DH as noise_protocol::DH>::Key::from_slice(
214 static_dh_key.as_slice(),
215 ));
216 builder.build_handshake_state()
217 };
218 handshake.read_message(initiator_handshake.as_slice(), &mut [])?;
219 handshake.write_message(
220 payload.as_slice(),
221 responder_handshake.as_mut_sized(
222 handshake.get_next_message_overhead() + payload.len(),
223 )?,
224 )?;
225 assert!(handshake.completed());
226 let ciphers = handshake.get_ciphers();
227 *decrypt_cipher_state = (&ciphers.0).into();
228 *encrypt_cipher_state = (&ciphers.1).into();
229 debug!("get_hash");
230 *psk_seed = Hash::from_slice(handshake.get_hash())?;
231 Ok(())
232 }
233
234 /// Implements the responder role of a Noise_NNpsk0 handshake.
derive_session_handshake( psk: &Hash, initiator_handshake: &HandshakeMessage, payload: &HandshakePayload, responder_handshake: &mut HandshakeMessage, decrypt_cipher_state: &mut NoiseCipherState<D::Cipher>, encrypt_cipher_state: &mut NoiseCipherState<D::Cipher>, psk_seed: &mut Hash, ) -> DpeResult<()>235 fn derive_session_handshake(
236 psk: &Hash,
237 initiator_handshake: &HandshakeMessage,
238 payload: &HandshakePayload,
239 responder_handshake: &mut HandshakeMessage,
240 decrypt_cipher_state: &mut NoiseCipherState<D::Cipher>,
241 encrypt_cipher_state: &mut NoiseCipherState<D::Cipher>,
242 psk_seed: &mut Hash,
243 ) -> DpeResult<()> {
244 #[allow(unused_results)]
245 let mut handshake: noise_protocol::HandshakeState<
246 D::DH,
247 D::Cipher,
248 D::Hash,
249 > = {
250 let mut builder = HandshakeStateBuilder::new();
251 builder.set_pattern(noise_protocol::patterns::noise_nn_psk0());
252 builder.set_is_initiator(false);
253 builder.set_prologue(&[]);
254 builder.build_handshake_state()
255 };
256 handshake
257 .push_psk(psk.as_slice().get(..32).ok_or(ErrCode::InternalError)?);
258 handshake.read_message(initiator_handshake.as_slice(), &mut [])?;
259 handshake.write_message(
260 payload.as_slice(),
261 responder_handshake.as_mut_sized(
262 handshake.get_next_message_overhead() + payload.len(),
263 )?,
264 )?;
265 let ciphers = handshake.get_ciphers();
266 *decrypt_cipher_state = (&ciphers.0).into();
267 *encrypt_cipher_state = (&ciphers.1).into();
268 *psk_seed = Hash::from_slice(handshake.get_hash())?;
269 Ok(())
270 }
271
272 /// Encrypts a Noise transport message in place.
session_encrypt( cipher_state: &mut NoiseCipherState<D::Cipher>, in_place_buffer: &mut Message, ) -> DpeResult<()>273 fn session_encrypt(
274 cipher_state: &mut NoiseCipherState<D::Cipher>,
275 in_place_buffer: &mut Message,
276 ) -> DpeResult<()> {
277 let mut cs = noise_protocol::CipherState::<D::Cipher>::new(
278 cipher_state.k.as_slice(),
279 cipher_state.n,
280 );
281 let plaintext_len = in_place_buffer.len();
282 let _ = cs.encrypt_in_place(
283 in_place_buffer.as_mut_sized(
284 plaintext_len
285 + <D::Cipher as noise_protocol::Cipher>::tag_len(),
286 )?,
287 plaintext_len,
288 );
289 // Encrypting a message is usually not the final step in preparing
290 // the message for transport. If a subsequent step fails, it is
291 // better for 'n' to remain unchanged so we don't get out of sync.
292 (_, cipher_state.n_staged) = cs.extract();
293 Ok(())
294 }
295
296 /// Decrypts a Noise transport message in place.
session_decrypt( cipher_state: &mut NoiseCipherState<D::Cipher>, in_place_buffer: &mut Message, ) -> DpeResult<()>297 fn session_decrypt(
298 cipher_state: &mut NoiseCipherState<D::Cipher>,
299 in_place_buffer: &mut Message,
300 ) -> DpeResult<()> {
301 let mut cs = noise_protocol::CipherState::<D::Cipher>::new(
302 cipher_state.k.as_slice(),
303 cipher_state.n,
304 );
305 let ciphertext_len = in_place_buffer.len();
306 let plaintext_len = match cs
307 .decrypt_in_place(in_place_buffer.vec.as_mut(), ciphertext_len)
308 {
309 Ok(length) => length,
310 _ => {
311 error!("Session decrypt failed");
312 return Err(ErrCode::InvalidCommand);
313 }
314 };
315 in_place_buffer.vec.truncate(plaintext_len);
316 (_, cipher_state.n) = cs.extract();
317 Ok(())
318 }
319
320 /// Derives a responder-side PSK.
derive_psk_from_session( psk_seed: &Hash, decrypt_cipher_state: &NoiseCipherState<D::Cipher>, encrypt_cipher_state: &NoiseCipherState<D::Cipher>, ) -> DpeResult<Hash>321 fn derive_psk_from_session(
322 psk_seed: &Hash,
323 decrypt_cipher_state: &NoiseCipherState<D::Cipher>,
324 encrypt_cipher_state: &NoiseCipherState<D::Cipher>,
325 ) -> DpeResult<Hash> {
326 let mut hasher: D::Hash = Default::default();
327 hasher.input(psk_seed.as_slice());
328 // Use the decrypt state as it was before we decrypted the current
329 // command message. This allows clients to compute the PSK using
330 // the cipher states as they are before the client sends the
331 // command.
332 hasher.input(&(decrypt_cipher_state.n() - 1).to_le_bytes());
333 hasher.input(&encrypt_cipher_state.n().to_le_bytes());
334 Ok((&hasher.result()).into())
335 }
336 }
337
338 /// A SessionClient implements the initiator side of an encrypted session. A
339 /// DPE does not use this itself, it is useful for clients and testing.
340 pub struct SessionClient<D>
341 where
342 D: NoiseCryptoDeps,
343 {
344 handshake_state:
345 Option<noise_protocol::HandshakeState<D::DH, D::Cipher, D::Hash>>,
346 /// Cipher state for encrypting messages to a DPE.
347 pub encrypt_cipher_state: NoiseCipherState<D::Cipher>,
348 /// Cipher state for decrypting messages from a DPE.
349 pub decrypt_cipher_state: NoiseCipherState<D::Cipher>,
350 /// PSK seed for deriving sessions. See [`derive_psk`].
351 ///
352 /// [`derive_psk`]: #method.derive_psk
353 pub psk_seed: Hash,
354 }
355
356 impl<D> Clone for SessionClient<D>
357 where
358 D: NoiseCryptoDeps,
359 {
clone(&self) -> Self360 fn clone(&self) -> Self {
361 Self {
362 handshake_state: self.handshake_state.clone(),
363 encrypt_cipher_state: self.encrypt_cipher_state.clone(),
364 decrypt_cipher_state: self.decrypt_cipher_state.clone(),
365 psk_seed: self.psk_seed.clone(),
366 }
367 }
368 }
369
370 impl<D> Default for SessionClient<D>
371 where
372 D: NoiseCryptoDeps,
373 {
default() -> Self374 fn default() -> Self {
375 Self::new()
376 }
377 }
378
379 impl<D> core::fmt::Debug for SessionClient<D>
380 where
381 D: NoiseCryptoDeps,
382 {
fmt(&self, _: &mut core::fmt::Formatter<'_>) -> core::fmt::Result383 fn fmt(&self, _: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
384 Ok(())
385 }
386 }
387
388 impl<D> SessionClient<D>
389 where
390 D: NoiseCryptoDeps,
391 {
392 /// Creates a new SessionClient instance. Set up by starting and finishing a
393 /// handshake.
new() -> Self394 pub fn new() -> Self {
395 Self {
396 handshake_state: Default::default(),
397 encrypt_cipher_state: Default::default(),
398 decrypt_cipher_state: Default::default(),
399 psk_seed: Default::default(),
400 }
401 }
402
403 /// Starts a handshake using a known `public_key` and returns a message that
404 /// works with the DPE OpenSession command.
start_handshake_with_known_public_key( &mut self, public_key: &DhPublicKey, ) -> DpeResult<HandshakeMessage>405 pub fn start_handshake_with_known_public_key(
406 &mut self,
407 public_key: &DhPublicKey,
408 ) -> DpeResult<HandshakeMessage> {
409 #[allow(unused_results)]
410 let mut handshake_state = {
411 let mut builder = HandshakeStateBuilder::new();
412 builder.set_pattern(noise_protocol::patterns::noise_nk());
413 builder.set_is_initiator(true);
414 builder.set_prologue(&[]);
415 builder.set_rs(<D::DH as noise_protocol::DH>::Pubkey::from_slice(
416 public_key.as_slice(),
417 ));
418 builder.build_handshake_state()
419 };
420 let mut message = HandshakeMessage::new();
421 handshake_state.write_message(
422 &[],
423 message
424 .as_mut_sized(handshake_state.get_next_message_overhead())?,
425 )?;
426 self.handshake_state = Some(handshake_state);
427 Ok(message)
428 }
429
430 /// Starts a handshake using a `psk` and returns a message that works with
431 /// the DPE DeriveContext command. Use [`derive_psk`] to obtain this value
432 /// from an existing session.
433 ///
434 /// [`derive_psk`]: #method.derive_psk
start_handshake_with_psk( &mut self, psk: &Hash, ) -> DpeResult<HandshakeMessage>435 pub fn start_handshake_with_psk(
436 &mut self,
437 psk: &Hash,
438 ) -> DpeResult<HandshakeMessage> {
439 #[allow(unused_results)]
440 let mut handshake_state = {
441 let mut builder = HandshakeStateBuilder::new();
442 builder.set_pattern(noise_protocol::patterns::noise_nn_psk0());
443 builder.set_is_initiator(true);
444 builder.set_prologue(&[]);
445 builder.build_handshake_state()
446 };
447 handshake_state
448 .push_psk(psk.as_slice().get(..32).ok_or(ErrCode::InternalError)?);
449 let mut message = HandshakeMessage::new();
450 handshake_state.write_message(
451 &[],
452 message
453 .as_mut_sized(handshake_state.get_next_message_overhead())?,
454 )?;
455 self.handshake_state = Some(handshake_state);
456 Ok(message)
457 }
458
459 /// Finishes a handshake started using one of the start_handshake_* methods.
460 /// On success, returns the handshake payload from the responder and sets up
461 /// internal state for subsequent calls to encrypt and decrypt.
finish_handshake( &mut self, responder_handshake: &HandshakeMessage, ) -> DpeResult<HandshakePayload>462 pub fn finish_handshake(
463 &mut self,
464 responder_handshake: &HandshakeMessage,
465 ) -> DpeResult<HandshakePayload> {
466 match self.handshake_state {
467 None => Err(ErrCode::InvalidArgument),
468 Some(ref mut handshake) => {
469 let mut payload = HandshakePayload::new();
470 handshake.read_message(
471 responder_handshake.as_slice(),
472 payload.as_mut_sized(
473 responder_handshake.len()
474 - handshake.get_next_message_overhead(),
475 )?,
476 )?;
477 let ciphers = handshake.get_ciphers();
478 self.encrypt_cipher_state = (&ciphers.0).into();
479 self.decrypt_cipher_state = (&ciphers.1).into();
480 self.psk_seed = Hash::from_slice(handshake.get_hash())?;
481 Ok(payload)
482 }
483 }
484 }
485
486 /// Derives a PSK from the current session.
derive_psk(&self) -> Hash487 pub fn derive_psk(&self) -> Hash {
488 // Note this is from a client perspective so the counters are hashed
489 // encrypt first and unmodified from their current state. A DPE will
490 // reverse the order and decrement the first counter in order to derive
491 // the same value (see derive_psk_from_session).
492 let mut hasher: D::Hash = Default::default();
493 hasher.input(self.psk_seed.as_slice());
494 hasher.input(&self.encrypt_cipher_state.n().to_le_bytes());
495 hasher.input(&self.decrypt_cipher_state.n().to_le_bytes());
496 (&hasher.result()).into()
497 }
498
499 /// Encrypts a message to send to a DPE and commits cipher state changes.
encrypt(&mut self, in_place_buffer: &mut Message) -> DpeResult<()>500 pub fn encrypt(&mut self, in_place_buffer: &mut Message) -> DpeResult<()> {
501 NoiseSessionCrypto::<D>::session_encrypt(
502 &mut self.encrypt_cipher_state,
503 in_place_buffer,
504 )?;
505 self.encrypt_cipher_state.commit();
506 Ok(())
507 }
508
509 /// Decrypts a message from a DPE.
decrypt(&mut self, in_place_buffer: &mut Message) -> DpeResult<()>510 pub fn decrypt(&mut self, in_place_buffer: &mut Message) -> DpeResult<()> {
511 NoiseSessionCrypto::<D>::session_decrypt(
512 &mut self.decrypt_cipher_state,
513 in_place_buffer,
514 )
515 }
516 }
517
518 #[cfg(test)]
519 mod tests {
520 use super::*;
521
522 struct DepsForTesting {}
523 impl NoiseCryptoDeps for DepsForTesting {
524 type Cipher = noise_rust_crypto::Aes256Gcm;
525 type DH = noise_rust_crypto::X25519;
526 type Hash = noise_rust_crypto::Sha512;
527 }
528
529 type SessionCryptoForTesting = NoiseSessionCrypto<DepsForTesting>;
530
531 type SessionClientForTesting = SessionClient<DepsForTesting>;
532
533 type CipherStateForTesting = NoiseCipherState<noise_rust_crypto::Aes256Gcm>;
534
535 #[test]
end_to_end_session()536 fn end_to_end_session() {
537 let mut client = SessionClientForTesting::new();
538 let dh_key: DhPrivateKey = Default::default();
539 let dh_public_key = get_dh_public_key::<
540 <DepsForTesting as NoiseCryptoDeps>::DH,
541 >(&dh_key)
542 .unwrap();
543 let handshake1 = client
544 .start_handshake_with_known_public_key(&dh_public_key)
545 .unwrap();
546 let mut dpe_decrypt_cs: CipherStateForTesting = Default::default();
547 let mut dpe_encrypt_cs: CipherStateForTesting = Default::default();
548 let mut psk_seed = Default::default();
549 let mut handshake2 = Default::default();
550 let payload = HandshakePayload::from_slice("pay".as_bytes()).unwrap();
551 SessionCryptoForTesting::new_session_handshake(
552 &dh_key,
553 &handshake1,
554 &payload,
555 &mut handshake2,
556 &mut dpe_decrypt_cs,
557 &mut dpe_encrypt_cs,
558 &mut psk_seed,
559 )
560 .unwrap();
561 assert_eq!(payload, client.finish_handshake(&handshake2).unwrap());
562
563 // Check that the session works.
564 let mut buffer = Message::from_slice("message".as_bytes()).unwrap();
565 client.encrypt(&mut buffer).unwrap();
566 SessionCryptoForTesting::session_decrypt(
567 &mut dpe_decrypt_cs,
568 &mut buffer,
569 )
570 .unwrap();
571 assert_eq!("message".as_bytes(), buffer.as_slice());
572 SessionCryptoForTesting::session_encrypt(
573 &mut dpe_encrypt_cs,
574 &mut buffer,
575 )
576 .unwrap();
577 dpe_encrypt_cs.commit();
578 client.decrypt(&mut buffer).unwrap();
579 assert_eq!("message".as_bytes(), buffer.as_slice());
580
581 // Do it again to check session state still works.
582 client.encrypt(&mut buffer).unwrap();
583 SessionCryptoForTesting::session_decrypt(
584 &mut dpe_decrypt_cs,
585 &mut buffer,
586 )
587 .unwrap();
588 assert_eq!("message".as_bytes(), buffer.as_slice());
589 SessionCryptoForTesting::session_encrypt(
590 &mut dpe_encrypt_cs,
591 &mut buffer,
592 )
593 .unwrap();
594 dpe_encrypt_cs.commit();
595 client.decrypt(&mut buffer).unwrap();
596 assert_eq!("message".as_bytes(), buffer.as_slice());
597 }
598
599 #[test]
derived_session()600 fn derived_session() {
601 // Set up a session from which to derive.
602 let mut client = SessionClientForTesting::new();
603 let dh_key: DhPrivateKey = Default::default();
604 let dh_public_key = get_dh_public_key::<
605 <DepsForTesting as NoiseCryptoDeps>::DH,
606 >(&dh_key)
607 .unwrap();
608 let handshake1 = client
609 .start_handshake_with_known_public_key(&dh_public_key)
610 .unwrap();
611 let mut dpe_decrypt_cs = Default::default();
612 let mut dpe_encrypt_cs = Default::default();
613 let mut psk_seed = Default::default();
614 let mut handshake2 = Default::default();
615 let payload = HandshakePayload::from_slice("pay".as_bytes()).unwrap();
616 SessionCryptoForTesting::new_session_handshake(
617 &dh_key,
618 &handshake1,
619 &payload,
620 &mut handshake2,
621 &mut dpe_decrypt_cs,
622 &mut dpe_encrypt_cs,
623 &mut psk_seed,
624 )
625 .unwrap();
626 assert_eq!(payload, client.finish_handshake(&handshake2).unwrap());
627
628 // Derive a second session.
629 let mut client2 = SessionClientForTesting::new();
630 let client_psk = client.derive_psk();
631 // Simulate the session state after command decryption on the DPE side
632 // as expected by the DPE PSK logic.
633 let mut buffer = Message::from_slice("message".as_bytes()).unwrap();
634 client.encrypt(&mut buffer).unwrap();
635 SessionCryptoForTesting::session_decrypt(
636 &mut dpe_decrypt_cs,
637 &mut buffer,
638 )
639 .unwrap();
640 let dpe_psk = SessionCryptoForTesting::derive_psk_from_session(
641 &psk_seed,
642 &dpe_decrypt_cs,
643 &dpe_encrypt_cs,
644 )
645 .unwrap();
646 let handshake1 = client2.start_handshake_with_psk(&client_psk).unwrap();
647 let mut dpe_decrypt_cs2 = Default::default();
648 let mut dpe_encrypt_cs2 = Default::default();
649 let mut psk_seed2 = Default::default();
650 SessionCryptoForTesting::derive_session_handshake(
651 &dpe_psk,
652 &handshake1,
653 &payload,
654 &mut handshake2,
655 &mut dpe_decrypt_cs2,
656 &mut dpe_encrypt_cs2,
657 &mut psk_seed2,
658 )
659 .unwrap();
660 assert_eq!(payload, client2.finish_handshake(&handshake2).unwrap());
661
662 // Check that the second session works.
663 let mut buffer = Message::from_slice("message".as_bytes()).unwrap();
664 client2.encrypt(&mut buffer).unwrap();
665 SessionCryptoForTesting::session_decrypt(
666 &mut dpe_decrypt_cs2,
667 &mut buffer,
668 )
669 .unwrap();
670 assert_eq!("message".as_bytes(), buffer.as_slice());
671 SessionCryptoForTesting::session_encrypt(
672 &mut dpe_encrypt_cs2,
673 &mut buffer,
674 )
675 .unwrap();
676 dpe_encrypt_cs2.commit();
677 client2.decrypt(&mut buffer).unwrap();
678 assert_eq!("message".as_bytes(), buffer.as_slice());
679
680 // Check that the first session also still works.
681 let mut buffer = Message::from_slice("message".as_bytes()).unwrap();
682 client.encrypt(&mut buffer).unwrap();
683 SessionCryptoForTesting::session_decrypt(
684 &mut dpe_decrypt_cs,
685 &mut buffer,
686 )
687 .unwrap();
688 assert_eq!("message".as_bytes(), buffer.as_slice());
689 SessionCryptoForTesting::session_encrypt(
690 &mut dpe_encrypt_cs,
691 &mut buffer,
692 )
693 .unwrap();
694 dpe_encrypt_cs.commit();
695 client.decrypt(&mut buffer).unwrap();
696 assert_eq!("message".as_bytes(), buffer.as_slice());
697 }
698 }
699