1 use crate::signal::os::{OsExtraData, OsStorage};
2 use crate::sync::watch;
3 use crate::util::once_cell::OnceCell;
4
5 use std::ops;
6 use std::sync::atomic::{AtomicBool, Ordering};
7
8 pub(crate) type EventId = usize;
9
10 /// State for a specific event, whether a notification is pending delivery,
11 /// and what listeners are registered.
12 #[derive(Debug)]
13 pub(crate) struct EventInfo {
14 pending: AtomicBool,
15 tx: watch::Sender<()>,
16 }
17
18 impl Default for EventInfo {
default() -> Self19 fn default() -> Self {
20 let (tx, _rx) = watch::channel(());
21
22 Self {
23 pending: AtomicBool::new(false),
24 tx,
25 }
26 }
27 }
28
29 /// An interface for retrieving the `EventInfo` for a particular `eventId`.
30 pub(crate) trait Storage {
31 /// Gets the `EventInfo` for `id` if it exists.
event_info(&self, id: EventId) -> Option<&EventInfo>32 fn event_info(&self, id: EventId) -> Option<&EventInfo>;
33
34 /// Invokes `f` once for each defined `EventInfo` in this storage.
for_each<'a, F>(&'a self, f: F) where F: FnMut(&'a EventInfo)35 fn for_each<'a, F>(&'a self, f: F)
36 where
37 F: FnMut(&'a EventInfo);
38 }
39
40 impl Storage for Vec<EventInfo> {
event_info(&self, id: EventId) -> Option<&EventInfo>41 fn event_info(&self, id: EventId) -> Option<&EventInfo> {
42 self.get(id)
43 }
44
for_each<'a, F>(&'a self, f: F) where F: FnMut(&'a EventInfo),45 fn for_each<'a, F>(&'a self, f: F)
46 where
47 F: FnMut(&'a EventInfo),
48 {
49 self.iter().for_each(f);
50 }
51 }
52
53 /// An interface for initializing a type. Useful for situations where we cannot
54 /// inject a configured instance in the constructor of another type.
55 pub(crate) trait Init {
init() -> Self56 fn init() -> Self;
57 }
58
59 /// Manages and distributes event notifications to any registered listeners.
60 ///
61 /// Generic over the underlying storage to allow for domain specific
62 /// optimizations (e.g. `eventIds` may or may not be contiguous).
63 #[derive(Debug)]
64 pub(crate) struct Registry<S> {
65 storage: S,
66 }
67
68 impl<S> Registry<S> {
new(storage: S) -> Self69 fn new(storage: S) -> Self {
70 Self { storage }
71 }
72 }
73
74 impl<S: Storage> Registry<S> {
75 /// Registers a new listener for `event_id`.
register_listener(&self, event_id: EventId) -> watch::Receiver<()>76 fn register_listener(&self, event_id: EventId) -> watch::Receiver<()> {
77 self.storage
78 .event_info(event_id)
79 .unwrap_or_else(|| panic!("invalid event_id: {event_id}"))
80 .tx
81 .subscribe()
82 }
83
84 /// Marks `event_id` as having been delivered, without broadcasting it to
85 /// any listeners.
record_event(&self, event_id: EventId)86 fn record_event(&self, event_id: EventId) {
87 if let Some(event_info) = self.storage.event_info(event_id) {
88 event_info.pending.store(true, Ordering::SeqCst);
89 }
90 }
91
92 /// Broadcasts all previously recorded events to their respective listeners.
93 ///
94 /// Returns `true` if an event was delivered to at least one listener.
broadcast(&self) -> bool95 fn broadcast(&self) -> bool {
96 let mut did_notify = false;
97 self.storage.for_each(|event_info| {
98 // Any signal of this kind arrived since we checked last?
99 if !event_info.pending.swap(false, Ordering::SeqCst) {
100 return;
101 }
102
103 // Ignore errors if there are no listeners
104 if event_info.tx.send(()).is_ok() {
105 did_notify = true;
106 }
107 });
108
109 did_notify
110 }
111 }
112
113 pub(crate) struct Globals {
114 extra: OsExtraData,
115 registry: Registry<OsStorage>,
116 }
117
118 impl ops::Deref for Globals {
119 type Target = OsExtraData;
120
deref(&self) -> &Self::Target121 fn deref(&self) -> &Self::Target {
122 &self.extra
123 }
124 }
125
126 impl Globals {
127 /// Registers a new listener for `event_id`.
register_listener(&self, event_id: EventId) -> watch::Receiver<()>128 pub(crate) fn register_listener(&self, event_id: EventId) -> watch::Receiver<()> {
129 self.registry.register_listener(event_id)
130 }
131
132 /// Marks `event_id` as having been delivered, without broadcasting it to
133 /// any listeners.
record_event(&self, event_id: EventId)134 pub(crate) fn record_event(&self, event_id: EventId) {
135 self.registry.record_event(event_id);
136 }
137
138 /// Broadcasts all previously recorded events to their respective listeners.
139 ///
140 /// Returns `true` if an event was delivered to at least one listener.
broadcast(&self) -> bool141 pub(crate) fn broadcast(&self) -> bool {
142 self.registry.broadcast()
143 }
144
145 #[cfg(unix)]
storage(&self) -> &OsStorage146 pub(crate) fn storage(&self) -> &OsStorage {
147 &self.registry.storage
148 }
149 }
150
globals_init() -> Globals where OsExtraData: 'static + Send + Sync + Init, OsStorage: 'static + Send + Sync + Init,151 fn globals_init() -> Globals
152 where
153 OsExtraData: 'static + Send + Sync + Init,
154 OsStorage: 'static + Send + Sync + Init,
155 {
156 Globals {
157 extra: OsExtraData::init(),
158 registry: Registry::new(OsStorage::init()),
159 }
160 }
161
globals() -> &'static Globals where OsExtraData: 'static + Send + Sync + Init, OsStorage: 'static + Send + Sync + Init,162 pub(crate) fn globals() -> &'static Globals
163 where
164 OsExtraData: 'static + Send + Sync + Init,
165 OsStorage: 'static + Send + Sync + Init,
166 {
167 static GLOBALS: OnceCell<Globals> = OnceCell::new();
168
169 GLOBALS.get(globals_init)
170 }
171
172 #[cfg(all(test, not(loom)))]
173 mod tests {
174 use super::*;
175 use crate::runtime::{self, Runtime};
176 use crate::sync::{oneshot, watch};
177
178 use futures::future;
179
180 #[test]
smoke()181 fn smoke() {
182 let rt = rt();
183 rt.block_on(async move {
184 let registry = Registry::new(vec![
185 EventInfo::default(),
186 EventInfo::default(),
187 EventInfo::default(),
188 ]);
189
190 let first = registry.register_listener(0);
191 let second = registry.register_listener(1);
192 let third = registry.register_listener(2);
193
194 let (fire, wait) = oneshot::channel();
195
196 crate::spawn(async {
197 wait.await.expect("wait failed");
198
199 // Record some events which should get coalesced
200 registry.record_event(0);
201 registry.record_event(0);
202 registry.record_event(1);
203 registry.record_event(1);
204 registry.broadcast();
205
206 // Yield so the previous broadcast can get received
207 //
208 // This yields many times since the block_on task is only polled every 61
209 // ticks.
210 for _ in 0..100 {
211 crate::task::yield_now().await;
212 }
213
214 // Send subsequent signal
215 registry.record_event(0);
216 registry.broadcast();
217
218 drop(registry);
219 });
220
221 let _ = fire.send(());
222 let all = future::join3(collect(first), collect(second), collect(third));
223
224 let (first_results, second_results, third_results) = all.await;
225 assert_eq!(2, first_results.len());
226 assert_eq!(1, second_results.len());
227 assert_eq!(0, third_results.len());
228 });
229 }
230
231 #[test]
232 #[should_panic = "invalid event_id: 1"]
register_panics_on_invalid_input()233 fn register_panics_on_invalid_input() {
234 let registry = Registry::new(vec![EventInfo::default()]);
235
236 registry.register_listener(1);
237 }
238
239 #[test]
record_invalid_event_does_nothing()240 fn record_invalid_event_does_nothing() {
241 let registry = Registry::new(vec![EventInfo::default()]);
242 registry.record_event(1302);
243 }
244
245 #[test]
broadcast_returns_if_at_least_one_event_fired()246 fn broadcast_returns_if_at_least_one_event_fired() {
247 let registry = Registry::new(vec![EventInfo::default(), EventInfo::default()]);
248
249 registry.record_event(0);
250 assert!(!registry.broadcast());
251
252 let first = registry.register_listener(0);
253 let second = registry.register_listener(1);
254
255 registry.record_event(0);
256 assert!(registry.broadcast());
257
258 drop(first);
259 registry.record_event(0);
260 assert!(!registry.broadcast());
261
262 drop(second);
263 }
264
rt() -> Runtime265 fn rt() -> Runtime {
266 runtime::Builder::new_current_thread()
267 .enable_time()
268 .build()
269 .unwrap()
270 }
271
collect(mut rx: watch::Receiver<()>) -> Vec<()>272 async fn collect(mut rx: watch::Receiver<()>) -> Vec<()> {
273 let mut ret = vec![];
274
275 while let Ok(v) = rx.changed().await {
276 ret.push(v);
277 }
278
279 ret
280 }
281 }
282