1 //! The futures-rs `select!` macro implementation.
2 
3 use proc_macro::TokenStream;
4 use proc_macro2::Span;
5 use quote::{format_ident, quote};
6 use syn::parse::{Parse, ParseStream};
7 use syn::{parse_quote, Expr, Ident, Pat, Token};
8 
9 mod kw {
10     syn::custom_keyword!(complete);
11 }
12 
13 struct Select {
14     // span of `complete`, then expression after `=> ...`
15     complete: Option<Expr>,
16     default: Option<Expr>,
17     normal_fut_exprs: Vec<Expr>,
18     normal_fut_handlers: Vec<(Pat, Expr)>,
19 }
20 
21 #[allow(clippy::large_enum_variant)]
22 enum CaseKind {
23     Complete,
24     Default,
25     Normal(Pat, Expr),
26 }
27 
28 impl Parse for Select {
parse(input: ParseStream<'_>) -> syn::Result<Self>29     fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
30         let mut select = Self {
31             complete: None,
32             default: None,
33             normal_fut_exprs: vec![],
34             normal_fut_handlers: vec![],
35         };
36 
37         while !input.is_empty() {
38             let case_kind = if input.peek(kw::complete) {
39                 // `complete`
40                 if select.complete.is_some() {
41                     return Err(input.error("multiple `complete` cases found, only one allowed"));
42                 }
43                 input.parse::<kw::complete>()?;
44                 CaseKind::Complete
45             } else if input.peek(Token![default]) {
46                 // `default`
47                 if select.default.is_some() {
48                     return Err(input.error("multiple `default` cases found, only one allowed"));
49                 }
50                 input.parse::<Ident>()?;
51                 CaseKind::Default
52             } else {
53                 // `<pat> = <expr>`
54                 let pat = Pat::parse_multi_with_leading_vert(input)?;
55                 input.parse::<Token![=]>()?;
56                 let expr = input.parse()?;
57                 CaseKind::Normal(pat, expr)
58             };
59 
60             // `=> <expr>`
61             input.parse::<Token![=>]>()?;
62             let expr = Expr::parse_with_earlier_boundary_rule(input)?;
63 
64             // Commas after the expression are only optional if it's a `Block`
65             // or it is the last branch in the `match`.
66             let is_block = match expr {
67                 Expr::Block(_) => true,
68                 _ => false,
69             };
70             if is_block || input.is_empty() {
71                 input.parse::<Option<Token![,]>>()?;
72             } else {
73                 input.parse::<Token![,]>()?;
74             }
75 
76             match case_kind {
77                 CaseKind::Complete => select.complete = Some(expr),
78                 CaseKind::Default => select.default = Some(expr),
79                 CaseKind::Normal(pat, fut_expr) => {
80                     select.normal_fut_exprs.push(fut_expr);
81                     select.normal_fut_handlers.push((pat, expr));
82                 }
83             }
84         }
85 
86         Ok(select)
87     }
88 }
89 
90 // Enum over all the cases in which the `select!` waiting has completed and the result
91 // can be processed.
92 //
93 // `enum __PrivResult<_1, _2, ...> { _1(_1), _2(_2), ..., Complete }`
declare_result_enum( result_ident: Ident, variants: usize, complete: bool, span: Span, ) -> (Vec<Ident>, syn::ItemEnum)94 fn declare_result_enum(
95     result_ident: Ident,
96     variants: usize,
97     complete: bool,
98     span: Span,
99 ) -> (Vec<Ident>, syn::ItemEnum) {
100     // "_0", "_1", "_2"
101     let variant_names: Vec<Ident> =
102         (0..variants).map(|num| format_ident!("_{}", num, span = span)).collect();
103 
104     let type_parameters = &variant_names;
105     let variants = &variant_names;
106 
107     let complete_variant = if complete { Some(quote!(Complete)) } else { None };
108 
109     let enum_item = parse_quote! {
110         enum #result_ident<#(#type_parameters,)*> {
111             #(
112                 #variants(#type_parameters),
113             )*
114             #complete_variant
115         }
116     };
117 
118     (variant_names, enum_item)
119 }
120 
121 /// The `select!` macro.
select(input: TokenStream) -> TokenStream122 pub(crate) fn select(input: TokenStream) -> TokenStream {
123     select_inner(input, true)
124 }
125 
126 /// The `select_biased!` macro.
select_biased(input: TokenStream) -> TokenStream127 pub(crate) fn select_biased(input: TokenStream) -> TokenStream {
128     select_inner(input, false)
129 }
130 
select_inner(input: TokenStream, random: bool) -> TokenStream131 fn select_inner(input: TokenStream, random: bool) -> TokenStream {
132     let parsed = syn::parse_macro_input!(input as Select);
133 
134     // should be def_site, but that's unstable
135     let span = Span::call_site();
136 
137     let enum_ident = Ident::new("__PrivResult", span);
138 
139     let (variant_names, enum_item) = declare_result_enum(
140         enum_ident.clone(),
141         parsed.normal_fut_exprs.len(),
142         parsed.complete.is_some(),
143         span,
144     );
145 
146     // bind non-`Ident` future exprs w/ `let`
147     let mut future_let_bindings = Vec::with_capacity(parsed.normal_fut_exprs.len());
148     let bound_future_names: Vec<_> = parsed
149         .normal_fut_exprs
150         .into_iter()
151         .zip(variant_names.iter())
152         .map(|(expr, variant_name)| {
153             match expr {
154                 syn::Expr::Path(path) => {
155                     // Don't bind futures that are already a path.
156                     // This prevents creating redundant stack space
157                     // for them.
158                     // Passing Futures by path requires those Futures to implement Unpin.
159                     // We check for this condition here in order to be able to
160                     // safely use Pin::new_unchecked(&mut #path) later on.
161                     future_let_bindings.push(quote! {
162                         __futures_crate::async_await::assert_fused_future(&#path);
163                         __futures_crate::async_await::assert_unpin(&#path);
164                     });
165                     path
166                 }
167                 _ => {
168                     // Bind and pin the resulting Future on the stack. This is
169                     // necessary to support direct select! calls on !Unpin
170                     // Futures. The Future is not explicitly pinned here with
171                     // a Pin call, but assumed as pinned. The actual Pin is
172                     // created inside the poll() function below to defer the
173                     // creation of the temporary pointer, which would otherwise
174                     // increase the size of the generated Future.
175                     // Safety: This is safe since the lifetime of the Future
176                     // is totally constraint to the lifetime of the select!
177                     // expression, and the Future can't get moved inside it
178                     // (it is shadowed).
179                     future_let_bindings.push(quote! {
180                         let mut #variant_name = #expr;
181                     });
182                     parse_quote! { #variant_name }
183                 }
184             }
185         })
186         .collect();
187 
188     // For each future, make an `&mut dyn FnMut(&mut Context<'_>) -> Option<Poll<__PrivResult<...>>`
189     // to use for polling that individual future. These will then be put in an array.
190     let poll_functions = bound_future_names.iter().zip(variant_names.iter()).map(
191         |(bound_future_name, variant_name)| {
192             // Below we lazily create the Pin on the Future below.
193             // This is done in order to avoid allocating memory in the generator
194             // for the Pin variable.
195             // Safety: This is safe because one of the following condition applies:
196             // 1. The Future is passed by the caller by name, and we assert that
197             //    it implements Unpin.
198             // 2. The Future is created in scope of the select! function and will
199             //    not be moved for the duration of it. It is thereby stack-pinned
200             quote! {
201                 let mut #variant_name = |__cx: &mut __futures_crate::task::Context<'_>| {
202                     let mut #bound_future_name = unsafe {
203                         __futures_crate::Pin::new_unchecked(&mut #bound_future_name)
204                     };
205                     if __futures_crate::future::FusedFuture::is_terminated(&#bound_future_name) {
206                         __futures_crate::None
207                     } else {
208                         __futures_crate::Some(__futures_crate::future::FutureExt::poll_unpin(
209                             &mut #bound_future_name,
210                             __cx,
211                         ).map(#enum_ident::#variant_name))
212                     }
213                 };
214                 let #variant_name: &mut dyn FnMut(
215                     &mut __futures_crate::task::Context<'_>
216                 ) -> __futures_crate::Option<__futures_crate::task::Poll<_>> = &mut #variant_name;
217             }
218         },
219     );
220 
221     let none_polled = if parsed.complete.is_some() {
222         quote! {
223             __futures_crate::task::Poll::Ready(#enum_ident::Complete)
224         }
225     } else {
226         quote! {
227             panic!("all futures in select! were completed,\
228                     but no `complete =>` handler was provided")
229         }
230     };
231 
232     let branches = parsed.normal_fut_handlers.into_iter().zip(variant_names.iter()).map(
233         |((pat, expr), variant_name)| {
234             quote! {
235                 #enum_ident::#variant_name(#pat) => #expr,
236             }
237         },
238     );
239     let branches = quote! { #( #branches )* };
240 
241     let complete_branch = parsed.complete.map(|complete_expr| {
242         quote! {
243             #enum_ident::Complete => { #complete_expr },
244         }
245     });
246 
247     let branches = quote! {
248         #branches
249         #complete_branch
250     };
251 
252     let await_select_fut = if parsed.default.is_some() {
253         // For select! with default this returns the Poll result
254         quote! {
255             __poll_fn(&mut __futures_crate::task::Context::from_waker(
256                 __futures_crate::task::noop_waker_ref()
257             ))
258         }
259     } else {
260         quote! {
261             __futures_crate::future::poll_fn(__poll_fn).await
262         }
263     };
264 
265     let execute_result_expr = if let Some(default_expr) = &parsed.default {
266         // For select! with default __select_result is a Poll, otherwise not
267         quote! {
268             match __select_result {
269                 __futures_crate::task::Poll::Ready(result) => match result {
270                     #branches
271                 },
272                 _ => #default_expr
273             }
274         }
275     } else {
276         quote! {
277             match __select_result {
278                 #branches
279             }
280         }
281     };
282 
283     let shuffle = if random {
284         quote! {
285             __futures_crate::async_await::shuffle(&mut __select_arr);
286         }
287     } else {
288         quote!()
289     };
290 
291     TokenStream::from(quote! { {
292         #enum_item
293 
294         let __select_result = {
295             #( #future_let_bindings )*
296 
297             let mut __poll_fn = |__cx: &mut __futures_crate::task::Context<'_>| {
298                 let mut __any_polled = false;
299 
300                 #( #poll_functions )*
301 
302                 let mut __select_arr = [#( #variant_names ),*];
303                 #shuffle
304                 for poller in &mut __select_arr {
305                     let poller: &mut &mut dyn FnMut(
306                         &mut __futures_crate::task::Context<'_>
307                     ) -> __futures_crate::Option<__futures_crate::task::Poll<_>> = poller;
308                     match poller(__cx) {
309                         __futures_crate::Some(x @ __futures_crate::task::Poll::Ready(_)) =>
310                             return x,
311                         __futures_crate::Some(__futures_crate::task::Poll::Pending) => {
312                             __any_polled = true;
313                         }
314                         __futures_crate::None => {}
315                     }
316                 }
317 
318                 if !__any_polled {
319                     #none_polled
320                 } else {
321                     __futures_crate::task::Poll::Pending
322                 }
323             };
324 
325             #await_select_fut
326         };
327 
328         #execute_result_expr
329     } })
330 }
331