1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 #[cfg(target_has_atomic = "ptr")]
6 use alloc::sync::Arc;
7 
8 #[cfg(not(target_has_atomic = "ptr"))]
9 use portable_atomic_util::Arc;
10 
11 use core::convert::Infallible;
12 
13 #[cfg(feature = "std")]
14 use std::collections::HashMap;
15 
16 #[cfg(not(feature = "std"))]
17 use alloc::collections::BTreeMap;
18 
19 use mls_rs_core::psk::{ExternalPskId, PreSharedKey, PreSharedKeyStorage};
20 
21 #[cfg(mls_build_async)]
22 use alloc::boxed::Box;
23 #[cfg(feature = "std")]
24 use std::sync::Mutex;
25 
26 #[cfg(not(feature = "std"))]
27 use spin::Mutex;
28 
29 #[derive(Clone, Debug, Default)]
30 /// In memory pre-shared key storage backed by a HashMap.
31 ///
32 /// All clones of an instance of this type share the same underlying HashMap.
33 pub struct InMemoryPreSharedKeyStorage {
34     #[cfg(feature = "std")]
35     inner: Arc<Mutex<HashMap<ExternalPskId, PreSharedKey>>>,
36     #[cfg(not(feature = "std"))]
37     inner: Arc<Mutex<BTreeMap<ExternalPskId, PreSharedKey>>>,
38 }
39 
40 impl InMemoryPreSharedKeyStorage {
41     /// Insert a pre-shared key into storage.
insert(&mut self, id: ExternalPskId, psk: PreSharedKey)42     pub fn insert(&mut self, id: ExternalPskId, psk: PreSharedKey) {
43         #[cfg(feature = "std")]
44         let mut lock = self.inner.lock().unwrap();
45 
46         #[cfg(not(feature = "std"))]
47         let mut lock = self.inner.lock();
48 
49         lock.insert(id, psk);
50     }
51 
52     /// Get a pre-shared key by `id`.
get(&self, id: &ExternalPskId) -> Option<PreSharedKey>53     pub fn get(&self, id: &ExternalPskId) -> Option<PreSharedKey> {
54         #[cfg(feature = "std")]
55         let lock = self.inner.lock().unwrap();
56 
57         #[cfg(not(feature = "std"))]
58         let lock = self.inner.lock();
59 
60         lock.get(id).cloned()
61     }
62 
63     /// Delete a pre-shared key from storage.
delete(&mut self, id: &ExternalPskId)64     pub fn delete(&mut self, id: &ExternalPskId) {
65         #[cfg(feature = "std")]
66         let mut lock = self.inner.lock().unwrap();
67 
68         #[cfg(not(feature = "std"))]
69         let mut lock = self.inner.lock();
70 
71         lock.remove(id);
72     }
73 }
74 
75 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
76 #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
77 impl PreSharedKeyStorage for InMemoryPreSharedKeyStorage {
78     type Error = Infallible;
79 
get(&self, id: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error>80     async fn get(&self, id: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error> {
81         Ok(self.get(id))
82     }
83 }
84