1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use crate::{ 6 client::MlsError, 7 client_config::ClientConfig, 8 group::{ 9 key_schedule::KeySchedule, CommitGeneration, ConfirmationTag, Group, GroupContext, 10 GroupState, InterimTranscriptHash, ReInitProposal, TreeKemPublic, 11 }, 12 tree_kem::TreeKemPrivate, 13 }; 14 15 #[cfg(feature = "by_ref_proposal")] 16 use crate::{ 17 crypto::{HpkePublicKey, HpkeSecretKey}, 18 group::ProposalRef, 19 }; 20 21 #[cfg(feature = "by_ref_proposal")] 22 use super::proposal_cache::{CachedProposal, ProposalCache}; 23 24 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; 25 26 use mls_rs_core::crypto::SignatureSecretKey; 27 #[cfg(feature = "tree_index")] 28 use mls_rs_core::identity::IdentityProvider; 29 30 #[cfg(all(feature = "std", feature = "by_ref_proposal"))] 31 use std::collections::HashMap; 32 33 #[cfg(all(feature = "by_ref_proposal", not(feature = "std")))] 34 use alloc::vec::Vec; 35 36 use super::{cipher_suite_provider, epoch::EpochSecrets, state_repo::GroupStateRepository}; 37 38 #[derive(Debug, PartialEq, Clone, MlsEncode, MlsDecode, MlsSize)] 39 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 40 pub(crate) struct Snapshot { 41 version: u16, 42 pub(crate) state: RawGroupState, 43 private_tree: TreeKemPrivate, 44 epoch_secrets: EpochSecrets, 45 key_schedule: KeySchedule, 46 #[cfg(all(feature = "std", feature = "by_ref_proposal"))] 47 pending_updates: HashMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>, 48 #[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))] 49 pending_updates: Vec<(HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>))>, 50 pending_commit: Option<CommitGeneration>, 51 signer: SignatureSecretKey, 52 } 53 54 #[derive(Debug, MlsEncode, MlsDecode, MlsSize, PartialEq, Clone)] 55 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] 56 pub(crate) struct RawGroupState { 57 pub(crate) context: GroupContext, 58 #[cfg(all(feature = "std", feature = "by_ref_proposal"))] 59 pub(crate) proposals: HashMap<ProposalRef, CachedProposal>, 60 #[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))] 61 pub(crate) proposals: Vec<(ProposalRef, CachedProposal)>, 62 pub(crate) public_tree: TreeKemPublic, 63 pub(crate) interim_transcript_hash: InterimTranscriptHash, 64 pub(crate) pending_reinit: Option<ReInitProposal>, 65 pub(crate) confirmation_tag: ConfirmationTag, 66 } 67 68 impl RawGroupState { export(state: &GroupState) -> Self69 pub(crate) fn export(state: &GroupState) -> Self { 70 #[cfg(feature = "tree_index")] 71 let public_tree = state.public_tree.clone(); 72 73 #[cfg(not(feature = "tree_index"))] 74 let public_tree = { 75 let mut tree = TreeKemPublic::new(); 76 tree.nodes = state.public_tree.nodes.clone(); 77 tree 78 }; 79 80 Self { 81 context: state.context.clone(), 82 #[cfg(feature = "by_ref_proposal")] 83 proposals: state.proposals.proposals.clone(), 84 public_tree, 85 interim_transcript_hash: state.interim_transcript_hash.clone(), 86 pending_reinit: state.pending_reinit.clone(), 87 confirmation_tag: state.confirmation_tag.clone(), 88 } 89 } 90 91 #[cfg(feature = "tree_index")] 92 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] import<C>(self, identity_provider: &C) -> Result<GroupState, MlsError> where C: IdentityProvider,93 pub(crate) async fn import<C>(self, identity_provider: &C) -> Result<GroupState, MlsError> 94 where 95 C: IdentityProvider, 96 { 97 let context = self.context; 98 99 #[cfg(feature = "by_ref_proposal")] 100 let proposals = ProposalCache::import( 101 context.protocol_version, 102 context.group_id.clone(), 103 self.proposals, 104 ); 105 106 let mut public_tree = self.public_tree; 107 108 public_tree 109 .initialize_index_if_necessary(identity_provider, &context.extensions) 110 .await?; 111 112 Ok(GroupState { 113 #[cfg(feature = "by_ref_proposal")] 114 proposals, 115 context, 116 public_tree, 117 interim_transcript_hash: self.interim_transcript_hash, 118 pending_reinit: self.pending_reinit, 119 confirmation_tag: self.confirmation_tag, 120 }) 121 } 122 123 #[cfg(not(feature = "tree_index"))] 124 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] import(self) -> Result<GroupState, MlsError>125 pub(crate) async fn import(self) -> Result<GroupState, MlsError> { 126 let context = self.context; 127 128 #[cfg(feature = "by_ref_proposal")] 129 let proposals = ProposalCache::import( 130 context.protocol_version, 131 context.group_id.clone(), 132 self.proposals, 133 ); 134 135 Ok(GroupState { 136 #[cfg(feature = "by_ref_proposal")] 137 proposals, 138 context, 139 public_tree: self.public_tree, 140 interim_transcript_hash: self.interim_transcript_hash, 141 pending_reinit: self.pending_reinit, 142 confirmation_tag: self.confirmation_tag, 143 }) 144 } 145 } 146 147 impl<C> Group<C> 148 where 149 C: ClientConfig + Clone, 150 { 151 /// Write the current state of the group to the 152 /// [`GroupStorageProvider`](crate::GroupStateStorage) 153 /// that is currently in use by the group. 154 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] write_to_storage(&mut self) -> Result<(), MlsError>155 pub async fn write_to_storage(&mut self) -> Result<(), MlsError> { 156 self.state_repo.write_to_storage(self.snapshot()).await 157 } 158 snapshot(&self) -> Snapshot159 pub(crate) fn snapshot(&self) -> Snapshot { 160 Snapshot { 161 state: RawGroupState::export(&self.state), 162 private_tree: self.private_tree.clone(), 163 key_schedule: self.key_schedule.clone(), 164 #[cfg(feature = "by_ref_proposal")] 165 pending_updates: self.pending_updates.clone(), 166 pending_commit: self.pending_commit.clone(), 167 epoch_secrets: self.epoch_secrets.clone(), 168 version: 1, 169 signer: self.signer.clone(), 170 } 171 } 172 173 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] from_snapshot(config: C, snapshot: Snapshot) -> Result<Self, MlsError>174 pub(crate) async fn from_snapshot(config: C, snapshot: Snapshot) -> Result<Self, MlsError> { 175 let cipher_suite_provider = cipher_suite_provider( 176 config.crypto_provider(), 177 snapshot.state.context.cipher_suite, 178 )?; 179 180 #[cfg(feature = "tree_index")] 181 let identity_provider = config.identity_provider(); 182 183 let state_repo = GroupStateRepository::new( 184 #[cfg(feature = "prior_epoch")] 185 snapshot.state.context.group_id.clone(), 186 config.group_state_storage(), 187 config.key_package_repo(), 188 None, 189 )?; 190 191 Ok(Group { 192 config, 193 state: snapshot 194 .state 195 .import( 196 #[cfg(feature = "tree_index")] 197 &identity_provider, 198 ) 199 .await?, 200 private_tree: snapshot.private_tree, 201 key_schedule: snapshot.key_schedule, 202 #[cfg(feature = "by_ref_proposal")] 203 pending_updates: snapshot.pending_updates, 204 pending_commit: snapshot.pending_commit, 205 #[cfg(test)] 206 commit_modifiers: Default::default(), 207 epoch_secrets: snapshot.epoch_secrets, 208 state_repo, 209 cipher_suite_provider, 210 #[cfg(feature = "psk")] 211 previous_psk: None, 212 signer: snapshot.signer, 213 }) 214 } 215 } 216 217 #[cfg(test)] 218 pub(crate) mod test_utils { 219 use alloc::vec; 220 221 use crate::{ 222 cipher_suite::CipherSuite, 223 crypto::test_utils::test_cipher_suite_provider, 224 group::{ 225 confirmation_tag::ConfirmationTag, epoch::test_utils::get_test_epoch_secrets, 226 key_schedule::test_utils::get_test_key_schedule, test_utils::get_test_group_context, 227 transcript_hash::InterimTranscriptHash, 228 }, 229 tree_kem::{node::LeafIndex, TreeKemPrivate}, 230 }; 231 232 use super::{RawGroupState, Snapshot}; 233 234 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] get_test_snapshot(cipher_suite: CipherSuite, epoch_id: u64) -> Snapshot235 pub(crate) async fn get_test_snapshot(cipher_suite: CipherSuite, epoch_id: u64) -> Snapshot { 236 Snapshot { 237 state: RawGroupState { 238 context: get_test_group_context(epoch_id, cipher_suite).await, 239 #[cfg(feature = "by_ref_proposal")] 240 proposals: Default::default(), 241 public_tree: Default::default(), 242 interim_transcript_hash: InterimTranscriptHash::from(vec![]), 243 pending_reinit: None, 244 confirmation_tag: ConfirmationTag::empty(&test_cipher_suite_provider(cipher_suite)) 245 .await, 246 }, 247 private_tree: TreeKemPrivate::new(LeafIndex(0)), 248 epoch_secrets: get_test_epoch_secrets(cipher_suite), 249 key_schedule: get_test_key_schedule(cipher_suite), 250 #[cfg(feature = "by_ref_proposal")] 251 pending_updates: Default::default(), 252 pending_commit: None, 253 version: 1, 254 signer: vec![].into(), 255 } 256 } 257 } 258 259 #[cfg(test)] 260 mod tests { 261 use alloc::vec; 262 263 use crate::{ 264 client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION}, 265 group::{ 266 test_utils::{test_group, TestGroup}, 267 Group, 268 }, 269 }; 270 271 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] snapshot_restore(group: TestGroup)272 async fn snapshot_restore(group: TestGroup) { 273 let snapshot = group.group.snapshot(); 274 275 let group_restored = Group::from_snapshot(group.group.config.clone(), snapshot) 276 .await 277 .unwrap(); 278 279 assert!(Group::equal_group_state(&group.group, &group_restored)); 280 281 #[cfg(feature = "tree_index")] 282 assert!(group_restored 283 .state 284 .public_tree 285 .equal_internals(&group.group.state.public_tree)) 286 } 287 288 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] snapshot_with_pending_commit_can_be_serialized_to_json()289 async fn snapshot_with_pending_commit_can_be_serialized_to_json() { 290 let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; 291 group.group.commit(vec![]).await.unwrap(); 292 293 snapshot_restore(group).await 294 } 295 296 #[cfg(feature = "by_ref_proposal")] 297 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] snapshot_with_pending_updates_can_be_serialized_to_json()298 async fn snapshot_with_pending_updates_can_be_serialized_to_json() { 299 let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; 300 301 // Creating the update proposal will add it to pending updates 302 let update_proposal = group.update_proposal().await; 303 304 // This will insert the proposal into the internal proposal cache 305 let _ = group.group.proposal_message(update_proposal, vec![]).await; 306 307 snapshot_restore(group).await 308 } 309 310 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] snapshot_can_be_serialized_to_json_with_internals()311 async fn snapshot_can_be_serialized_to_json_with_internals() { 312 let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await; 313 314 snapshot_restore(group).await 315 } 316 317 #[cfg(feature = "serde")] 318 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] serde()319 async fn serde() { 320 let snapshot = super::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, 5).await; 321 let json = serde_json::to_string_pretty(&snapshot).unwrap(); 322 let recovered = serde_json::from_str(&json).unwrap(); 323 assert_eq!(snapshot, recovered); 324 } 325 } 326