1 // Copyright 2024, The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 //! This file provides async utility APIs used by GBL.
16 //!
17 //! They are mainly barebone APIs for busy waiting and polling Futures. There is no support for
18 //! sleep/wake or threading.
19
20 #![cfg_attr(not(test), no_std)]
21
22 use core::{
23 future::Future,
24 ops::DerefMut,
25 pin::{pin, Pin},
26 ptr::null,
27 task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
28 };
29
30 /// Clone method for `NOOP_VTABLE`.
noop_clone(_: *const ()) -> RawWaker31 fn noop_clone(_: *const ()) -> RawWaker {
32 noop_raw_waker()
33 }
34
35 /// Noop method for `wake`, `wake_by_ref` and `drop` in `RawWakerVTable`.
noop_wake_method(_: *const ())36 fn noop_wake_method(_: *const ()) {}
37
38 /// A noop `RawWakerVTable`
39 const NOOP_VTABLE: RawWakerVTable =
40 RawWakerVTable::new(noop_clone, noop_wake_method, noop_wake_method, noop_wake_method);
41
42 /// Creates a noop instance that does nothing.
noop_raw_waker() -> RawWaker43 fn noop_raw_waker() -> RawWaker {
44 RawWaker::new(null(), &NOOP_VTABLE)
45 }
46
47 /// Repetitively polls and blocks until a future completes.
block_on<O>(fut: impl Future<Output = O>) -> O48 pub fn block_on<O>(fut: impl Future<Output = O>) -> O {
49 let mut fut = pin!(fut);
50 loop {
51 match poll(&mut fut) {
52 Some(res) => return res,
53 _ => {}
54 }
55 }
56 }
57
58 /// Polls a Future.
59 ///
60 /// Returns Some(_) if ready, None otherwise.
poll<O, F: Future<Output = O> + ?Sized>( fut: &mut Pin<impl DerefMut<Target = F>>, ) -> Option<O>61 pub fn poll<O, F: Future<Output = O> + ?Sized>(
62 fut: &mut Pin<impl DerefMut<Target = F>>,
63 ) -> Option<O> {
64 // SAFETY:
65 // * All methods for noop_raw_waker() are either noop or have no shared state. Thus they are
66 // thread-safe.
67 let waker = unsafe { Waker::from_raw(noop_raw_waker()) };
68 let mut context = Context::from_waker(&waker);
69 match fut.as_mut().poll(&mut context) {
70 Poll::Pending => None,
71 Poll::Ready(res) => Some(res),
72 }
73 }
74
75 /// Polls the given future for up to `n` times.
poll_n_times<O, F: Future<Output = O> + ?Sized>( fut: &mut Pin<impl DerefMut<Target = F>>, n: usize, ) -> Option<O>76 pub fn poll_n_times<O, F: Future<Output = O> + ?Sized>(
77 fut: &mut Pin<impl DerefMut<Target = F>>,
78 n: usize,
79 ) -> Option<O> {
80 (0..n).find_map(|_| poll(fut))
81 }
82
83 /// `Yield` implements a simple API for yielding control once to the executor.
84 struct Yield(bool);
85
86 impl Future for Yield {
87 type Output = ();
88
poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output>89 fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
90 self.0 = !self.0;
91 match self.0 {
92 true => Poll::Pending,
93 _ => Poll::Ready(()),
94 }
95 }
96 }
97
98 /// Yield the execution once.
yield_now()99 pub async fn yield_now() {
100 Yield(false).await
101 }
102
103 /// `YieldCounter` maintains a counter and yield control to executor once it overflows a given
104 /// threshold. When overflow occurs, the counter value is reset and the carry over is discarded.
105 pub struct YieldCounter {
106 threshold: u64,
107 current: u64,
108 }
109
110 impl YieldCounter {
111 /// Creates an instance with a given threshold.
new(threshold: u64) -> Self112 pub fn new(threshold: u64) -> Self {
113 Self { threshold, current: 0 }
114 }
115
116 /// Increments the current counter and yield execution if the value overflows the threshold.
increment(&mut self, inc: u64)117 pub async fn increment(&mut self, inc: u64) {
118 self.current = self.current.saturating_sub(inc);
119 if self.current == 0 {
120 self.current = self.threshold;
121 yield_now().await;
122 }
123 }
124 }
125
126 /// Repetitively polls two futures until both of them finish.
join<L, LO, R, RO>(fut_lhs: L, fut_rhs: R) -> (LO, RO) where L: Future<Output = LO>, R: Future<Output = RO>,127 pub async fn join<L, LO, R, RO>(fut_lhs: L, fut_rhs: R) -> (LO, RO)
128 where
129 L: Future<Output = LO>,
130 R: Future<Output = RO>,
131 {
132 let fut_lhs = &mut pin!(fut_lhs);
133 let fut_rhs = &mut pin!(fut_rhs);
134 let mut out_lhs = poll(fut_lhs);
135 let mut out_rhs = poll(fut_rhs);
136 while out_lhs.is_none() || out_rhs.is_none() {
137 yield_now().await;
138 if out_lhs.is_none() {
139 out_lhs = poll(fut_lhs);
140 }
141
142 if out_rhs.is_none() {
143 out_rhs = poll(fut_rhs);
144 }
145 }
146 (out_lhs.unwrap(), out_rhs.unwrap())
147 }
148
149 /// Waits until either of the given two futures completes.
select<L, LO, R, RO>(fut_lhs: L, fut_rhs: R) -> (Option<LO>, Option<RO>) where L: Future<Output = LO>, R: Future<Output = RO>,150 pub async fn select<L, LO, R, RO>(fut_lhs: L, fut_rhs: R) -> (Option<LO>, Option<RO>)
151 where
152 L: Future<Output = LO>,
153 R: Future<Output = RO>,
154 {
155 let fut_lhs = &mut pin!(fut_lhs);
156 let fut_rhs = &mut pin!(fut_rhs);
157 let mut out_lhs = poll(fut_lhs);
158 let mut out_rhs = poll(fut_rhs);
159 while out_lhs.is_none() && out_rhs.is_none() {
160 yield_now().await;
161 out_lhs = poll(fut_lhs);
162 out_rhs = poll(fut_rhs);
163 }
164 (out_lhs, out_rhs)
165 }
166
167 /// Runs a [Future] and checks and asserts that it returns eventually.
assert_return<O>(fut: impl Future<Output = O>) -> O168 pub async fn assert_return<O>(fut: impl Future<Output = O>) -> O {
169 struct Returned(bool);
170
171 impl Drop for Returned {
172 fn drop(&mut self) {
173 assert!(self.0)
174 }
175 }
176
177 let mut flag = Returned(false);
178 let res = fut.await;
179 flag.0 = true;
180 res
181 }
182
183 #[cfg(test)]
184 mod test {
185 use super::*;
186 use std::sync::Mutex;
187
188 #[test]
test()189 fn test() {
190 let mut counter = YieldCounter::new(1);
191 let mut fut = pin!(async move {
192 counter.increment(2).await;
193 counter.increment(2).await;
194 });
195
196 assert!(poll(&mut fut).is_none());
197 assert!(poll(&mut fut).is_none());
198 assert!(poll(&mut fut).is_some());
199 }
200
201 #[test]
test_join()202 fn test_join() {
203 let val1 = Mutex::new(0);
204 let val2 = Mutex::new(1);
205
206 let mut join_fut = pin!(join(
207 async {
208 *val1.try_lock().unwrap() += 1;
209 yield_now().await;
210 *val1.try_lock().unwrap() += 1;
211 yield_now().await;
212 },
213 async {
214 *val2.try_lock().unwrap() += 1;
215 yield_now().await;
216 *val2.try_lock().unwrap() += 1;
217 yield_now().await;
218 *val2.try_lock().unwrap() += 1;
219 yield_now().await;
220 }
221 ));
222
223 assert!(poll(&mut join_fut).is_none());
224 assert_eq!(*val1.try_lock().unwrap(), 1);
225 assert_eq!(*val2.try_lock().unwrap(), 2);
226
227 assert!(poll(&mut join_fut).is_none());
228 assert_eq!(*val1.try_lock().unwrap(), 2);
229 assert_eq!(*val2.try_lock().unwrap(), 3);
230
231 assert!(poll(&mut join_fut).is_none());
232 assert_eq!(*val1.try_lock().unwrap(), 2);
233 assert_eq!(*val2.try_lock().unwrap(), 4);
234
235 assert!(poll(&mut join_fut).is_some());
236 }
237
238 #[test]
test_select()239 fn test_select() {
240 let val1 = Mutex::new(0);
241 let val2 = Mutex::new(1);
242
243 let mut select_fut = pin!(select(
244 async {
245 *val1.try_lock().unwrap() += 1;
246 yield_now().await;
247 *val1.try_lock().unwrap() += 1;
248 yield_now().await;
249 },
250 async {
251 *val2.try_lock().unwrap() += 1;
252 yield_now().await;
253 *val2.try_lock().unwrap() += 1;
254 yield_now().await;
255 *val2.try_lock().unwrap() += 1;
256 yield_now().await;
257 }
258 ));
259
260 assert!(poll(&mut select_fut).is_none());
261 assert_eq!(*val1.try_lock().unwrap(), 1);
262 assert_eq!(*val2.try_lock().unwrap(), 2);
263
264 assert!(poll(&mut select_fut).is_none());
265 assert_eq!(*val1.try_lock().unwrap(), 2);
266 assert_eq!(*val2.try_lock().unwrap(), 3);
267
268 let (lhs, rhs) = poll(&mut select_fut).unwrap();
269 assert!(lhs.is_some());
270 assert!(rhs.is_none());
271 }
272
273 #[test]
test_assert_return()274 fn test_assert_return() {
275 // Finishes. No assert.
276 block_on(assert_return(async { yield_now().await }));
277 }
278
279 #[test]
280 #[should_panic]
test_assert_return_panics()281 fn test_assert_return_panics() {
282 let mut fut = pin!(assert_return(async { yield_now().await }));
283 // Need one more poll to finish. Thus it should panic when going out of scope.
284 assert!(poll(&mut fut).is_none());
285 }
286 }
287