1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 // Copyright by contributors to this project. 3 // SPDX-License-Identifier: (Apache-2.0 OR MIT) 4 5 #[cfg(target_has_atomic = "ptr")] 6 use alloc::sync::Arc; 7 8 #[cfg(not(target_has_atomic = "ptr"))] 9 use portable_atomic_util::Arc; 10 11 use core::{ 12 convert::Infallible, 13 fmt::{self, Debug}, 14 }; 15 16 #[cfg(feature = "std")] 17 use std::collections::HashMap; 18 19 #[cfg(not(feature = "std"))] 20 use alloc::collections::BTreeMap; 21 use alloc::vec::Vec; 22 use mls_rs_core::key_package::{KeyPackageData, KeyPackageStorage}; 23 24 #[cfg(feature = "std")] 25 use std::sync::Mutex; 26 27 #[cfg(mls_build_async)] 28 use alloc::boxed::Box; 29 #[cfg(not(feature = "std"))] 30 use spin::Mutex; 31 32 #[derive(Clone, Default)] 33 /// In memory key package storage backed by a HashMap. 34 /// 35 /// All clones of an instance of this type share the same underlying HashMap. 36 pub struct InMemoryKeyPackageStorage { 37 #[cfg(feature = "std")] 38 inner: Arc<Mutex<HashMap<Vec<u8>, KeyPackageData>>>, 39 #[cfg(not(feature = "std"))] 40 inner: Arc<Mutex<BTreeMap<Vec<u8>, KeyPackageData>>>, 41 } 42 43 impl Debug for InMemoryKeyPackageStorage { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 45 f.debug_struct("InMemoryKeyPackageStorage") 46 .field( 47 "inner", 48 &mls_rs_core::debug::pretty_with(|f| { 49 f.debug_map() 50 .entries( 51 self.lock() 52 .iter() 53 .map(|(k, v)| (mls_rs_core::debug::pretty_bytes(k), v)), 54 ) 55 .finish() 56 }), 57 ) 58 .finish() 59 } 60 } 61 62 impl InMemoryKeyPackageStorage { 63 /// Create an empty key package storage. new() -> Self64 pub fn new() -> Self { 65 Default::default() 66 } 67 68 /// Insert key package data. insert(&self, id: Vec<u8>, pkg: KeyPackageData)69 pub fn insert(&self, id: Vec<u8>, pkg: KeyPackageData) { 70 self.lock().insert(id, pkg); 71 } 72 73 /// Get a key package data by `id`. get(&self, id: &[u8]) -> Option<KeyPackageData>74 pub fn get(&self, id: &[u8]) -> Option<KeyPackageData> { 75 self.lock().get(id).cloned() 76 } 77 78 /// Delete key package data by `id`. delete(&self, id: &[u8])79 pub fn delete(&self, id: &[u8]) { 80 self.lock().remove(id); 81 } 82 83 /// Get all key packages that are currently stored. key_packages(&self) -> Vec<(Vec<u8>, KeyPackageData)>84 pub fn key_packages(&self) -> Vec<(Vec<u8>, KeyPackageData)> { 85 self.lock() 86 .iter() 87 .map(|(k, v)| (k.clone(), v.clone())) 88 .collect() 89 } 90 91 #[cfg(feature = "std")] lock(&self) -> std::sync::MutexGuard<'_, HashMap<Vec<u8>, KeyPackageData>>92 fn lock(&self) -> std::sync::MutexGuard<'_, HashMap<Vec<u8>, KeyPackageData>> { 93 self.inner.lock().unwrap() 94 } 95 96 #[cfg(not(feature = "std"))] lock(&self) -> spin::mutex::MutexGuard<'_, BTreeMap<Vec<u8>, KeyPackageData>>97 fn lock(&self) -> spin::mutex::MutexGuard<'_, BTreeMap<Vec<u8>, KeyPackageData>> { 98 self.inner.lock() 99 } 100 } 101 102 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] 103 #[cfg_attr(mls_build_async, maybe_async::must_be_async)] 104 impl KeyPackageStorage for InMemoryKeyPackageStorage { 105 type Error = Infallible; 106 delete(&mut self, id: &[u8]) -> Result<(), Self::Error>107 async fn delete(&mut self, id: &[u8]) -> Result<(), Self::Error> { 108 (*self).delete(id); 109 Ok(()) 110 } 111 insert(&mut self, id: Vec<u8>, pkg: KeyPackageData) -> Result<(), Self::Error>112 async fn insert(&mut self, id: Vec<u8>, pkg: KeyPackageData) -> Result<(), Self::Error> { 113 (*self).insert(id, pkg); 114 Ok(()) 115 } 116 get(&self, id: &[u8]) -> Result<Option<KeyPackageData>, Self::Error>117 async fn get(&self, id: &[u8]) -> Result<Option<KeyPackageData>, Self::Error> { 118 Ok(self.get(id)) 119 } 120 } 121