1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use crate::{
6     client::MlsError,
7     client_config::ClientConfig,
8     group::{
9         key_schedule::KeySchedule, CommitGeneration, ConfirmationTag, Group, GroupContext,
10         GroupState, InterimTranscriptHash, ReInitProposal, TreeKemPublic,
11     },
12     tree_kem::TreeKemPrivate,
13 };
14 
15 #[cfg(feature = "by_ref_proposal")]
16 use crate::{
17     crypto::{HpkePublicKey, HpkeSecretKey},
18     group::ProposalRef,
19 };
20 
21 #[cfg(feature = "by_ref_proposal")]
22 use super::proposal_cache::{CachedProposal, ProposalCache};
23 
24 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
25 
26 use mls_rs_core::crypto::SignatureSecretKey;
27 #[cfg(feature = "tree_index")]
28 use mls_rs_core::identity::IdentityProvider;
29 
30 #[cfg(all(feature = "std", feature = "by_ref_proposal"))]
31 use std::collections::HashMap;
32 
33 #[cfg(all(feature = "by_ref_proposal", not(feature = "std")))]
34 use alloc::vec::Vec;
35 
36 use super::{cipher_suite_provider, epoch::EpochSecrets, state_repo::GroupStateRepository};
37 
38 #[derive(Debug, PartialEq, Clone, MlsEncode, MlsDecode, MlsSize)]
39 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
40 pub(crate) struct Snapshot {
41     version: u16,
42     pub(crate) state: RawGroupState,
43     private_tree: TreeKemPrivate,
44     epoch_secrets: EpochSecrets,
45     key_schedule: KeySchedule,
46     #[cfg(all(feature = "std", feature = "by_ref_proposal"))]
47     pending_updates: HashMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>,
48     #[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))]
49     pending_updates: Vec<(HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>))>,
50     pending_commit: Option<CommitGeneration>,
51     signer: SignatureSecretKey,
52 }
53 
54 #[derive(Debug, MlsEncode, MlsDecode, MlsSize, PartialEq, Clone)]
55 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
56 pub(crate) struct RawGroupState {
57     pub(crate) context: GroupContext,
58     #[cfg(all(feature = "std", feature = "by_ref_proposal"))]
59     pub(crate) proposals: HashMap<ProposalRef, CachedProposal>,
60     #[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))]
61     pub(crate) proposals: Vec<(ProposalRef, CachedProposal)>,
62     pub(crate) public_tree: TreeKemPublic,
63     pub(crate) interim_transcript_hash: InterimTranscriptHash,
64     pub(crate) pending_reinit: Option<ReInitProposal>,
65     pub(crate) confirmation_tag: ConfirmationTag,
66 }
67 
68 impl RawGroupState {
export(state: &GroupState) -> Self69     pub(crate) fn export(state: &GroupState) -> Self {
70         #[cfg(feature = "tree_index")]
71         let public_tree = state.public_tree.clone();
72 
73         #[cfg(not(feature = "tree_index"))]
74         let public_tree = {
75             let mut tree = TreeKemPublic::new();
76             tree.nodes = state.public_tree.nodes.clone();
77             tree
78         };
79 
80         Self {
81             context: state.context.clone(),
82             #[cfg(feature = "by_ref_proposal")]
83             proposals: state.proposals.proposals.clone(),
84             public_tree,
85             interim_transcript_hash: state.interim_transcript_hash.clone(),
86             pending_reinit: state.pending_reinit.clone(),
87             confirmation_tag: state.confirmation_tag.clone(),
88         }
89     }
90 
91     #[cfg(feature = "tree_index")]
92     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
import<C>(self, identity_provider: &C) -> Result<GroupState, MlsError> where C: IdentityProvider,93     pub(crate) async fn import<C>(self, identity_provider: &C) -> Result<GroupState, MlsError>
94     where
95         C: IdentityProvider,
96     {
97         let context = self.context;
98 
99         #[cfg(feature = "by_ref_proposal")]
100         let proposals = ProposalCache::import(
101             context.protocol_version,
102             context.group_id.clone(),
103             self.proposals,
104         );
105 
106         let mut public_tree = self.public_tree;
107 
108         public_tree
109             .initialize_index_if_necessary(identity_provider, &context.extensions)
110             .await?;
111 
112         Ok(GroupState {
113             #[cfg(feature = "by_ref_proposal")]
114             proposals,
115             context,
116             public_tree,
117             interim_transcript_hash: self.interim_transcript_hash,
118             pending_reinit: self.pending_reinit,
119             confirmation_tag: self.confirmation_tag,
120         })
121     }
122 
123     #[cfg(not(feature = "tree_index"))]
124     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
import(self) -> Result<GroupState, MlsError>125     pub(crate) async fn import(self) -> Result<GroupState, MlsError> {
126         let context = self.context;
127 
128         #[cfg(feature = "by_ref_proposal")]
129         let proposals = ProposalCache::import(
130             context.protocol_version,
131             context.group_id.clone(),
132             self.proposals,
133         );
134 
135         Ok(GroupState {
136             #[cfg(feature = "by_ref_proposal")]
137             proposals,
138             context,
139             public_tree: self.public_tree,
140             interim_transcript_hash: self.interim_transcript_hash,
141             pending_reinit: self.pending_reinit,
142             confirmation_tag: self.confirmation_tag,
143         })
144     }
145 }
146 
147 impl<C> Group<C>
148 where
149     C: ClientConfig + Clone,
150 {
151     /// Write the current state of the group to the
152     /// [`GroupStorageProvider`](crate::GroupStateStorage)
153     /// that is currently in use by the group.
154     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
write_to_storage(&mut self) -> Result<(), MlsError>155     pub async fn write_to_storage(&mut self) -> Result<(), MlsError> {
156         self.state_repo.write_to_storage(self.snapshot()).await
157     }
158 
snapshot(&self) -> Snapshot159     pub(crate) fn snapshot(&self) -> Snapshot {
160         Snapshot {
161             state: RawGroupState::export(&self.state),
162             private_tree: self.private_tree.clone(),
163             key_schedule: self.key_schedule.clone(),
164             #[cfg(feature = "by_ref_proposal")]
165             pending_updates: self.pending_updates.clone(),
166             pending_commit: self.pending_commit.clone(),
167             epoch_secrets: self.epoch_secrets.clone(),
168             version: 1,
169             signer: self.signer.clone(),
170         }
171     }
172 
173     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
from_snapshot(config: C, snapshot: Snapshot) -> Result<Self, MlsError>174     pub(crate) async fn from_snapshot(config: C, snapshot: Snapshot) -> Result<Self, MlsError> {
175         let cipher_suite_provider = cipher_suite_provider(
176             config.crypto_provider(),
177             snapshot.state.context.cipher_suite,
178         )?;
179 
180         #[cfg(feature = "tree_index")]
181         let identity_provider = config.identity_provider();
182 
183         let state_repo = GroupStateRepository::new(
184             #[cfg(feature = "prior_epoch")]
185             snapshot.state.context.group_id.clone(),
186             config.group_state_storage(),
187             config.key_package_repo(),
188             None,
189         )?;
190 
191         Ok(Group {
192             config,
193             state: snapshot
194                 .state
195                 .import(
196                     #[cfg(feature = "tree_index")]
197                     &identity_provider,
198                 )
199                 .await?,
200             private_tree: snapshot.private_tree,
201             key_schedule: snapshot.key_schedule,
202             #[cfg(feature = "by_ref_proposal")]
203             pending_updates: snapshot.pending_updates,
204             pending_commit: snapshot.pending_commit,
205             #[cfg(test)]
206             commit_modifiers: Default::default(),
207             epoch_secrets: snapshot.epoch_secrets,
208             state_repo,
209             cipher_suite_provider,
210             #[cfg(feature = "psk")]
211             previous_psk: None,
212             signer: snapshot.signer,
213         })
214     }
215 }
216 
217 #[cfg(test)]
218 pub(crate) mod test_utils {
219     use alloc::vec;
220 
221     use crate::{
222         cipher_suite::CipherSuite,
223         crypto::test_utils::test_cipher_suite_provider,
224         group::{
225             confirmation_tag::ConfirmationTag, epoch::test_utils::get_test_epoch_secrets,
226             key_schedule::test_utils::get_test_key_schedule, test_utils::get_test_group_context,
227             transcript_hash::InterimTranscriptHash,
228         },
229         tree_kem::{node::LeafIndex, TreeKemPrivate},
230     };
231 
232     use super::{RawGroupState, Snapshot};
233 
234     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_snapshot(cipher_suite: CipherSuite, epoch_id: u64) -> Snapshot235     pub(crate) async fn get_test_snapshot(cipher_suite: CipherSuite, epoch_id: u64) -> Snapshot {
236         Snapshot {
237             state: RawGroupState {
238                 context: get_test_group_context(epoch_id, cipher_suite).await,
239                 #[cfg(feature = "by_ref_proposal")]
240                 proposals: Default::default(),
241                 public_tree: Default::default(),
242                 interim_transcript_hash: InterimTranscriptHash::from(vec![]),
243                 pending_reinit: None,
244                 confirmation_tag: ConfirmationTag::empty(&test_cipher_suite_provider(cipher_suite))
245                     .await,
246             },
247             private_tree: TreeKemPrivate::new(LeafIndex(0)),
248             epoch_secrets: get_test_epoch_secrets(cipher_suite),
249             key_schedule: get_test_key_schedule(cipher_suite),
250             #[cfg(feature = "by_ref_proposal")]
251             pending_updates: Default::default(),
252             pending_commit: None,
253             version: 1,
254             signer: vec![].into(),
255         }
256     }
257 }
258 
259 #[cfg(test)]
260 mod tests {
261     use alloc::vec;
262 
263     use crate::{
264         client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
265         group::{
266             test_utils::{test_group, TestGroup},
267             Group,
268         },
269     };
270 
271     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
snapshot_restore(group: TestGroup)272     async fn snapshot_restore(group: TestGroup) {
273         let snapshot = group.group.snapshot();
274 
275         let group_restored = Group::from_snapshot(group.group.config.clone(), snapshot)
276             .await
277             .unwrap();
278 
279         assert!(Group::equal_group_state(&group.group, &group_restored));
280 
281         #[cfg(feature = "tree_index")]
282         assert!(group_restored
283             .state
284             .public_tree
285             .equal_internals(&group.group.state.public_tree))
286     }
287 
288     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
snapshot_with_pending_commit_can_be_serialized_to_json()289     async fn snapshot_with_pending_commit_can_be_serialized_to_json() {
290         let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
291         group.group.commit(vec![]).await.unwrap();
292 
293         snapshot_restore(group).await
294     }
295 
296     #[cfg(feature = "by_ref_proposal")]
297     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
snapshot_with_pending_updates_can_be_serialized_to_json()298     async fn snapshot_with_pending_updates_can_be_serialized_to_json() {
299         let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
300 
301         // Creating the update proposal will add it to pending updates
302         let update_proposal = group.update_proposal().await;
303 
304         // This will insert the proposal into the internal proposal cache
305         let _ = group.group.proposal_message(update_proposal, vec![]).await;
306 
307         snapshot_restore(group).await
308     }
309 
310     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
snapshot_can_be_serialized_to_json_with_internals()311     async fn snapshot_can_be_serialized_to_json_with_internals() {
312         let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
313 
314         snapshot_restore(group).await
315     }
316 
317     #[cfg(feature = "serde")]
318     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
serde()319     async fn serde() {
320         let snapshot = super::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, 5).await;
321         let json = serde_json::to_string_pretty(&snapshot).unwrap();
322         let recovered = serde_json::from_str(&json).unwrap();
323         assert_eq!(snapshot, recovered);
324     }
325 }
326