1 use pin_project_lite::pin_project;
2 use std::cell::RefCell;
3 use std::error::Error;
4 use std::future::Future;
5 use std::marker::PhantomPinned;
6 use std::pin::Pin;
7 use std::task::{Context, Poll};
8 use std::{fmt, mem, thread};
9 
10 /// Declares a new task-local key of type [`tokio::task::LocalKey`].
11 ///
12 /// # Syntax
13 ///
14 /// The macro wraps any number of static declarations and makes them local to the current task.
15 /// Publicity and attributes for each static is preserved. For example:
16 ///
17 /// # Examples
18 ///
19 /// ```
20 /// # use tokio::task_local;
21 /// task_local! {
22 ///     pub static ONE: u32;
23 ///
24 ///     #[allow(unused)]
25 ///     static TWO: f32;
26 /// }
27 /// # fn main() {}
28 /// ```
29 ///
30 /// See [`LocalKey` documentation][`tokio::task::LocalKey`] for more
31 /// information.
32 ///
33 /// [`tokio::task::LocalKey`]: struct@crate::task::LocalKey
34 #[macro_export]
35 #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
36 macro_rules! task_local {
37      // empty (base case for the recursion)
38     () => {};
39 
40     ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty; $($rest:tt)*) => {
41         $crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
42         $crate::task_local!($($rest)*);
43     };
44 
45     ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty) => {
46         $crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
47     }
48 }
49 
50 #[doc(hidden)]
51 #[macro_export]
52 macro_rules! __task_local_inner {
53     ($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => {
54         $(#[$attr])*
55         $vis static $name: $crate::task::LocalKey<$t> = {
56             std::thread_local! {
57                 static __KEY: std::cell::RefCell<Option<$t>> = const { std::cell::RefCell::new(None) };
58             }
59 
60             $crate::task::LocalKey { inner: __KEY }
61         };
62     };
63 }
64 
65 /// A key for task-local data.
66 ///
67 /// This type is generated by the [`task_local!`] macro.
68 ///
69 /// Unlike [`std::thread::LocalKey`], `tokio::task::LocalKey` will
70 /// _not_ lazily initialize the value on first access. Instead, the
71 /// value is first initialized when the future containing
72 /// the task-local is first polled by a futures executor, like Tokio.
73 ///
74 /// # Examples
75 ///
76 /// ```
77 /// # async fn dox() {
78 /// tokio::task_local! {
79 ///     static NUMBER: u32;
80 /// }
81 ///
82 /// NUMBER.scope(1, async move {
83 ///     assert_eq!(NUMBER.get(), 1);
84 /// }).await;
85 ///
86 /// NUMBER.scope(2, async move {
87 ///     assert_eq!(NUMBER.get(), 2);
88 ///
89 ///     NUMBER.scope(3, async move {
90 ///         assert_eq!(NUMBER.get(), 3);
91 ///     }).await;
92 /// }).await;
93 /// # }
94 /// ```
95 ///
96 /// [`std::thread::LocalKey`]: struct@std::thread::LocalKey
97 /// [`task_local!`]: ../macro.task_local.html
98 #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
99 pub struct LocalKey<T: 'static> {
100     #[doc(hidden)]
101     pub inner: thread::LocalKey<RefCell<Option<T>>>,
102 }
103 
104 impl<T: 'static> LocalKey<T> {
105     /// Sets a value `T` as the task-local value for the future `F`.
106     ///
107     /// On completion of `scope`, the task-local will be dropped.
108     ///
109     /// ### Panics
110     ///
111     /// If you poll the returned future inside a call to [`with`] or
112     /// [`try_with`] on the same `LocalKey`, then the call to `poll` will panic.
113     ///
114     /// ### Examples
115     ///
116     /// ```
117     /// # async fn dox() {
118     /// tokio::task_local! {
119     ///     static NUMBER: u32;
120     /// }
121     ///
122     /// NUMBER.scope(1, async move {
123     ///     println!("task local value: {}", NUMBER.get());
124     /// }).await;
125     /// # }
126     /// ```
127     ///
128     /// [`with`]: fn@Self::with
129     /// [`try_with`]: fn@Self::try_with
scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F> where F: Future,130     pub fn scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F>
131     where
132         F: Future,
133     {
134         TaskLocalFuture {
135             local: self,
136             slot: Some(value),
137             future: Some(f),
138             _pinned: PhantomPinned,
139         }
140     }
141 
142     /// Sets a value `T` as the task-local value for the closure `F`.
143     ///
144     /// On completion of `sync_scope`, the task-local will be dropped.
145     ///
146     /// ### Panics
147     ///
148     /// This method panics if called inside a call to [`with`] or [`try_with`]
149     /// on the same `LocalKey`.
150     ///
151     /// ### Examples
152     ///
153     /// ```
154     /// # async fn dox() {
155     /// tokio::task_local! {
156     ///     static NUMBER: u32;
157     /// }
158     ///
159     /// NUMBER.sync_scope(1, || {
160     ///     println!("task local value: {}", NUMBER.get());
161     /// });
162     /// # }
163     /// ```
164     ///
165     /// [`with`]: fn@Self::with
166     /// [`try_with`]: fn@Self::try_with
167     #[track_caller]
sync_scope<F, R>(&'static self, value: T, f: F) -> R where F: FnOnce() -> R,168     pub fn sync_scope<F, R>(&'static self, value: T, f: F) -> R
169     where
170         F: FnOnce() -> R,
171     {
172         let mut value = Some(value);
173         match self.scope_inner(&mut value, f) {
174             Ok(res) => res,
175             Err(err) => err.panic(),
176         }
177     }
178 
scope_inner<F, R>(&'static self, slot: &mut Option<T>, f: F) -> Result<R, ScopeInnerErr> where F: FnOnce() -> R,179     fn scope_inner<F, R>(&'static self, slot: &mut Option<T>, f: F) -> Result<R, ScopeInnerErr>
180     where
181         F: FnOnce() -> R,
182     {
183         struct Guard<'a, T: 'static> {
184             local: &'static LocalKey<T>,
185             slot: &'a mut Option<T>,
186         }
187 
188         impl<'a, T: 'static> Drop for Guard<'a, T> {
189             fn drop(&mut self) {
190                 // This should not panic.
191                 //
192                 // We know that the RefCell was not borrowed before the call to
193                 // `scope_inner`, so the only way for this to panic is if the
194                 // closure has created but not destroyed a RefCell guard.
195                 // However, we never give user-code access to the guards, so
196                 // there's no way for user-code to forget to destroy a guard.
197                 //
198                 // The call to `with` also should not panic, since the
199                 // thread-local wasn't destroyed when we first called
200                 // `scope_inner`, and it shouldn't have gotten destroyed since
201                 // then.
202                 self.local.inner.with(|inner| {
203                     let mut ref_mut = inner.borrow_mut();
204                     mem::swap(self.slot, &mut *ref_mut);
205                 });
206             }
207         }
208 
209         self.inner.try_with(|inner| {
210             inner
211                 .try_borrow_mut()
212                 .map(|mut ref_mut| mem::swap(slot, &mut *ref_mut))
213         })??;
214 
215         let guard = Guard { local: self, slot };
216 
217         let res = f();
218 
219         drop(guard);
220 
221         Ok(res)
222     }
223 
224     /// Accesses the current task-local and runs the provided closure.
225     ///
226     /// # Panics
227     ///
228     /// This function will panic if the task local doesn't have a value set.
229     #[track_caller]
with<F, R>(&'static self, f: F) -> R where F: FnOnce(&T) -> R,230     pub fn with<F, R>(&'static self, f: F) -> R
231     where
232         F: FnOnce(&T) -> R,
233     {
234         match self.try_with(f) {
235             Ok(res) => res,
236             Err(_) => panic!("cannot access a task-local storage value without setting it first"),
237         }
238     }
239 
240     /// Accesses the current task-local and runs the provided closure.
241     ///
242     /// If the task-local with the associated key is not present, this
243     /// method will return an `AccessError`. For a panicking variant,
244     /// see `with`.
try_with<F, R>(&'static self, f: F) -> Result<R, AccessError> where F: FnOnce(&T) -> R,245     pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError>
246     where
247         F: FnOnce(&T) -> R,
248     {
249         // If called after the thread-local storing the task-local is destroyed,
250         // then we are outside of a closure where the task-local is set.
251         //
252         // Therefore, it is correct to return an AccessError if `try_with`
253         // returns an error.
254         let try_with_res = self.inner.try_with(|v| {
255             // This call to `borrow` cannot panic because no user-defined code
256             // runs while a `borrow_mut` call is active.
257             v.borrow().as_ref().map(f)
258         });
259 
260         match try_with_res {
261             Ok(Some(res)) => Ok(res),
262             Ok(None) | Err(_) => Err(AccessError { _private: () }),
263         }
264     }
265 }
266 
267 impl<T: Clone + 'static> LocalKey<T> {
268     /// Returns a copy of the task-local value
269     /// if the task-local value implements `Clone`.
270     ///
271     /// # Panics
272     ///
273     /// This function will panic if the task local doesn't have a value set.
274     #[track_caller]
get(&'static self) -> T275     pub fn get(&'static self) -> T {
276         self.with(|v| v.clone())
277     }
278 }
279 
280 impl<T: 'static> fmt::Debug for LocalKey<T> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result281     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
282         f.pad("LocalKey { .. }")
283     }
284 }
285 
286 pin_project! {
287     /// A future that sets a value `T` of a task local for the future `F` during
288     /// its execution.
289     ///
290     /// The value of the task-local must be `'static` and will be dropped on the
291     /// completion of the future.
292     ///
293     /// Created by the function [`LocalKey::scope`](self::LocalKey::scope).
294     ///
295     /// ### Examples
296     ///
297     /// ```
298     /// # async fn dox() {
299     /// tokio::task_local! {
300     ///     static NUMBER: u32;
301     /// }
302     ///
303     /// NUMBER.scope(1, async move {
304     ///     println!("task local value: {}", NUMBER.get());
305     /// }).await;
306     /// # }
307     /// ```
308     pub struct TaskLocalFuture<T, F>
309     where
310         T: 'static,
311     {
312         local: &'static LocalKey<T>,
313         slot: Option<T>,
314         #[pin]
315         future: Option<F>,
316         #[pin]
317         _pinned: PhantomPinned,
318     }
319 
320     impl<T: 'static, F> PinnedDrop for TaskLocalFuture<T, F> {
321         fn drop(this: Pin<&mut Self>) {
322             let this = this.project();
323             if mem::needs_drop::<F>() && this.future.is_some() {
324                 // Drop the future while the task-local is set, if possible. Otherwise
325                 // the future is dropped normally when the `Option<F>` field drops.
326                 let mut future = this.future;
327                 let _ = this.local.scope_inner(this.slot, || {
328                     future.set(None);
329                 });
330             }
331         }
332     }
333 }
334 
335 impl<T, F> TaskLocalFuture<T, F>
336 where
337     T: 'static,
338 {
339     /// Returns the value stored in the task local by this `TaskLocalFuture`.
340     ///
341     /// The function returns:
342     ///
343     /// * `Some(T)` if the task local value exists.
344     /// * `None` if the task local value has already been taken.
345     ///
346     /// Note that this function attempts to take the task local value even if
347     /// the future has not yet completed. In that case, the value will no longer
348     /// be available via the task local after the call to `take_value`.
349     ///
350     /// # Examples
351     ///
352     /// ```
353     /// # async fn dox() {
354     /// tokio::task_local! {
355     ///     static KEY: u32;
356     /// }
357     ///
358     /// let fut = KEY.scope(42, async {
359     ///     // Do some async work
360     /// });
361     ///
362     /// let mut pinned = Box::pin(fut);
363     ///
364     /// // Complete the TaskLocalFuture
365     /// let _ = pinned.as_mut().await;
366     ///
367     /// // And here, we can take task local value
368     /// let value = pinned.as_mut().take_value();
369     ///
370     /// assert_eq!(value, Some(42));
371     /// # }
372     /// ```
take_value(self: Pin<&mut Self>) -> Option<T>373     pub fn take_value(self: Pin<&mut Self>) -> Option<T> {
374         let this = self.project();
375         this.slot.take()
376     }
377 }
378 
379 impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> {
380     type Output = F::Output;
381 
382     #[track_caller]
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>383     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
384         let this = self.project();
385         let mut future_opt = this.future;
386 
387         let res = this
388             .local
389             .scope_inner(this.slot, || match future_opt.as_mut().as_pin_mut() {
390                 Some(fut) => {
391                     let res = fut.poll(cx);
392                     if res.is_ready() {
393                         future_opt.set(None);
394                     }
395                     Some(res)
396                 }
397                 None => None,
398             });
399 
400         match res {
401             Ok(Some(res)) => res,
402             Ok(None) => panic!("`TaskLocalFuture` polled after completion"),
403             Err(err) => err.panic(),
404         }
405     }
406 }
407 
408 impl<T: 'static, F> fmt::Debug for TaskLocalFuture<T, F>
409 where
410     T: fmt::Debug,
411 {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result412     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
413         /// Format the Option without Some.
414         struct TransparentOption<'a, T> {
415             value: &'a Option<T>,
416         }
417         impl<'a, T: fmt::Debug> fmt::Debug for TransparentOption<'a, T> {
418             fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
419                 match self.value.as_ref() {
420                     Some(value) => value.fmt(f),
421                     // Hitting the None branch should not be possible.
422                     None => f.pad("<missing>"),
423                 }
424             }
425         }
426 
427         f.debug_struct("TaskLocalFuture")
428             .field("value", &TransparentOption { value: &self.slot })
429             .finish()
430     }
431 }
432 
433 /// An error returned by [`LocalKey::try_with`](method@LocalKey::try_with).
434 #[derive(Clone, Copy, Eq, PartialEq)]
435 pub struct AccessError {
436     _private: (),
437 }
438 
439 impl fmt::Debug for AccessError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result440     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
441         f.debug_struct("AccessError").finish()
442     }
443 }
444 
445 impl fmt::Display for AccessError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result446     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
447         fmt::Display::fmt("task-local value not set", f)
448     }
449 }
450 
451 impl Error for AccessError {}
452 
453 enum ScopeInnerErr {
454     BorrowError,
455     AccessError,
456 }
457 
458 impl ScopeInnerErr {
459     #[track_caller]
panic(&self) -> !460     fn panic(&self) -> ! {
461         match self {
462             Self::BorrowError => panic!("cannot enter a task-local scope while the task-local storage is borrowed"),
463             Self::AccessError => panic!("cannot enter a task-local scope during or after destruction of the underlying thread-local"),
464         }
465     }
466 }
467 
468 impl From<std::cell::BorrowMutError> for ScopeInnerErr {
from(_: std::cell::BorrowMutError) -> Self469     fn from(_: std::cell::BorrowMutError) -> Self {
470         Self::BorrowError
471     }
472 }
473 
474 impl From<std::thread::AccessError> for ScopeInnerErr {
from(_: std::thread::AccessError) -> Self475     fn from(_: std::thread::AccessError) -> Self {
476         Self::AccessError
477     }
478 }
479