1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 #[cfg(target_has_atomic = "ptr")]
6 use alloc::sync::Arc;
7 
8 #[cfg(not(target_has_atomic = "ptr"))]
9 use portable_atomic_util::Arc;
10 
11 use core::{
12     convert::Infallible,
13     fmt::{self, Debug},
14 };
15 
16 #[cfg(feature = "std")]
17 use std::collections::HashMap;
18 
19 #[cfg(not(feature = "std"))]
20 use alloc::collections::BTreeMap;
21 use alloc::vec::Vec;
22 use mls_rs_core::key_package::{KeyPackageData, KeyPackageStorage};
23 
24 #[cfg(feature = "std")]
25 use std::sync::Mutex;
26 
27 #[cfg(mls_build_async)]
28 use alloc::boxed::Box;
29 #[cfg(not(feature = "std"))]
30 use spin::Mutex;
31 
32 #[derive(Clone, Default)]
33 /// In memory key package storage backed by a HashMap.
34 ///
35 /// All clones of an instance of this type share the same underlying HashMap.
36 pub struct InMemoryKeyPackageStorage {
37     #[cfg(feature = "std")]
38     inner: Arc<Mutex<HashMap<Vec<u8>, KeyPackageData>>>,
39     #[cfg(not(feature = "std"))]
40     inner: Arc<Mutex<BTreeMap<Vec<u8>, KeyPackageData>>>,
41 }
42 
43 impl Debug for InMemoryKeyPackageStorage {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result44     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45         f.debug_struct("InMemoryKeyPackageStorage")
46             .field(
47                 "inner",
48                 &mls_rs_core::debug::pretty_with(|f| {
49                     f.debug_map()
50                         .entries(
51                             self.lock()
52                                 .iter()
53                                 .map(|(k, v)| (mls_rs_core::debug::pretty_bytes(k), v)),
54                         )
55                         .finish()
56                 }),
57             )
58             .finish()
59     }
60 }
61 
62 impl InMemoryKeyPackageStorage {
63     /// Create an empty key package storage.
new() -> Self64     pub fn new() -> Self {
65         Default::default()
66     }
67 
68     /// Insert key package data.
insert(&self, id: Vec<u8>, pkg: KeyPackageData)69     pub fn insert(&self, id: Vec<u8>, pkg: KeyPackageData) {
70         self.lock().insert(id, pkg);
71     }
72 
73     /// Get a key package data by `id`.
get(&self, id: &[u8]) -> Option<KeyPackageData>74     pub fn get(&self, id: &[u8]) -> Option<KeyPackageData> {
75         self.lock().get(id).cloned()
76     }
77 
78     /// Delete key package data by `id`.
delete(&self, id: &[u8])79     pub fn delete(&self, id: &[u8]) {
80         self.lock().remove(id);
81     }
82 
83     /// Get all key packages that are currently stored.
key_packages(&self) -> Vec<(Vec<u8>, KeyPackageData)>84     pub fn key_packages(&self) -> Vec<(Vec<u8>, KeyPackageData)> {
85         self.lock()
86             .iter()
87             .map(|(k, v)| (k.clone(), v.clone()))
88             .collect()
89     }
90 
91     #[cfg(feature = "std")]
lock(&self) -> std::sync::MutexGuard<'_, HashMap<Vec<u8>, KeyPackageData>>92     fn lock(&self) -> std::sync::MutexGuard<'_, HashMap<Vec<u8>, KeyPackageData>> {
93         self.inner.lock().unwrap()
94     }
95 
96     #[cfg(not(feature = "std"))]
lock(&self) -> spin::mutex::MutexGuard<'_, BTreeMap<Vec<u8>, KeyPackageData>>97     fn lock(&self) -> spin::mutex::MutexGuard<'_, BTreeMap<Vec<u8>, KeyPackageData>> {
98         self.inner.lock()
99     }
100 }
101 
102 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
103 #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
104 impl KeyPackageStorage for InMemoryKeyPackageStorage {
105     type Error = Infallible;
106 
delete(&mut self, id: &[u8]) -> Result<(), Self::Error>107     async fn delete(&mut self, id: &[u8]) -> Result<(), Self::Error> {
108         (*self).delete(id);
109         Ok(())
110     }
111 
insert(&mut self, id: Vec<u8>, pkg: KeyPackageData) -> Result<(), Self::Error>112     async fn insert(&mut self, id: Vec<u8>, pkg: KeyPackageData) -> Result<(), Self::Error> {
113         (*self).insert(id, pkg);
114         Ok(())
115     }
116 
get(&self, id: &[u8]) -> Result<Option<KeyPackageData>, Self::Error>117     async fn get(&self, id: &[u8]) -> Result<Option<KeyPackageData>, Self::Error> {
118         Ok(self.get(id))
119     }
120 }
121