1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use crate::client::MlsError;
6 use crate::key_package::KeyPackageRef;
7 
8 use alloc::vec::Vec;
9 use mls_rs_codec::MlsEncode;
10 use mls_rs_core::{
11     error::IntoAnyError,
12     group::{GroupState, GroupStateStorage},
13     key_package::KeyPackageStorage,
14 };
15 
16 use super::snapshot::Snapshot;
17 
18 #[derive(Debug, Clone)]
19 pub(crate) struct GroupStateRepository<S, K>
20 where
21     S: GroupStateStorage,
22     K: KeyPackageStorage,
23 {
24     pending_key_package_removal: Option<KeyPackageRef>,
25     storage: S,
26     key_package_repo: K,
27 }
28 
29 impl<S, K> GroupStateRepository<S, K>
30 where
31     S: GroupStateStorage,
32     K: KeyPackageStorage,
33 {
new( storage: S, key_package_repo: K, key_package_to_remove: Option<KeyPackageRef>, ) -> Result<GroupStateRepository<S, K>, MlsError>34     pub fn new(
35         storage: S,
36         key_package_repo: K,
37         // Set to `None` if restoring from snapshot; set to `Some` when joining a group.
38         key_package_to_remove: Option<KeyPackageRef>,
39     ) -> Result<GroupStateRepository<S, K>, MlsError> {
40         Ok(GroupStateRepository {
41             storage,
42             pending_key_package_removal: key_package_to_remove,
43             key_package_repo,
44         })
45     }
46 
47     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError>48     pub async fn write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError> {
49         let group_state = GroupState {
50             data: group_snapshot.mls_encode_to_vec()?,
51             id: group_snapshot.state.context.group_id,
52         };
53 
54         self.storage
55             .write(group_state, Vec::new(), Vec::new())
56             .await
57             .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?;
58 
59         if let Some(ref key_package_ref) = self.pending_key_package_removal {
60             self.key_package_repo
61                 .delete(key_package_ref)
62                 .await
63                 .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))?;
64         }
65 
66         Ok(())
67     }
68 }
69 
70 #[cfg(test)]
71 mod tests {
72     use crate::{
73         client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
74         group::{
75             snapshot::{test_utils::get_test_snapshot, Snapshot},
76             test_utils::{test_member, TEST_GROUP},
77         },
78         storage_provider::in_memory::{InMemoryGroupStateStorage, InMemoryKeyPackageStorage},
79     };
80 
81     use alloc::vec;
82 
83     use super::GroupStateRepository;
84 
85     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_snapshot(epoch_id: u64) -> Snapshot86     async fn test_snapshot(epoch_id: u64) -> Snapshot {
87         get_test_snapshot(TEST_CIPHER_SUITE, epoch_id).await
88     }
89 
90     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_stored_groups_list()91     async fn test_stored_groups_list() {
92         let mut test_repo = GroupStateRepository::new(
93             InMemoryGroupStateStorage::default(),
94             InMemoryKeyPackageStorage::default(),
95             None,
96         )
97         .unwrap();
98 
99         test_repo
100             .write_to_storage(test_snapshot(0).await)
101             .await
102             .unwrap();
103 
104         assert_eq!(test_repo.storage.stored_groups(), vec![TEST_GROUP])
105     }
106 
107     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
used_key_package_is_deleted()108     async fn used_key_package_is_deleted() {
109         let key_package_repo = InMemoryKeyPackageStorage::default();
110 
111         let key_package = test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"member")
112             .await
113             .0;
114 
115         let (id, data) = key_package.to_storage().unwrap();
116 
117         key_package_repo.insert(id, data);
118 
119         let mut repo = GroupStateRepository::new(
120             InMemoryGroupStateStorage::default(),
121             key_package_repo,
122             Some(key_package.reference.clone()),
123         )
124         .unwrap();
125 
126         repo.key_package_repo.get(&key_package.reference).unwrap();
127 
128         repo.write_to_storage(test_snapshot(4).await).await.unwrap();
129 
130         assert!(repo.key_package_repo.get(&key_package.reference).is_none());
131     }
132 }
133