xref: /aosp_15_r20/external/open-dice/dpe-rs/src/noise.rs (revision 60b67249c2e226f42f35cc6cfe66c6048e0bae6b)
1 // Copyright 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 //! An encrypted session implementation which uses
16 //! Noise_NK_X25519_AESGCM_SHA512 and Noise_NNpsk0_X25519_AESGCM_SHA512.
17 
18 use crate::crypto::{
19     Commit, Counter, DhPrivateKey, DhPublicKey, HandshakeMessage,
20     HandshakePayload, Hash, SessionCrypto,
21 };
22 use crate::error::{DpeResult, ErrCode};
23 use crate::memory::Message;
24 use core::marker::PhantomData;
25 use log::{debug, error};
26 use noise_protocol::{HandshakeStateBuilder, Hash as NoiseHash, U8Array};
27 
28 impl From<noise_protocol::Error> for ErrCode {
from(_err: noise_protocol::Error) -> Self29     fn from(_err: noise_protocol::Error) -> Self {
30         ErrCode::InvalidArgument
31     }
32 }
33 
34 impl<NoiseHash> From<&NoiseHash> for Hash
35 where
36     NoiseHash: U8Array,
37 {
from(value: &NoiseHash) -> Self38     fn from(value: &NoiseHash) -> Self {
39         // The Noise hash size may not match HASH_SIZE.
40         Hash::from_slice_infallible(value.as_slice())
41     }
42 }
43 
44 /// A cipher state type that can be used as a
45 /// [`SessionCipherState`](crate::crypto::SessionCrypto::SessionCipherState).
46 pub struct NoiseCipherState<C: noise_protocol::Cipher> {
47     k: C::Key,
48     n: u64,
49     n_staged: u64,
50 }
51 
52 impl<C: noise_protocol::Cipher> Clone for NoiseCipherState<C> {
clone(&self) -> Self53     fn clone(&self) -> Self {
54         Self { k: self.k.clone(), n: self.n, n_staged: self.n_staged }
55     }
56 }
57 
58 impl<C: noise_protocol::Cipher> Default for NoiseCipherState<C> {
default() -> Self59     fn default() -> Self {
60         Self { k: C::Key::new(), n: 0, n_staged: 0 }
61     }
62 }
63 
64 impl<C: noise_protocol::Cipher> core::fmt::Debug for NoiseCipherState<C> {
fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result65     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
66         write!(f, "k: redacted, n: {}", self.n)?;
67         Ok(())
68     }
69 }
70 
71 impl<C: noise_protocol::Cipher> core::hash::Hash for NoiseCipherState<C> {
hash<H: core::hash::Hasher>(&self, state: &mut H)72     fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
73         self.k.as_slice().hash(state);
74         self.n.hash(state);
75         self.n_staged.hash(state);
76     }
77 }
78 
79 #[cfg(test)]
80 impl<C: noise_protocol::Cipher> PartialEq for NoiseCipherState<C> {
eq(&self, other: &Self) -> bool81     fn eq(&self, other: &Self) -> bool {
82         self.k.as_slice() == other.k.as_slice()
83             && self.n == other.n
84             && self.n_staged == other.n_staged
85     }
86 }
87 
88 #[cfg(test)]
89 impl<C: noise_protocol::Cipher> Eq for NoiseCipherState<C> {}
90 
91 impl<C: noise_protocol::Cipher> Counter for NoiseCipherState<C> {
n(&self) -> u6492     fn n(&self) -> u64 {
93         self.n
94     }
set_n(&mut self, n: u64)95     fn set_n(&mut self, n: u64) {
96         self.n = n;
97     }
98 }
99 
100 impl<C: noise_protocol::Cipher> Commit for NoiseCipherState<C> {
101     // Called when an encrypted message is finalized to commit the new cipher
102     // state.
commit(&mut self)103     fn commit(&mut self) {
104         self.n = self.n_staged;
105     }
106 }
107 
108 impl<C: noise_protocol::Cipher> From<&noise_protocol::CipherState<C>>
109     for NoiseCipherState<C>
110 {
from(cs: &noise_protocol::CipherState<C>) -> Self111     fn from(cs: &noise_protocol::CipherState<C>) -> Self {
112         let (key, counter) = cs.clone().extract();
113         NoiseCipherState { k: key, n: counter, n_staged: counter }
114     }
115 }
116 
117 /// Returns the public key corresponding to a given `dh_private_key`.
get_dh_public_key<D: noise_protocol::DH>( dh_private_key: &DhPrivateKey, ) -> DpeResult<DhPublicKey>118 pub fn get_dh_public_key<D: noise_protocol::DH>(
119     dh_private_key: &DhPrivateKey,
120 ) -> DpeResult<DhPublicKey> {
121     DhPublicKey::from_slice(
122         D::pubkey(&D::Key::from_slice(dh_private_key.as_slice())).as_slice(),
123     )
124 }
125 
126 /// A trait representing [`NoiseSessionCrypto`] dependencies.
127 pub trait NoiseCryptoDeps {
128     /// Cipher type
129     type Cipher: noise_protocol::Cipher;
130     /// DH type
131     type DH: noise_protocol::DH;
132     /// Hash type
133     type Hash: noise_protocol::Hash;
134 }
135 
136 /// A Noise implementation of the [`SessionCrypto`] trait.
137 pub struct NoiseSessionCrypto<D: NoiseCryptoDeps> {
138     #[allow(dead_code)]
139     phantom: PhantomData<D>,
140 }
141 
142 impl<D> Clone for NoiseSessionCrypto<D>
143 where
144     D: NoiseCryptoDeps,
145 {
clone(&self) -> Self146     fn clone(&self) -> Self {
147         Self { phantom: Default::default() }
148     }
149 }
150 
151 impl<D> Default for NoiseSessionCrypto<D>
152 where
153     D: NoiseCryptoDeps,
154 {
default() -> Self155     fn default() -> Self {
156         Self { phantom: Default::default() }
157     }
158 }
159 
160 impl<D> core::fmt::Debug for NoiseSessionCrypto<D>
161 where
162     D: NoiseCryptoDeps,
163 {
fmt(&self, _: &mut core::fmt::Formatter<'_>) -> core::fmt::Result164     fn fmt(&self, _: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
165         Ok(())
166     }
167 }
168 
169 impl<D> core::hash::Hash for NoiseSessionCrypto<D>
170 where
171     D: NoiseCryptoDeps,
172 {
hash<Hr: core::hash::Hasher>(&self, _: &mut Hr)173     fn hash<Hr: core::hash::Hasher>(&self, _: &mut Hr) {}
174 }
175 
176 impl<D> PartialEq for NoiseSessionCrypto<D>
177 where
178     D: NoiseCryptoDeps,
179 {
eq(&self, _: &Self) -> bool180     fn eq(&self, _: &Self) -> bool {
181         true
182     }
183 }
184 
185 impl<D> Eq for NoiseSessionCrypto<D> where D: NoiseCryptoDeps {}
186 
187 impl<D> SessionCrypto for NoiseSessionCrypto<D>
188 where
189     D: NoiseCryptoDeps,
190 {
191     type SessionCipherState = NoiseCipherState<D::Cipher>;
192 
193     /// Implements the responder role of a Noise_NK handshake.
new_session_handshake( static_dh_key: &DhPrivateKey, initiator_handshake: &HandshakeMessage, payload: &HandshakePayload, responder_handshake: &mut HandshakeMessage, decrypt_cipher_state: &mut NoiseCipherState<D::Cipher>, encrypt_cipher_state: &mut NoiseCipherState<D::Cipher>, psk_seed: &mut Hash, ) -> DpeResult<()>194     fn new_session_handshake(
195         static_dh_key: &DhPrivateKey,
196         initiator_handshake: &HandshakeMessage,
197         payload: &HandshakePayload,
198         responder_handshake: &mut HandshakeMessage,
199         decrypt_cipher_state: &mut NoiseCipherState<D::Cipher>,
200         encrypt_cipher_state: &mut NoiseCipherState<D::Cipher>,
201         psk_seed: &mut Hash,
202     ) -> DpeResult<()> {
203         #[allow(unused_results)]
204         let mut handshake: noise_protocol::HandshakeState<
205             D::DH,
206             D::Cipher,
207             D::Hash,
208         > = {
209             let mut builder = HandshakeStateBuilder::new();
210             builder.set_pattern(noise_protocol::patterns::noise_nk());
211             builder.set_is_initiator(false);
212             builder.set_prologue(&[]);
213             builder.set_s(<D::DH as noise_protocol::DH>::Key::from_slice(
214                 static_dh_key.as_slice(),
215             ));
216             builder.build_handshake_state()
217         };
218         handshake.read_message(initiator_handshake.as_slice(), &mut [])?;
219         handshake.write_message(
220             payload.as_slice(),
221             responder_handshake.as_mut_sized(
222                 handshake.get_next_message_overhead() + payload.len(),
223             )?,
224         )?;
225         assert!(handshake.completed());
226         let ciphers = handshake.get_ciphers();
227         *decrypt_cipher_state = (&ciphers.0).into();
228         *encrypt_cipher_state = (&ciphers.1).into();
229         debug!("get_hash");
230         *psk_seed = Hash::from_slice(handshake.get_hash())?;
231         Ok(())
232     }
233 
234     /// Implements the responder role of a Noise_NNpsk0 handshake.
derive_session_handshake( psk: &Hash, initiator_handshake: &HandshakeMessage, payload: &HandshakePayload, responder_handshake: &mut HandshakeMessage, decrypt_cipher_state: &mut NoiseCipherState<D::Cipher>, encrypt_cipher_state: &mut NoiseCipherState<D::Cipher>, psk_seed: &mut Hash, ) -> DpeResult<()>235     fn derive_session_handshake(
236         psk: &Hash,
237         initiator_handshake: &HandshakeMessage,
238         payload: &HandshakePayload,
239         responder_handshake: &mut HandshakeMessage,
240         decrypt_cipher_state: &mut NoiseCipherState<D::Cipher>,
241         encrypt_cipher_state: &mut NoiseCipherState<D::Cipher>,
242         psk_seed: &mut Hash,
243     ) -> DpeResult<()> {
244         #[allow(unused_results)]
245         let mut handshake: noise_protocol::HandshakeState<
246             D::DH,
247             D::Cipher,
248             D::Hash,
249         > = {
250             let mut builder = HandshakeStateBuilder::new();
251             builder.set_pattern(noise_protocol::patterns::noise_nn_psk0());
252             builder.set_is_initiator(false);
253             builder.set_prologue(&[]);
254             builder.build_handshake_state()
255         };
256         handshake
257             .push_psk(psk.as_slice().get(..32).ok_or(ErrCode::InternalError)?);
258         handshake.read_message(initiator_handshake.as_slice(), &mut [])?;
259         handshake.write_message(
260             payload.as_slice(),
261             responder_handshake.as_mut_sized(
262                 handshake.get_next_message_overhead() + payload.len(),
263             )?,
264         )?;
265         let ciphers = handshake.get_ciphers();
266         *decrypt_cipher_state = (&ciphers.0).into();
267         *encrypt_cipher_state = (&ciphers.1).into();
268         *psk_seed = Hash::from_slice(handshake.get_hash())?;
269         Ok(())
270     }
271 
272     /// Encrypts a Noise transport message in place.
session_encrypt( cipher_state: &mut NoiseCipherState<D::Cipher>, in_place_buffer: &mut Message, ) -> DpeResult<()>273     fn session_encrypt(
274         cipher_state: &mut NoiseCipherState<D::Cipher>,
275         in_place_buffer: &mut Message,
276     ) -> DpeResult<()> {
277         let mut cs = noise_protocol::CipherState::<D::Cipher>::new(
278             cipher_state.k.as_slice(),
279             cipher_state.n,
280         );
281         let plaintext_len = in_place_buffer.len();
282         let _ = cs.encrypt_in_place(
283             in_place_buffer.as_mut_sized(
284                 plaintext_len
285                     + <D::Cipher as noise_protocol::Cipher>::tag_len(),
286             )?,
287             plaintext_len,
288         );
289         // Encrypting a message is usually not the final step in preparing
290         // the message for transport. If a subsequent step fails, it is
291         // better for 'n' to remain unchanged so we don't get out of sync.
292         (_, cipher_state.n_staged) = cs.extract();
293         Ok(())
294     }
295 
296     /// Decrypts a Noise transport message in place.
session_decrypt( cipher_state: &mut NoiseCipherState<D::Cipher>, in_place_buffer: &mut Message, ) -> DpeResult<()>297     fn session_decrypt(
298         cipher_state: &mut NoiseCipherState<D::Cipher>,
299         in_place_buffer: &mut Message,
300     ) -> DpeResult<()> {
301         let mut cs = noise_protocol::CipherState::<D::Cipher>::new(
302             cipher_state.k.as_slice(),
303             cipher_state.n,
304         );
305         let ciphertext_len = in_place_buffer.len();
306         let plaintext_len = match cs
307             .decrypt_in_place(in_place_buffer.vec.as_mut(), ciphertext_len)
308         {
309             Ok(length) => length,
310             _ => {
311                 error!("Session decrypt failed");
312                 return Err(ErrCode::InvalidCommand);
313             }
314         };
315         in_place_buffer.vec.truncate(plaintext_len);
316         (_, cipher_state.n) = cs.extract();
317         Ok(())
318     }
319 
320     /// Derives a responder-side PSK.
derive_psk_from_session( psk_seed: &Hash, decrypt_cipher_state: &NoiseCipherState<D::Cipher>, encrypt_cipher_state: &NoiseCipherState<D::Cipher>, ) -> DpeResult<Hash>321     fn derive_psk_from_session(
322         psk_seed: &Hash,
323         decrypt_cipher_state: &NoiseCipherState<D::Cipher>,
324         encrypt_cipher_state: &NoiseCipherState<D::Cipher>,
325     ) -> DpeResult<Hash> {
326         let mut hasher: D::Hash = Default::default();
327         hasher.input(psk_seed.as_slice());
328         // Use the decrypt state as it was before we decrypted the current
329         // command message. This allows clients to compute the PSK using
330         // the cipher states as they are before the client sends the
331         // command.
332         hasher.input(&(decrypt_cipher_state.n() - 1).to_le_bytes());
333         hasher.input(&encrypt_cipher_state.n().to_le_bytes());
334         Ok((&hasher.result()).into())
335     }
336 }
337 
338 /// A SessionClient implements the initiator side of an encrypted session. A
339 /// DPE does not use this itself, it is useful for clients and testing.
340 pub struct SessionClient<D>
341 where
342     D: NoiseCryptoDeps,
343 {
344     handshake_state:
345         Option<noise_protocol::HandshakeState<D::DH, D::Cipher, D::Hash>>,
346     /// Cipher state for encrypting messages to a DPE.
347     pub encrypt_cipher_state: NoiseCipherState<D::Cipher>,
348     /// Cipher state for decrypting messages from a DPE.
349     pub decrypt_cipher_state: NoiseCipherState<D::Cipher>,
350     /// PSK seed for deriving sessions. See [`derive_psk`].
351     ///
352     /// [`derive_psk`]: #method.derive_psk
353     pub psk_seed: Hash,
354 }
355 
356 impl<D> Clone for SessionClient<D>
357 where
358     D: NoiseCryptoDeps,
359 {
clone(&self) -> Self360     fn clone(&self) -> Self {
361         Self {
362             handshake_state: self.handshake_state.clone(),
363             encrypt_cipher_state: self.encrypt_cipher_state.clone(),
364             decrypt_cipher_state: self.decrypt_cipher_state.clone(),
365             psk_seed: self.psk_seed.clone(),
366         }
367     }
368 }
369 
370 impl<D> Default for SessionClient<D>
371 where
372     D: NoiseCryptoDeps,
373 {
default() -> Self374     fn default() -> Self {
375         Self::new()
376     }
377 }
378 
379 impl<D> core::fmt::Debug for SessionClient<D>
380 where
381     D: NoiseCryptoDeps,
382 {
fmt(&self, _: &mut core::fmt::Formatter<'_>) -> core::fmt::Result383     fn fmt(&self, _: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
384         Ok(())
385     }
386 }
387 
388 impl<D> SessionClient<D>
389 where
390     D: NoiseCryptoDeps,
391 {
392     /// Creates a new SessionClient instance. Set up by starting and finishing a
393     /// handshake.
new() -> Self394     pub fn new() -> Self {
395         Self {
396             handshake_state: Default::default(),
397             encrypt_cipher_state: Default::default(),
398             decrypt_cipher_state: Default::default(),
399             psk_seed: Default::default(),
400         }
401     }
402 
403     /// Starts a handshake using a known `public_key` and returns a message that
404     /// works with the DPE OpenSession command.
start_handshake_with_known_public_key( &mut self, public_key: &DhPublicKey, ) -> DpeResult<HandshakeMessage>405     pub fn start_handshake_with_known_public_key(
406         &mut self,
407         public_key: &DhPublicKey,
408     ) -> DpeResult<HandshakeMessage> {
409         #[allow(unused_results)]
410         let mut handshake_state = {
411             let mut builder = HandshakeStateBuilder::new();
412             builder.set_pattern(noise_protocol::patterns::noise_nk());
413             builder.set_is_initiator(true);
414             builder.set_prologue(&[]);
415             builder.set_rs(<D::DH as noise_protocol::DH>::Pubkey::from_slice(
416                 public_key.as_slice(),
417             ));
418             builder.build_handshake_state()
419         };
420         let mut message = HandshakeMessage::new();
421         handshake_state.write_message(
422             &[],
423             message
424                 .as_mut_sized(handshake_state.get_next_message_overhead())?,
425         )?;
426         self.handshake_state = Some(handshake_state);
427         Ok(message)
428     }
429 
430     /// Starts a handshake using a `psk` and returns a message that works with
431     /// the DPE DeriveContext command. Use [`derive_psk`] to obtain this value
432     /// from an existing session.
433     ///
434     /// [`derive_psk`]: #method.derive_psk
start_handshake_with_psk( &mut self, psk: &Hash, ) -> DpeResult<HandshakeMessage>435     pub fn start_handshake_with_psk(
436         &mut self,
437         psk: &Hash,
438     ) -> DpeResult<HandshakeMessage> {
439         #[allow(unused_results)]
440         let mut handshake_state = {
441             let mut builder = HandshakeStateBuilder::new();
442             builder.set_pattern(noise_protocol::patterns::noise_nn_psk0());
443             builder.set_is_initiator(true);
444             builder.set_prologue(&[]);
445             builder.build_handshake_state()
446         };
447         handshake_state
448             .push_psk(psk.as_slice().get(..32).ok_or(ErrCode::InternalError)?);
449         let mut message = HandshakeMessage::new();
450         handshake_state.write_message(
451             &[],
452             message
453                 .as_mut_sized(handshake_state.get_next_message_overhead())?,
454         )?;
455         self.handshake_state = Some(handshake_state);
456         Ok(message)
457     }
458 
459     /// Finishes a handshake started using one of the start_handshake_* methods.
460     /// On success, returns the handshake payload from the responder and sets up
461     /// internal state for subsequent calls to encrypt and decrypt.
finish_handshake( &mut self, responder_handshake: &HandshakeMessage, ) -> DpeResult<HandshakePayload>462     pub fn finish_handshake(
463         &mut self,
464         responder_handshake: &HandshakeMessage,
465     ) -> DpeResult<HandshakePayload> {
466         match self.handshake_state {
467             None => Err(ErrCode::InvalidArgument),
468             Some(ref mut handshake) => {
469                 let mut payload = HandshakePayload::new();
470                 handshake.read_message(
471                     responder_handshake.as_slice(),
472                     payload.as_mut_sized(
473                         responder_handshake.len()
474                             - handshake.get_next_message_overhead(),
475                     )?,
476                 )?;
477                 let ciphers = handshake.get_ciphers();
478                 self.encrypt_cipher_state = (&ciphers.0).into();
479                 self.decrypt_cipher_state = (&ciphers.1).into();
480                 self.psk_seed = Hash::from_slice(handshake.get_hash())?;
481                 Ok(payload)
482             }
483         }
484     }
485 
486     /// Derives a PSK from the current session.
derive_psk(&self) -> Hash487     pub fn derive_psk(&self) -> Hash {
488         // Note this is from a client perspective so the counters are hashed
489         // encrypt first and unmodified from their current state. A DPE will
490         // reverse the order and decrement the first counter in order to derive
491         // the same value (see derive_psk_from_session).
492         let mut hasher: D::Hash = Default::default();
493         hasher.input(self.psk_seed.as_slice());
494         hasher.input(&self.encrypt_cipher_state.n().to_le_bytes());
495         hasher.input(&self.decrypt_cipher_state.n().to_le_bytes());
496         (&hasher.result()).into()
497     }
498 
499     /// Encrypts a message to send to a DPE and commits cipher state changes.
encrypt(&mut self, in_place_buffer: &mut Message) -> DpeResult<()>500     pub fn encrypt(&mut self, in_place_buffer: &mut Message) -> DpeResult<()> {
501         NoiseSessionCrypto::<D>::session_encrypt(
502             &mut self.encrypt_cipher_state,
503             in_place_buffer,
504         )?;
505         self.encrypt_cipher_state.commit();
506         Ok(())
507     }
508 
509     /// Decrypts a message from a DPE.
decrypt(&mut self, in_place_buffer: &mut Message) -> DpeResult<()>510     pub fn decrypt(&mut self, in_place_buffer: &mut Message) -> DpeResult<()> {
511         NoiseSessionCrypto::<D>::session_decrypt(
512             &mut self.decrypt_cipher_state,
513             in_place_buffer,
514         )
515     }
516 }
517 
518 #[cfg(test)]
519 mod tests {
520     use super::*;
521 
522     struct DepsForTesting {}
523     impl NoiseCryptoDeps for DepsForTesting {
524         type Cipher = noise_rust_crypto::Aes256Gcm;
525         type DH = noise_rust_crypto::X25519;
526         type Hash = noise_rust_crypto::Sha512;
527     }
528 
529     type SessionCryptoForTesting = NoiseSessionCrypto<DepsForTesting>;
530 
531     type SessionClientForTesting = SessionClient<DepsForTesting>;
532 
533     type CipherStateForTesting = NoiseCipherState<noise_rust_crypto::Aes256Gcm>;
534 
535     #[test]
end_to_end_session()536     fn end_to_end_session() {
537         let mut client = SessionClientForTesting::new();
538         let dh_key: DhPrivateKey = Default::default();
539         let dh_public_key = get_dh_public_key::<
540             <DepsForTesting as NoiseCryptoDeps>::DH,
541         >(&dh_key)
542         .unwrap();
543         let handshake1 = client
544             .start_handshake_with_known_public_key(&dh_public_key)
545             .unwrap();
546         let mut dpe_decrypt_cs: CipherStateForTesting = Default::default();
547         let mut dpe_encrypt_cs: CipherStateForTesting = Default::default();
548         let mut psk_seed = Default::default();
549         let mut handshake2 = Default::default();
550         let payload = HandshakePayload::from_slice("pay".as_bytes()).unwrap();
551         SessionCryptoForTesting::new_session_handshake(
552             &dh_key,
553             &handshake1,
554             &payload,
555             &mut handshake2,
556             &mut dpe_decrypt_cs,
557             &mut dpe_encrypt_cs,
558             &mut psk_seed,
559         )
560         .unwrap();
561         assert_eq!(payload, client.finish_handshake(&handshake2).unwrap());
562 
563         // Check that the session works.
564         let mut buffer = Message::from_slice("message".as_bytes()).unwrap();
565         client.encrypt(&mut buffer).unwrap();
566         SessionCryptoForTesting::session_decrypt(
567             &mut dpe_decrypt_cs,
568             &mut buffer,
569         )
570         .unwrap();
571         assert_eq!("message".as_bytes(), buffer.as_slice());
572         SessionCryptoForTesting::session_encrypt(
573             &mut dpe_encrypt_cs,
574             &mut buffer,
575         )
576         .unwrap();
577         dpe_encrypt_cs.commit();
578         client.decrypt(&mut buffer).unwrap();
579         assert_eq!("message".as_bytes(), buffer.as_slice());
580 
581         // Do it again to check session state still works.
582         client.encrypt(&mut buffer).unwrap();
583         SessionCryptoForTesting::session_decrypt(
584             &mut dpe_decrypt_cs,
585             &mut buffer,
586         )
587         .unwrap();
588         assert_eq!("message".as_bytes(), buffer.as_slice());
589         SessionCryptoForTesting::session_encrypt(
590             &mut dpe_encrypt_cs,
591             &mut buffer,
592         )
593         .unwrap();
594         dpe_encrypt_cs.commit();
595         client.decrypt(&mut buffer).unwrap();
596         assert_eq!("message".as_bytes(), buffer.as_slice());
597     }
598 
599     #[test]
derived_session()600     fn derived_session() {
601         // Set up a session from which to derive.
602         let mut client = SessionClientForTesting::new();
603         let dh_key: DhPrivateKey = Default::default();
604         let dh_public_key = get_dh_public_key::<
605             <DepsForTesting as NoiseCryptoDeps>::DH,
606         >(&dh_key)
607         .unwrap();
608         let handshake1 = client
609             .start_handshake_with_known_public_key(&dh_public_key)
610             .unwrap();
611         let mut dpe_decrypt_cs = Default::default();
612         let mut dpe_encrypt_cs = Default::default();
613         let mut psk_seed = Default::default();
614         let mut handshake2 = Default::default();
615         let payload = HandshakePayload::from_slice("pay".as_bytes()).unwrap();
616         SessionCryptoForTesting::new_session_handshake(
617             &dh_key,
618             &handshake1,
619             &payload,
620             &mut handshake2,
621             &mut dpe_decrypt_cs,
622             &mut dpe_encrypt_cs,
623             &mut psk_seed,
624         )
625         .unwrap();
626         assert_eq!(payload, client.finish_handshake(&handshake2).unwrap());
627 
628         // Derive a second session.
629         let mut client2 = SessionClientForTesting::new();
630         let client_psk = client.derive_psk();
631         // Simulate the session state after command decryption on the DPE side
632         // as expected by the DPE PSK logic.
633         let mut buffer = Message::from_slice("message".as_bytes()).unwrap();
634         client.encrypt(&mut buffer).unwrap();
635         SessionCryptoForTesting::session_decrypt(
636             &mut dpe_decrypt_cs,
637             &mut buffer,
638         )
639         .unwrap();
640         let dpe_psk = SessionCryptoForTesting::derive_psk_from_session(
641             &psk_seed,
642             &dpe_decrypt_cs,
643             &dpe_encrypt_cs,
644         )
645         .unwrap();
646         let handshake1 = client2.start_handshake_with_psk(&client_psk).unwrap();
647         let mut dpe_decrypt_cs2 = Default::default();
648         let mut dpe_encrypt_cs2 = Default::default();
649         let mut psk_seed2 = Default::default();
650         SessionCryptoForTesting::derive_session_handshake(
651             &dpe_psk,
652             &handshake1,
653             &payload,
654             &mut handshake2,
655             &mut dpe_decrypt_cs2,
656             &mut dpe_encrypt_cs2,
657             &mut psk_seed2,
658         )
659         .unwrap();
660         assert_eq!(payload, client2.finish_handshake(&handshake2).unwrap());
661 
662         // Check that the second session works.
663         let mut buffer = Message::from_slice("message".as_bytes()).unwrap();
664         client2.encrypt(&mut buffer).unwrap();
665         SessionCryptoForTesting::session_decrypt(
666             &mut dpe_decrypt_cs2,
667             &mut buffer,
668         )
669         .unwrap();
670         assert_eq!("message".as_bytes(), buffer.as_slice());
671         SessionCryptoForTesting::session_encrypt(
672             &mut dpe_encrypt_cs2,
673             &mut buffer,
674         )
675         .unwrap();
676         dpe_encrypt_cs2.commit();
677         client2.decrypt(&mut buffer).unwrap();
678         assert_eq!("message".as_bytes(), buffer.as_slice());
679 
680         // Check that the first session also still works.
681         let mut buffer = Message::from_slice("message".as_bytes()).unwrap();
682         client.encrypt(&mut buffer).unwrap();
683         SessionCryptoForTesting::session_decrypt(
684             &mut dpe_decrypt_cs,
685             &mut buffer,
686         )
687         .unwrap();
688         assert_eq!("message".as_bytes(), buffer.as_slice());
689         SessionCryptoForTesting::session_encrypt(
690             &mut dpe_encrypt_cs,
691             &mut buffer,
692         )
693         .unwrap();
694         dpe_encrypt_cs.commit();
695         client.decrypt(&mut buffer).unwrap();
696         assert_eq!("message".as_bytes(), buffer.as_slice());
697     }
698 }
699