xref: /aosp_15_r20/bootable/libbootloader/gbl/libasync/src/lib.rs (revision 5225e6b173e52d2efc6bcf950c27374fd72adabc)
1 // Copyright 2024, The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 //! This file provides async utility APIs used by GBL.
16 //!
17 //! They are mainly barebone APIs for busy waiting and polling Futures. There is no support for
18 //! sleep/wake or threading.
19 
20 #![cfg_attr(not(test), no_std)]
21 
22 use core::{
23     future::Future,
24     ops::DerefMut,
25     pin::{pin, Pin},
26     ptr::null,
27     task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
28 };
29 
30 /// Clone method for `NOOP_VTABLE`.
noop_clone(_: *const ()) -> RawWaker31 fn noop_clone(_: *const ()) -> RawWaker {
32     noop_raw_waker()
33 }
34 
35 /// Noop method for `wake`, `wake_by_ref` and `drop` in `RawWakerVTable`.
noop_wake_method(_: *const ())36 fn noop_wake_method(_: *const ()) {}
37 
38 /// A noop `RawWakerVTable`
39 const NOOP_VTABLE: RawWakerVTable =
40     RawWakerVTable::new(noop_clone, noop_wake_method, noop_wake_method, noop_wake_method);
41 
42 /// Creates a noop instance that does nothing.
noop_raw_waker() -> RawWaker43 fn noop_raw_waker() -> RawWaker {
44     RawWaker::new(null(), &NOOP_VTABLE)
45 }
46 
47 /// Repetitively polls and blocks until a future completes.
block_on<O>(fut: impl Future<Output = O>) -> O48 pub fn block_on<O>(fut: impl Future<Output = O>) -> O {
49     let mut fut = pin!(fut);
50     loop {
51         match poll(&mut fut) {
52             Some(res) => return res,
53             _ => {}
54         }
55     }
56 }
57 
58 /// Polls a Future.
59 ///
60 /// Returns Some(_) if ready, None otherwise.
poll<O, F: Future<Output = O> + ?Sized>( fut: &mut Pin<impl DerefMut<Target = F>>, ) -> Option<O>61 pub fn poll<O, F: Future<Output = O> + ?Sized>(
62     fut: &mut Pin<impl DerefMut<Target = F>>,
63 ) -> Option<O> {
64     // SAFETY:
65     // * All methods for noop_raw_waker() are either noop or have no shared state. Thus they are
66     //   thread-safe.
67     let waker = unsafe { Waker::from_raw(noop_raw_waker()) };
68     let mut context = Context::from_waker(&waker);
69     match fut.as_mut().poll(&mut context) {
70         Poll::Pending => None,
71         Poll::Ready(res) => Some(res),
72     }
73 }
74 
75 /// Polls the given future for up to `n` times.
poll_n_times<O, F: Future<Output = O> + ?Sized>( fut: &mut Pin<impl DerefMut<Target = F>>, n: usize, ) -> Option<O>76 pub fn poll_n_times<O, F: Future<Output = O> + ?Sized>(
77     fut: &mut Pin<impl DerefMut<Target = F>>,
78     n: usize,
79 ) -> Option<O> {
80     (0..n).find_map(|_| poll(fut))
81 }
82 
83 /// `Yield` implements a simple API for yielding control once to the executor.
84 struct Yield(bool);
85 
86 impl Future for Yield {
87     type Output = ();
88 
poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output>89     fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
90         self.0 = !self.0;
91         match self.0 {
92             true => Poll::Pending,
93             _ => Poll::Ready(()),
94         }
95     }
96 }
97 
98 /// Yield the execution once.
yield_now()99 pub async fn yield_now() {
100     Yield(false).await
101 }
102 
103 /// `YieldCounter` maintains a counter and yield control to executor once it overflows a given
104 /// threshold. When overflow occurs, the counter value is reset and the carry over is discarded.
105 pub struct YieldCounter {
106     threshold: u64,
107     current: u64,
108 }
109 
110 impl YieldCounter {
111     /// Creates an instance with a given threshold.
new(threshold: u64) -> Self112     pub fn new(threshold: u64) -> Self {
113         Self { threshold, current: 0 }
114     }
115 
116     /// Increments the current counter and yield execution if the value overflows the threshold.
increment(&mut self, inc: u64)117     pub async fn increment(&mut self, inc: u64) {
118         self.current = self.current.saturating_sub(inc);
119         if self.current == 0 {
120             self.current = self.threshold;
121             yield_now().await;
122         }
123     }
124 }
125 
126 /// Repetitively polls two futures until both of them finish.
join<L, LO, R, RO>(fut_lhs: L, fut_rhs: R) -> (LO, RO) where L: Future<Output = LO>, R: Future<Output = RO>,127 pub async fn join<L, LO, R, RO>(fut_lhs: L, fut_rhs: R) -> (LO, RO)
128 where
129     L: Future<Output = LO>,
130     R: Future<Output = RO>,
131 {
132     let fut_lhs = &mut pin!(fut_lhs);
133     let fut_rhs = &mut pin!(fut_rhs);
134     let mut out_lhs = poll(fut_lhs);
135     let mut out_rhs = poll(fut_rhs);
136     while out_lhs.is_none() || out_rhs.is_none() {
137         yield_now().await;
138         if out_lhs.is_none() {
139             out_lhs = poll(fut_lhs);
140         }
141 
142         if out_rhs.is_none() {
143             out_rhs = poll(fut_rhs);
144         }
145     }
146     (out_lhs.unwrap(), out_rhs.unwrap())
147 }
148 
149 /// Waits until either of the given two futures completes.
select<L, LO, R, RO>(fut_lhs: L, fut_rhs: R) -> (Option<LO>, Option<RO>) where L: Future<Output = LO>, R: Future<Output = RO>,150 pub async fn select<L, LO, R, RO>(fut_lhs: L, fut_rhs: R) -> (Option<LO>, Option<RO>)
151 where
152     L: Future<Output = LO>,
153     R: Future<Output = RO>,
154 {
155     let fut_lhs = &mut pin!(fut_lhs);
156     let fut_rhs = &mut pin!(fut_rhs);
157     let mut out_lhs = poll(fut_lhs);
158     let mut out_rhs = poll(fut_rhs);
159     while out_lhs.is_none() && out_rhs.is_none() {
160         yield_now().await;
161         out_lhs = poll(fut_lhs);
162         out_rhs = poll(fut_rhs);
163     }
164     (out_lhs, out_rhs)
165 }
166 
167 /// Runs a [Future] and checks and asserts that it returns eventually.
assert_return<O>(fut: impl Future<Output = O>) -> O168 pub async fn assert_return<O>(fut: impl Future<Output = O>) -> O {
169     struct Returned(bool);
170 
171     impl Drop for Returned {
172         fn drop(&mut self) {
173             assert!(self.0)
174         }
175     }
176 
177     let mut flag = Returned(false);
178     let res = fut.await;
179     flag.0 = true;
180     res
181 }
182 
183 #[cfg(test)]
184 mod test {
185     use super::*;
186     use std::sync::Mutex;
187 
188     #[test]
test()189     fn test() {
190         let mut counter = YieldCounter::new(1);
191         let mut fut = pin!(async move {
192             counter.increment(2).await;
193             counter.increment(2).await;
194         });
195 
196         assert!(poll(&mut fut).is_none());
197         assert!(poll(&mut fut).is_none());
198         assert!(poll(&mut fut).is_some());
199     }
200 
201     #[test]
test_join()202     fn test_join() {
203         let val1 = Mutex::new(0);
204         let val2 = Mutex::new(1);
205 
206         let mut join_fut = pin!(join(
207             async {
208                 *val1.try_lock().unwrap() += 1;
209                 yield_now().await;
210                 *val1.try_lock().unwrap() += 1;
211                 yield_now().await;
212             },
213             async {
214                 *val2.try_lock().unwrap() += 1;
215                 yield_now().await;
216                 *val2.try_lock().unwrap() += 1;
217                 yield_now().await;
218                 *val2.try_lock().unwrap() += 1;
219                 yield_now().await;
220             }
221         ));
222 
223         assert!(poll(&mut join_fut).is_none());
224         assert_eq!(*val1.try_lock().unwrap(), 1);
225         assert_eq!(*val2.try_lock().unwrap(), 2);
226 
227         assert!(poll(&mut join_fut).is_none());
228         assert_eq!(*val1.try_lock().unwrap(), 2);
229         assert_eq!(*val2.try_lock().unwrap(), 3);
230 
231         assert!(poll(&mut join_fut).is_none());
232         assert_eq!(*val1.try_lock().unwrap(), 2);
233         assert_eq!(*val2.try_lock().unwrap(), 4);
234 
235         assert!(poll(&mut join_fut).is_some());
236     }
237 
238     #[test]
test_select()239     fn test_select() {
240         let val1 = Mutex::new(0);
241         let val2 = Mutex::new(1);
242 
243         let mut select_fut = pin!(select(
244             async {
245                 *val1.try_lock().unwrap() += 1;
246                 yield_now().await;
247                 *val1.try_lock().unwrap() += 1;
248                 yield_now().await;
249             },
250             async {
251                 *val2.try_lock().unwrap() += 1;
252                 yield_now().await;
253                 *val2.try_lock().unwrap() += 1;
254                 yield_now().await;
255                 *val2.try_lock().unwrap() += 1;
256                 yield_now().await;
257             }
258         ));
259 
260         assert!(poll(&mut select_fut).is_none());
261         assert_eq!(*val1.try_lock().unwrap(), 1);
262         assert_eq!(*val2.try_lock().unwrap(), 2);
263 
264         assert!(poll(&mut select_fut).is_none());
265         assert_eq!(*val1.try_lock().unwrap(), 2);
266         assert_eq!(*val2.try_lock().unwrap(), 3);
267 
268         let (lhs, rhs) = poll(&mut select_fut).unwrap();
269         assert!(lhs.is_some());
270         assert!(rhs.is_none());
271     }
272 
273     #[test]
test_assert_return()274     fn test_assert_return() {
275         // Finishes. No assert.
276         block_on(assert_return(async { yield_now().await }));
277     }
278 
279     #[test]
280     #[should_panic]
test_assert_return_panics()281     fn test_assert_return_panics() {
282         let mut fut = pin!(assert_return(async { yield_now().await }));
283         // Need one more poll to finish. Thus it should panic when going out of scope.
284         assert!(poll(&mut fut).is_none());
285     }
286 }
287