1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use core::ops::{Deref, DerefMut};
16 use lock_adapter::stdlib::{RwLock, RwLockReadGuard, RwLockWriteGuard};
17 use lock_adapter::RwLock as _;
18 use std::collections::hash_map::Entry::{Occupied, Vacant};
19 use std::collections::HashMap;
20 use std::marker::PhantomData;
21 use std::sync::atomic::{AtomicU32, Ordering};
22 
23 use crate::guard::{
24     ObjectReadGuardImpl, ObjectReadGuardMapping, ObjectReadWriteGuardImpl,
25     ObjectReadWriteGuardMapping,
26 };
27 use crate::{Handle, HandleNotPresentError};
28 
29 // Bunch o' type aliases to make talking about them much easier in the shard code.
30 type ShardMapType<T> = HashMap<Handle, T>;
31 type ShardReadWriteLock<T> = RwLock<ShardMapType<T>>;
32 type ShardReadGuard<'a, T> = RwLockReadGuard<'a, ShardMapType<T>>;
33 type ShardReadWriteGuard<'a, T> = RwLockWriteGuard<'a, ShardMapType<T>>;
34 
35 /// Internal error enum for failed allocations into a given shard.
36 pub(crate) enum ShardAllocationError<T, E, F: FnOnce() -> Result<T, E>> {
37     /// Error for when the entry for the handle is occupied,
38     /// in which case we spit out the object-provider to try again
39     /// with a new handle-id.
40     EntryOccupied(F),
41     /// Error for when we would exceed the maximum number of allocations.
42     ExceedsAllocationLimit,
43     /// Error for when the initial value-provider call failed.
44     ValueProviderFailed(E),
45 }
46 
47 /// An individual handle-map shard, which is ultimately
48 /// just a hash-map behind a lock.
49 pub(crate) struct HandleMapShard<T: Send + Sync> {
50     data: RwLock<ShardMapType<T>>,
51 }
52 
53 impl<T: Send + Sync> Default for HandleMapShard<T> {
default() -> Self54     fn default() -> Self {
55         Self {
56             data: RwLock::new(HashMap::new()),
57         }
58     }
59 }
60 
61 impl<T: Send + Sync> HandleMapShard<T> {
get(&self, handle: Handle) -> Result<ObjectReadGuardImpl<T>, HandleNotPresentError>62     pub fn get(&self, handle: Handle) -> Result<ObjectReadGuardImpl<T>, HandleNotPresentError> {
63         let map_read_guard = ShardReadWriteLock::<T>::read(&self.data);
64         let read_only_map_ref = map_read_guard.deref();
65         if read_only_map_ref.contains_key(&handle) {
66             let object_read_guard = ShardReadGuard::<T>::map(
67                 map_read_guard,
68                 ObjectReadGuardMapping {
69                     handle,
70                     _marker: PhantomData,
71                 },
72             );
73             Ok(ObjectReadGuardImpl {
74                 guard: object_read_guard,
75             })
76         } else {
77             // Auto-drop the read guard, and return an error
78             Err(HandleNotPresentError)
79         }
80     }
81     /// Gets a read-write guard on the entire shard map if an entry for the given
82     /// handle exists, but if not, yield [`HandleNotPresentError`].
get_read_write_guard_if_entry_exists( &self, handle: Handle, ) -> Result<ShardReadWriteGuard<T>, HandleNotPresentError>83     fn get_read_write_guard_if_entry_exists(
84         &self,
85         handle: Handle,
86     ) -> Result<ShardReadWriteGuard<T>, HandleNotPresentError> {
87         let contains_key = {
88             let map_ref = self.data.read();
89             map_ref.contains_key(&handle)
90         };
91         if contains_key {
92             // If we know that the entry exists, and we're currently
93             // holding a read-lock, we know that we're safe to request
94             // an upgrade to a write lock, since only one write or
95             // upgradable read lock can be outstanding at any one time.
96             let write_guard = self.data.write();
97             Ok(write_guard)
98         } else {
99             // Auto-drop the read guard, we don't need to allow a write.
100             Err(HandleNotPresentError)
101         }
102     }
103 
get_mut( &self, handle: Handle, ) -> Result<ObjectReadWriteGuardImpl<T>, HandleNotPresentError>104     pub fn get_mut(
105         &self,
106         handle: Handle,
107     ) -> Result<ObjectReadWriteGuardImpl<T>, HandleNotPresentError> {
108         let map_read_write_guard = self.get_read_write_guard_if_entry_exists(handle)?;
109         // Expose only the pointed-to object with a mapped read-write guard
110         let object_read_write_guard = ShardReadWriteGuard::<T>::map(
111             map_read_write_guard,
112             ObjectReadWriteGuardMapping {
113                 handle,
114                 _marker: PhantomData,
115             },
116         );
117         Ok(ObjectReadWriteGuardImpl {
118             guard: object_read_write_guard,
119         })
120     }
121 
deallocate( &self, handle: Handle, outstanding_allocations_counter: &AtomicU32, ) -> Result<T, HandleNotPresentError>122     pub fn deallocate(
123         &self,
124         handle: Handle,
125         outstanding_allocations_counter: &AtomicU32,
126     ) -> Result<T, HandleNotPresentError> {
127         let mut map_read_write_guard = self.get_read_write_guard_if_entry_exists(handle)?;
128         // We don't need to worry about double-decrements, since the above call
129         // got us an upgradable read guard for our read, which means it's the only
130         // outstanding upgradeable guard on the shard. See `spin` documentation.
131         // Remove the pointed-to object from the map, and return it,
132         // releasing the lock when the guard goes out of scope.
133         #[allow(clippy::expect_used)]
134         let removed_object = map_read_write_guard
135             .deref_mut()
136             .remove(&handle)
137             .expect("existence of handle is checked above");
138         // Decrement the allocations counter. Release ordering because we want
139         // to ensure that clearing the map entry never gets re-ordered to after when
140         // this counter gets decremented.
141         let _ = outstanding_allocations_counter.fetch_sub(1, Ordering::Release);
142         Ok(removed_object)
143     }
144 
try_allocate<E, F>( &self, handle: Handle, object_provider: F, outstanding_allocations_counter: &AtomicU32, max_active_handles: u32, ) -> Result<(), ShardAllocationError<T, E, F>> where F: FnOnce() -> Result<T, E>,145     pub fn try_allocate<E, F>(
146         &self,
147         handle: Handle,
148         object_provider: F,
149         outstanding_allocations_counter: &AtomicU32,
150         max_active_handles: u32,
151     ) -> Result<(), ShardAllocationError<T, E, F>>
152     where
153         F: FnOnce() -> Result<T, E>,
154     {
155         let mut read_write_guard = self.data.write();
156         match read_write_guard.entry(handle) {
157             Occupied(_) => {
158                 // We've already allocated for that handle-id, so yield
159                 // the object provider back to the caller.
160                 Err(ShardAllocationError::EntryOccupied(object_provider))
161             }
162             Vacant(vacant_entry) => {
163                 // An entry is open, but we haven't yet checked the allocations count.
164                 // Try to increment the total allocations count atomically.
165                 // Use acquire ordering on a successful bump, because we don't want
166                 // to invoke the allocation closure before we have a guaranteed slot.
167                 // On the other hand, upon failure, we don't care about ordering
168                 // of surrounding operations, and so we use a relaxed ordering there.
169                 let allocation_count_bump_result = outstanding_allocations_counter.fetch_update(
170                     Ordering::Acquire,
171                     Ordering::Relaxed,
172                     |old_total_allocations| {
173                         if old_total_allocations >= max_active_handles {
174                             None
175                         } else {
176                             Some(old_total_allocations + 1)
177                         }
178                     },
179                 );
180                 match allocation_count_bump_result {
181                     Ok(_) => {
182                         // We're good to actually allocate,
183                         // so attempt to call the value-provider.
184                         match object_provider() {
185                             Ok(object) => {
186                                 // Successfully obtained the initial value,
187                                 // so insert it into the vacant entry.
188                                 let _ = vacant_entry.insert(object);
189                                 Ok(())
190                             }
191                             Err(e) => Err(ShardAllocationError::ValueProviderFailed(e)),
192                         }
193                     }
194                     Err(_) => {
195                         // The allocation would cause us to exceed the allowed allocations,
196                         // so release all locks and error.
197                         Err(ShardAllocationError::ExceedsAllocationLimit)
198                     }
199                 }
200             }
201         }
202     }
203 }
204