// Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); you may not // use this file except in compliance with the License. You may obtain a copy of // the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the // License for the specific language governing permissions and limitations under // the License. //! An encrypted session implementation which uses //! Noise_NK_X25519_AESGCM_SHA512 and Noise_NNpsk0_X25519_AESGCM_SHA512. use crate::crypto::{ Commit, Counter, DhPrivateKey, DhPublicKey, HandshakeMessage, HandshakePayload, Hash, SessionCrypto, }; use crate::error::{DpeResult, ErrCode}; use crate::memory::Message; use core::marker::PhantomData; use log::{debug, error}; use noise_protocol::{HandshakeStateBuilder, Hash as NoiseHash, U8Array}; impl From for ErrCode { fn from(_err: noise_protocol::Error) -> Self { ErrCode::InvalidArgument } } impl From<&NoiseHash> for Hash where NoiseHash: U8Array, { fn from(value: &NoiseHash) -> Self { // The Noise hash size may not match HASH_SIZE. Hash::from_slice_infallible(value.as_slice()) } } /// A cipher state type that can be used as a /// [`SessionCipherState`](crate::crypto::SessionCrypto::SessionCipherState). pub struct NoiseCipherState { k: C::Key, n: u64, n_staged: u64, } impl Clone for NoiseCipherState { fn clone(&self) -> Self { Self { k: self.k.clone(), n: self.n, n_staged: self.n_staged } } } impl Default for NoiseCipherState { fn default() -> Self { Self { k: C::Key::new(), n: 0, n_staged: 0 } } } impl core::fmt::Debug for NoiseCipherState { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "k: redacted, n: {}", self.n)?; Ok(()) } } impl core::hash::Hash for NoiseCipherState { fn hash(&self, state: &mut H) { self.k.as_slice().hash(state); self.n.hash(state); self.n_staged.hash(state); } } #[cfg(test)] impl PartialEq for NoiseCipherState { fn eq(&self, other: &Self) -> bool { self.k.as_slice() == other.k.as_slice() && self.n == other.n && self.n_staged == other.n_staged } } #[cfg(test)] impl Eq for NoiseCipherState {} impl Counter for NoiseCipherState { fn n(&self) -> u64 { self.n } fn set_n(&mut self, n: u64) { self.n = n; } } impl Commit for NoiseCipherState { // Called when an encrypted message is finalized to commit the new cipher // state. fn commit(&mut self) { self.n = self.n_staged; } } impl From<&noise_protocol::CipherState> for NoiseCipherState { fn from(cs: &noise_protocol::CipherState) -> Self { let (key, counter) = cs.clone().extract(); NoiseCipherState { k: key, n: counter, n_staged: counter } } } /// Returns the public key corresponding to a given `dh_private_key`. pub fn get_dh_public_key( dh_private_key: &DhPrivateKey, ) -> DpeResult { DhPublicKey::from_slice( D::pubkey(&D::Key::from_slice(dh_private_key.as_slice())).as_slice(), ) } /// A trait representing [`NoiseSessionCrypto`] dependencies. pub trait NoiseCryptoDeps { /// Cipher type type Cipher: noise_protocol::Cipher; /// DH type type DH: noise_protocol::DH; /// Hash type type Hash: noise_protocol::Hash; } /// A Noise implementation of the [`SessionCrypto`] trait. pub struct NoiseSessionCrypto { #[allow(dead_code)] phantom: PhantomData, } impl Clone for NoiseSessionCrypto where D: NoiseCryptoDeps, { fn clone(&self) -> Self { Self { phantom: Default::default() } } } impl Default for NoiseSessionCrypto where D: NoiseCryptoDeps, { fn default() -> Self { Self { phantom: Default::default() } } } impl core::fmt::Debug for NoiseSessionCrypto where D: NoiseCryptoDeps, { fn fmt(&self, _: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { Ok(()) } } impl core::hash::Hash for NoiseSessionCrypto where D: NoiseCryptoDeps, { fn hash(&self, _: &mut Hr) {} } impl PartialEq for NoiseSessionCrypto where D: NoiseCryptoDeps, { fn eq(&self, _: &Self) -> bool { true } } impl Eq for NoiseSessionCrypto where D: NoiseCryptoDeps {} impl SessionCrypto for NoiseSessionCrypto where D: NoiseCryptoDeps, { type SessionCipherState = NoiseCipherState; /// Implements the responder role of a Noise_NK handshake. fn new_session_handshake( static_dh_key: &DhPrivateKey, initiator_handshake: &HandshakeMessage, payload: &HandshakePayload, responder_handshake: &mut HandshakeMessage, decrypt_cipher_state: &mut NoiseCipherState, encrypt_cipher_state: &mut NoiseCipherState, psk_seed: &mut Hash, ) -> DpeResult<()> { #[allow(unused_results)] let mut handshake: noise_protocol::HandshakeState< D::DH, D::Cipher, D::Hash, > = { let mut builder = HandshakeStateBuilder::new(); builder.set_pattern(noise_protocol::patterns::noise_nk()); builder.set_is_initiator(false); builder.set_prologue(&[]); builder.set_s(::Key::from_slice( static_dh_key.as_slice(), )); builder.build_handshake_state() }; handshake.read_message(initiator_handshake.as_slice(), &mut [])?; handshake.write_message( payload.as_slice(), responder_handshake.as_mut_sized( handshake.get_next_message_overhead() + payload.len(), )?, )?; assert!(handshake.completed()); let ciphers = handshake.get_ciphers(); *decrypt_cipher_state = (&ciphers.0).into(); *encrypt_cipher_state = (&ciphers.1).into(); debug!("get_hash"); *psk_seed = Hash::from_slice(handshake.get_hash())?; Ok(()) } /// Implements the responder role of a Noise_NNpsk0 handshake. fn derive_session_handshake( psk: &Hash, initiator_handshake: &HandshakeMessage, payload: &HandshakePayload, responder_handshake: &mut HandshakeMessage, decrypt_cipher_state: &mut NoiseCipherState, encrypt_cipher_state: &mut NoiseCipherState, psk_seed: &mut Hash, ) -> DpeResult<()> { #[allow(unused_results)] let mut handshake: noise_protocol::HandshakeState< D::DH, D::Cipher, D::Hash, > = { let mut builder = HandshakeStateBuilder::new(); builder.set_pattern(noise_protocol::patterns::noise_nn_psk0()); builder.set_is_initiator(false); builder.set_prologue(&[]); builder.build_handshake_state() }; handshake .push_psk(psk.as_slice().get(..32).ok_or(ErrCode::InternalError)?); handshake.read_message(initiator_handshake.as_slice(), &mut [])?; handshake.write_message( payload.as_slice(), responder_handshake.as_mut_sized( handshake.get_next_message_overhead() + payload.len(), )?, )?; let ciphers = handshake.get_ciphers(); *decrypt_cipher_state = (&ciphers.0).into(); *encrypt_cipher_state = (&ciphers.1).into(); *psk_seed = Hash::from_slice(handshake.get_hash())?; Ok(()) } /// Encrypts a Noise transport message in place. fn session_encrypt( cipher_state: &mut NoiseCipherState, in_place_buffer: &mut Message, ) -> DpeResult<()> { let mut cs = noise_protocol::CipherState::::new( cipher_state.k.as_slice(), cipher_state.n, ); let plaintext_len = in_place_buffer.len(); let _ = cs.encrypt_in_place( in_place_buffer.as_mut_sized( plaintext_len + ::tag_len(), )?, plaintext_len, ); // Encrypting a message is usually not the final step in preparing // the message for transport. If a subsequent step fails, it is // better for 'n' to remain unchanged so we don't get out of sync. (_, cipher_state.n_staged) = cs.extract(); Ok(()) } /// Decrypts a Noise transport message in place. fn session_decrypt( cipher_state: &mut NoiseCipherState, in_place_buffer: &mut Message, ) -> DpeResult<()> { let mut cs = noise_protocol::CipherState::::new( cipher_state.k.as_slice(), cipher_state.n, ); let ciphertext_len = in_place_buffer.len(); let plaintext_len = match cs .decrypt_in_place(in_place_buffer.vec.as_mut(), ciphertext_len) { Ok(length) => length, _ => { error!("Session decrypt failed"); return Err(ErrCode::InvalidCommand); } }; in_place_buffer.vec.truncate(plaintext_len); (_, cipher_state.n) = cs.extract(); Ok(()) } /// Derives a responder-side PSK. fn derive_psk_from_session( psk_seed: &Hash, decrypt_cipher_state: &NoiseCipherState, encrypt_cipher_state: &NoiseCipherState, ) -> DpeResult { let mut hasher: D::Hash = Default::default(); hasher.input(psk_seed.as_slice()); // Use the decrypt state as it was before we decrypted the current // command message. This allows clients to compute the PSK using // the cipher states as they are before the client sends the // command. hasher.input(&(decrypt_cipher_state.n() - 1).to_le_bytes()); hasher.input(&encrypt_cipher_state.n().to_le_bytes()); Ok((&hasher.result()).into()) } } /// A SessionClient implements the initiator side of an encrypted session. A /// DPE does not use this itself, it is useful for clients and testing. pub struct SessionClient where D: NoiseCryptoDeps, { handshake_state: Option>, /// Cipher state for encrypting messages to a DPE. pub encrypt_cipher_state: NoiseCipherState, /// Cipher state for decrypting messages from a DPE. pub decrypt_cipher_state: NoiseCipherState, /// PSK seed for deriving sessions. See [`derive_psk`]. /// /// [`derive_psk`]: #method.derive_psk pub psk_seed: Hash, } impl Clone for SessionClient where D: NoiseCryptoDeps, { fn clone(&self) -> Self { Self { handshake_state: self.handshake_state.clone(), encrypt_cipher_state: self.encrypt_cipher_state.clone(), decrypt_cipher_state: self.decrypt_cipher_state.clone(), psk_seed: self.psk_seed.clone(), } } } impl Default for SessionClient where D: NoiseCryptoDeps, { fn default() -> Self { Self::new() } } impl core::fmt::Debug for SessionClient where D: NoiseCryptoDeps, { fn fmt(&self, _: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { Ok(()) } } impl SessionClient where D: NoiseCryptoDeps, { /// Creates a new SessionClient instance. Set up by starting and finishing a /// handshake. pub fn new() -> Self { Self { handshake_state: Default::default(), encrypt_cipher_state: Default::default(), decrypt_cipher_state: Default::default(), psk_seed: Default::default(), } } /// Starts a handshake using a known `public_key` and returns a message that /// works with the DPE OpenSession command. pub fn start_handshake_with_known_public_key( &mut self, public_key: &DhPublicKey, ) -> DpeResult { #[allow(unused_results)] let mut handshake_state = { let mut builder = HandshakeStateBuilder::new(); builder.set_pattern(noise_protocol::patterns::noise_nk()); builder.set_is_initiator(true); builder.set_prologue(&[]); builder.set_rs(::Pubkey::from_slice( public_key.as_slice(), )); builder.build_handshake_state() }; let mut message = HandshakeMessage::new(); handshake_state.write_message( &[], message .as_mut_sized(handshake_state.get_next_message_overhead())?, )?; self.handshake_state = Some(handshake_state); Ok(message) } /// Starts a handshake using a `psk` and returns a message that works with /// the DPE DeriveContext command. Use [`derive_psk`] to obtain this value /// from an existing session. /// /// [`derive_psk`]: #method.derive_psk pub fn start_handshake_with_psk( &mut self, psk: &Hash, ) -> DpeResult { #[allow(unused_results)] let mut handshake_state = { let mut builder = HandshakeStateBuilder::new(); builder.set_pattern(noise_protocol::patterns::noise_nn_psk0()); builder.set_is_initiator(true); builder.set_prologue(&[]); builder.build_handshake_state() }; handshake_state .push_psk(psk.as_slice().get(..32).ok_or(ErrCode::InternalError)?); let mut message = HandshakeMessage::new(); handshake_state.write_message( &[], message .as_mut_sized(handshake_state.get_next_message_overhead())?, )?; self.handshake_state = Some(handshake_state); Ok(message) } /// Finishes a handshake started using one of the start_handshake_* methods. /// On success, returns the handshake payload from the responder and sets up /// internal state for subsequent calls to encrypt and decrypt. pub fn finish_handshake( &mut self, responder_handshake: &HandshakeMessage, ) -> DpeResult { match self.handshake_state { None => Err(ErrCode::InvalidArgument), Some(ref mut handshake) => { let mut payload = HandshakePayload::new(); handshake.read_message( responder_handshake.as_slice(), payload.as_mut_sized( responder_handshake.len() - handshake.get_next_message_overhead(), )?, )?; let ciphers = handshake.get_ciphers(); self.encrypt_cipher_state = (&ciphers.0).into(); self.decrypt_cipher_state = (&ciphers.1).into(); self.psk_seed = Hash::from_slice(handshake.get_hash())?; Ok(payload) } } } /// Derives a PSK from the current session. pub fn derive_psk(&self) -> Hash { // Note this is from a client perspective so the counters are hashed // encrypt first and unmodified from their current state. A DPE will // reverse the order and decrement the first counter in order to derive // the same value (see derive_psk_from_session). let mut hasher: D::Hash = Default::default(); hasher.input(self.psk_seed.as_slice()); hasher.input(&self.encrypt_cipher_state.n().to_le_bytes()); hasher.input(&self.decrypt_cipher_state.n().to_le_bytes()); (&hasher.result()).into() } /// Encrypts a message to send to a DPE and commits cipher state changes. pub fn encrypt(&mut self, in_place_buffer: &mut Message) -> DpeResult<()> { NoiseSessionCrypto::::session_encrypt( &mut self.encrypt_cipher_state, in_place_buffer, )?; self.encrypt_cipher_state.commit(); Ok(()) } /// Decrypts a message from a DPE. pub fn decrypt(&mut self, in_place_buffer: &mut Message) -> DpeResult<()> { NoiseSessionCrypto::::session_decrypt( &mut self.decrypt_cipher_state, in_place_buffer, ) } } #[cfg(test)] mod tests { use super::*; struct DepsForTesting {} impl NoiseCryptoDeps for DepsForTesting { type Cipher = noise_rust_crypto::Aes256Gcm; type DH = noise_rust_crypto::X25519; type Hash = noise_rust_crypto::Sha512; } type SessionCryptoForTesting = NoiseSessionCrypto; type SessionClientForTesting = SessionClient; type CipherStateForTesting = NoiseCipherState; #[test] fn end_to_end_session() { let mut client = SessionClientForTesting::new(); let dh_key: DhPrivateKey = Default::default(); let dh_public_key = get_dh_public_key::< ::DH, >(&dh_key) .unwrap(); let handshake1 = client .start_handshake_with_known_public_key(&dh_public_key) .unwrap(); let mut dpe_decrypt_cs: CipherStateForTesting = Default::default(); let mut dpe_encrypt_cs: CipherStateForTesting = Default::default(); let mut psk_seed = Default::default(); let mut handshake2 = Default::default(); let payload = HandshakePayload::from_slice("pay".as_bytes()).unwrap(); SessionCryptoForTesting::new_session_handshake( &dh_key, &handshake1, &payload, &mut handshake2, &mut dpe_decrypt_cs, &mut dpe_encrypt_cs, &mut psk_seed, ) .unwrap(); assert_eq!(payload, client.finish_handshake(&handshake2).unwrap()); // Check that the session works. let mut buffer = Message::from_slice("message".as_bytes()).unwrap(); client.encrypt(&mut buffer).unwrap(); SessionCryptoForTesting::session_decrypt( &mut dpe_decrypt_cs, &mut buffer, ) .unwrap(); assert_eq!("message".as_bytes(), buffer.as_slice()); SessionCryptoForTesting::session_encrypt( &mut dpe_encrypt_cs, &mut buffer, ) .unwrap(); dpe_encrypt_cs.commit(); client.decrypt(&mut buffer).unwrap(); assert_eq!("message".as_bytes(), buffer.as_slice()); // Do it again to check session state still works. client.encrypt(&mut buffer).unwrap(); SessionCryptoForTesting::session_decrypt( &mut dpe_decrypt_cs, &mut buffer, ) .unwrap(); assert_eq!("message".as_bytes(), buffer.as_slice()); SessionCryptoForTesting::session_encrypt( &mut dpe_encrypt_cs, &mut buffer, ) .unwrap(); dpe_encrypt_cs.commit(); client.decrypt(&mut buffer).unwrap(); assert_eq!("message".as_bytes(), buffer.as_slice()); } #[test] fn derived_session() { // Set up a session from which to derive. let mut client = SessionClientForTesting::new(); let dh_key: DhPrivateKey = Default::default(); let dh_public_key = get_dh_public_key::< ::DH, >(&dh_key) .unwrap(); let handshake1 = client .start_handshake_with_known_public_key(&dh_public_key) .unwrap(); let mut dpe_decrypt_cs = Default::default(); let mut dpe_encrypt_cs = Default::default(); let mut psk_seed = Default::default(); let mut handshake2 = Default::default(); let payload = HandshakePayload::from_slice("pay".as_bytes()).unwrap(); SessionCryptoForTesting::new_session_handshake( &dh_key, &handshake1, &payload, &mut handshake2, &mut dpe_decrypt_cs, &mut dpe_encrypt_cs, &mut psk_seed, ) .unwrap(); assert_eq!(payload, client.finish_handshake(&handshake2).unwrap()); // Derive a second session. let mut client2 = SessionClientForTesting::new(); let client_psk = client.derive_psk(); // Simulate the session state after command decryption on the DPE side // as expected by the DPE PSK logic. let mut buffer = Message::from_slice("message".as_bytes()).unwrap(); client.encrypt(&mut buffer).unwrap(); SessionCryptoForTesting::session_decrypt( &mut dpe_decrypt_cs, &mut buffer, ) .unwrap(); let dpe_psk = SessionCryptoForTesting::derive_psk_from_session( &psk_seed, &dpe_decrypt_cs, &dpe_encrypt_cs, ) .unwrap(); let handshake1 = client2.start_handshake_with_psk(&client_psk).unwrap(); let mut dpe_decrypt_cs2 = Default::default(); let mut dpe_encrypt_cs2 = Default::default(); let mut psk_seed2 = Default::default(); SessionCryptoForTesting::derive_session_handshake( &dpe_psk, &handshake1, &payload, &mut handshake2, &mut dpe_decrypt_cs2, &mut dpe_encrypt_cs2, &mut psk_seed2, ) .unwrap(); assert_eq!(payload, client2.finish_handshake(&handshake2).unwrap()); // Check that the second session works. let mut buffer = Message::from_slice("message".as_bytes()).unwrap(); client2.encrypt(&mut buffer).unwrap(); SessionCryptoForTesting::session_decrypt( &mut dpe_decrypt_cs2, &mut buffer, ) .unwrap(); assert_eq!("message".as_bytes(), buffer.as_slice()); SessionCryptoForTesting::session_encrypt( &mut dpe_encrypt_cs2, &mut buffer, ) .unwrap(); dpe_encrypt_cs2.commit(); client2.decrypt(&mut buffer).unwrap(); assert_eq!("message".as_bytes(), buffer.as_slice()); // Check that the first session also still works. let mut buffer = Message::from_slice("message".as_bytes()).unwrap(); client.encrypt(&mut buffer).unwrap(); SessionCryptoForTesting::session_decrypt( &mut dpe_decrypt_cs, &mut buffer, ) .unwrap(); assert_eq!("message".as_bytes(), buffer.as_slice()); SessionCryptoForTesting::session_encrypt( &mut dpe_encrypt_cs, &mut buffer, ) .unwrap(); dpe_encrypt_cs.commit(); client.decrypt(&mut buffer).unwrap(); assert_eq!("message".as_bytes(), buffer.as_slice()); } }