1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 use crate::client::MlsError; 6 use crate::key_package::KeyPackageRef; 7 8 use alloc::vec::Vec; 9 use mls_rs_codec::MlsEncode; 10 use mls_rs_core::{ 11 error::IntoAnyError, 12 group::{GroupState, GroupStateStorage}, 13 key_package::KeyPackageStorage, 14 }; 15 16 use super::snapshot::Snapshot; 17 18 #[derive(Debug, Clone)] 19 pub(crate) struct GroupStateRepository<S, K> 20 where 21 S: GroupStateStorage, 22 K: KeyPackageStorage, 23 { 24 pending_key_package_removal: Option<KeyPackageRef>, 25 storage: S, 26 key_package_repo: K, 27 } 28 29 impl<S, K> GroupStateRepository<S, K> 30 where 31 S: GroupStateStorage, 32 K: KeyPackageStorage, 33 { new( storage: S, key_package_repo: K, key_package_to_remove: Option<KeyPackageRef>, ) -> Result<GroupStateRepository<S, K>, MlsError>34 pub fn new( 35 storage: S, 36 key_package_repo: K, 37 // Set to `None` if restoring from snapshot; set to `Some` when joining a group. 38 key_package_to_remove: Option<KeyPackageRef>, 39 ) -> Result<GroupStateRepository<S, K>, MlsError> { 40 Ok(GroupStateRepository { 41 storage, 42 pending_key_package_removal: key_package_to_remove, 43 key_package_repo, 44 }) 45 } 46 47 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError>48 pub async fn write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError> { 49 let group_state = GroupState { 50 data: group_snapshot.mls_encode_to_vec()?, 51 id: group_snapshot.state.context.group_id, 52 }; 53 54 self.storage 55 .write(group_state, Vec::new(), Vec::new()) 56 .await 57 .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?; 58 59 if let Some(ref key_package_ref) = self.pending_key_package_removal { 60 self.key_package_repo 61 .delete(key_package_ref) 62 .await 63 .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))?; 64 } 65 66 Ok(()) 67 } 68 } 69 70 #[cfg(test)] 71 mod tests { 72 use crate::{ 73 client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION}, 74 group::{ 75 snapshot::{test_utils::get_test_snapshot, Snapshot}, 76 test_utils::{test_member, TEST_GROUP}, 77 }, 78 storage_provider::in_memory::{InMemoryGroupStateStorage, InMemoryKeyPackageStorage}, 79 }; 80 81 use alloc::vec; 82 83 use super::GroupStateRepository; 84 85 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] test_snapshot(epoch_id: u64) -> Snapshot86 async fn test_snapshot(epoch_id: u64) -> Snapshot { 87 get_test_snapshot(TEST_CIPHER_SUITE, epoch_id).await 88 } 89 90 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] test_stored_groups_list()91 async fn test_stored_groups_list() { 92 let mut test_repo = GroupStateRepository::new( 93 InMemoryGroupStateStorage::default(), 94 InMemoryKeyPackageStorage::default(), 95 None, 96 ) 97 .unwrap(); 98 99 test_repo 100 .write_to_storage(test_snapshot(0).await) 101 .await 102 .unwrap(); 103 104 assert_eq!(test_repo.storage.stored_groups(), vec![TEST_GROUP]) 105 } 106 107 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] used_key_package_is_deleted()108 async fn used_key_package_is_deleted() { 109 let key_package_repo = InMemoryKeyPackageStorage::default(); 110 111 let key_package = test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"member") 112 .await 113 .0; 114 115 let (id, data) = key_package.to_storage().unwrap(); 116 117 key_package_repo.insert(id, data); 118 119 let mut repo = GroupStateRepository::new( 120 InMemoryGroupStateStorage::default(), 121 key_package_repo, 122 Some(key_package.reference.clone()), 123 ) 124 .unwrap(); 125 126 repo.key_package_repo.get(&key_package.reference).unwrap(); 127 128 repo.write_to_storage(test_snapshot(4).await).await.unwrap(); 129 130 assert!(repo.key_package_repo.get(&key_package.reference).is_none()); 131 } 132 } 133