1 //! The futures-rs `join!` macro implementation.
2 
3 use proc_macro::TokenStream;
4 use proc_macro2::{Span, TokenStream as TokenStream2};
5 use quote::{format_ident, quote};
6 use syn::parse::{Parse, ParseStream};
7 use syn::{Expr, Ident, Token};
8 
9 #[derive(Default)]
10 struct Join {
11     fut_exprs: Vec<Expr>,
12 }
13 
14 impl Parse for Join {
parse(input: ParseStream<'_>) -> syn::Result<Self>15     fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
16         let mut join = Self::default();
17 
18         while !input.is_empty() {
19             join.fut_exprs.push(input.parse::<Expr>()?);
20 
21             if !input.is_empty() {
22                 input.parse::<Token![,]>()?;
23             }
24         }
25 
26         Ok(join)
27     }
28 }
29 
bind_futures(fut_exprs: Vec<Expr>, span: Span) -> (Vec<TokenStream2>, Vec<Ident>)30 fn bind_futures(fut_exprs: Vec<Expr>, span: Span) -> (Vec<TokenStream2>, Vec<Ident>) {
31     let mut future_let_bindings = Vec::with_capacity(fut_exprs.len());
32     let future_names: Vec<_> = fut_exprs
33         .into_iter()
34         .enumerate()
35         .map(|(i, expr)| {
36             let name = format_ident!("_fut{}", i, span = span);
37             future_let_bindings.push(quote! {
38                 // Move future into a local so that it is pinned in one place and
39                 // is no longer accessible by the end user.
40                 let mut #name = __futures_crate::future::maybe_done(#expr);
41                 let mut #name = unsafe { __futures_crate::Pin::new_unchecked(&mut #name) };
42             });
43             name
44         })
45         .collect();
46 
47     (future_let_bindings, future_names)
48 }
49 
50 /// The `join!` macro.
join(input: TokenStream) -> TokenStream51 pub(crate) fn join(input: TokenStream) -> TokenStream {
52     let parsed = syn::parse_macro_input!(input as Join);
53 
54     // should be def_site, but that's unstable
55     let span = Span::call_site();
56 
57     let (future_let_bindings, future_names) = bind_futures(parsed.fut_exprs, span);
58 
59     let poll_futures = future_names.iter().map(|fut| {
60         quote! {
61             __all_done &= __futures_crate::future::Future::poll(
62                 #fut.as_mut(), __cx).is_ready();
63         }
64     });
65     let take_outputs = future_names.iter().map(|fut| {
66         quote! {
67             #fut.as_mut().take_output().unwrap(),
68         }
69     });
70 
71     TokenStream::from(quote! { {
72         #( #future_let_bindings )*
73 
74         __futures_crate::future::poll_fn(move |__cx: &mut __futures_crate::task::Context<'_>| {
75             let mut __all_done = true;
76             #( #poll_futures )*
77             if __all_done {
78                 __futures_crate::task::Poll::Ready((
79                     #( #take_outputs )*
80                 ))
81             } else {
82                 __futures_crate::task::Poll::Pending
83             }
84         }).await
85     } })
86 }
87 
88 /// The `try_join!` macro.
try_join(input: TokenStream) -> TokenStream89 pub(crate) fn try_join(input: TokenStream) -> TokenStream {
90     let parsed = syn::parse_macro_input!(input as Join);
91 
92     // should be def_site, but that's unstable
93     let span = Span::call_site();
94 
95     let (future_let_bindings, future_names) = bind_futures(parsed.fut_exprs, span);
96 
97     let poll_futures = future_names.iter().map(|fut| {
98         quote! {
99             if __futures_crate::future::Future::poll(
100                 #fut.as_mut(), __cx).is_pending()
101             {
102                 __all_done = false;
103             } else if #fut.as_mut().output_mut().unwrap().is_err() {
104                 // `.err().unwrap()` rather than `.unwrap_err()` so that we don't introduce
105                 // a `T: Debug` bound.
106                 // Also, for an error type of ! any code after `err().unwrap()` is unreachable.
107                 #[allow(unreachable_code)]
108                 return __futures_crate::task::Poll::Ready(
109                     __futures_crate::Err(
110                         #fut.as_mut().take_output().unwrap().err().unwrap()
111                     )
112                 );
113             }
114         }
115     });
116     let take_outputs = future_names.iter().map(|fut| {
117         quote! {
118             // `.ok().unwrap()` rather than `.unwrap()` so that we don't introduce
119             // an `E: Debug` bound.
120             // Also, for an ok type of ! any code after `ok().unwrap()` is unreachable.
121             #[allow(unreachable_code)]
122             #fut.as_mut().take_output().unwrap().ok().unwrap(),
123         }
124     });
125 
126     TokenStream::from(quote! { {
127         #( #future_let_bindings )*
128 
129         #[allow(clippy::diverging_sub_expression)]
130         __futures_crate::future::poll_fn(move |__cx: &mut __futures_crate::task::Context<'_>| {
131             let mut __all_done = true;
132             #( #poll_futures )*
133             if __all_done {
134                 __futures_crate::task::Poll::Ready(
135                     __futures_crate::Ok((
136                         #( #take_outputs )*
137                     ))
138                 )
139             } else {
140                 __futures_crate::task::Poll::Pending
141             }
142         }).await
143     } })
144 }
145