1 use super::attr::AttrsHelper;
2 use proc_macro2::{Span, TokenStream};
3 use quote::{format_ident, quote};
4 use syn::{
5     punctuated::Punctuated,
6     token::{Colon, Comma, PathSep, Plus, Where},
7     Data, DataEnum, DataStruct, DeriveInput, Error, Fields, Generics, Ident, Path, PathArguments,
8     PathSegment, PredicateType, Result, TraitBound, TraitBoundModifier, Type, TypeParam,
9     TypeParamBound, TypePath, WhereClause, WherePredicate,
10 };
11 
12 use std::collections::HashMap;
13 
derive(input: &DeriveInput) -> Result<TokenStream>14 pub(crate) fn derive(input: &DeriveInput) -> Result<TokenStream> {
15     let impls = match &input.data {
16         Data::Struct(data) => impl_struct(input, data),
17         Data::Enum(data) => impl_enum(input, data),
18         Data::Union(_) => Err(Error::new_spanned(input, "Unions are not supported")),
19     }?;
20 
21     let helpers = specialization();
22     Ok(quote! {
23         #[allow(non_upper_case_globals, unused_attributes, unused_qualifications)]
24         const _: () = {
25             #helpers
26             #impls
27         };
28     })
29 }
30 
31 #[cfg(feature = "std")]
specialization() -> TokenStream32 fn specialization() -> TokenStream {
33     quote! {
34         trait DisplayToDisplayDoc {
35             fn __displaydoc_display(&self) -> Self;
36         }
37 
38         impl<T: ::core::fmt::Display> DisplayToDisplayDoc for &T {
39             fn __displaydoc_display(&self) -> Self {
40                 self
41             }
42         }
43 
44         // If the `std` feature gets enabled we want to ensure that any crate
45         // using displaydoc can still reference the std crate, which is already
46         // being compiled in by whoever enabled the `std` feature in
47         // `displaydoc`, even if the crates using displaydoc are no_std.
48         extern crate std;
49 
50         trait PathToDisplayDoc {
51             fn __displaydoc_display(&self) -> std::path::Display<'_>;
52         }
53 
54         impl PathToDisplayDoc for std::path::Path {
55             fn __displaydoc_display(&self) -> std::path::Display<'_> {
56                 self.display()
57             }
58         }
59 
60         impl PathToDisplayDoc for std::path::PathBuf {
61             fn __displaydoc_display(&self) -> std::path::Display<'_> {
62                 self.display()
63             }
64         }
65     }
66 }
67 
68 #[cfg(not(feature = "std"))]
specialization() -> TokenStream69 fn specialization() -> TokenStream {
70     quote! {}
71 }
72 
impl_struct(input: &DeriveInput, data: &DataStruct) -> Result<TokenStream>73 fn impl_struct(input: &DeriveInput, data: &DataStruct) -> Result<TokenStream> {
74     let ty = &input.ident;
75     let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
76     let where_clause = generate_where_clause(&input.generics, where_clause);
77 
78     let helper = AttrsHelper::new(&input.attrs);
79 
80     let display = helper.display(&input.attrs)?.map(|display| {
81         let pat = match &data.fields {
82             Fields::Named(fields) => {
83                 let var = fields.named.iter().map(|field| &field.ident);
84                 quote!(Self { #(#var),* })
85             }
86             Fields::Unnamed(fields) => {
87                 let var = (0..fields.unnamed.len()).map(|i| format_ident!("_{}", i));
88                 quote!(Self(#(#var),*))
89             }
90             Fields::Unit => quote!(_),
91         };
92         quote! {
93             impl #impl_generics ::core::fmt::Display for #ty #ty_generics #where_clause {
94                 fn fmt(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
95                     // NB: This destructures the fields of `self` into named variables (for unnamed
96                     // fields, it uses _0, _1, etc as above). The `#[allow(unused_variables)]`
97                     // section means it doesn't have to parse the individual field references out of
98                     // the docstring.
99                     #[allow(unused_variables)]
100                     let #pat = self;
101                     #display
102                 }
103             }
104         }
105     });
106 
107     Ok(quote! { #display })
108 }
109 
110 /// Create a `where` predicate for `ident`, without any [bound][TypeParamBound]s yet.
new_empty_where_type_predicate(ident: Ident) -> PredicateType111 fn new_empty_where_type_predicate(ident: Ident) -> PredicateType {
112     let mut path_segments = Punctuated::<PathSegment, PathSep>::new();
113     path_segments.push_value(PathSegment {
114         ident,
115         arguments: PathArguments::None,
116     });
117     PredicateType {
118         lifetimes: None,
119         bounded_ty: Type::Path(TypePath {
120             qself: None,
121             path: Path {
122                 leading_colon: None,
123                 segments: path_segments,
124             },
125         }),
126         colon_token: Colon {
127             spans: [Span::call_site()],
128         },
129         bounds: Punctuated::<TypeParamBound, Plus>::new(),
130     }
131 }
132 
133 /// Create a `where` clause that we can add [WherePredicate]s to.
new_empty_where_clause() -> WhereClause134 fn new_empty_where_clause() -> WhereClause {
135     WhereClause {
136         where_token: Where {
137             span: Span::call_site(),
138         },
139         predicates: Punctuated::<WherePredicate, Comma>::new(),
140     }
141 }
142 
143 enum UseGlobalPrefix {
144     LeadingColon,
145     #[allow(dead_code)]
146     NoLeadingColon,
147 }
148 
149 /// Create a path with segments composed of [Idents] *without* any [PathArguments].
join_paths(name_segments: &[&str], use_global_prefix: UseGlobalPrefix) -> Path150 fn join_paths(name_segments: &[&str], use_global_prefix: UseGlobalPrefix) -> Path {
151     let mut segments = Punctuated::<PathSegment, PathSep>::new();
152     assert!(!name_segments.is_empty());
153     segments.push_value(PathSegment {
154         ident: Ident::new(name_segments[0], Span::call_site()),
155         arguments: PathArguments::None,
156     });
157     for name in name_segments[1..].iter() {
158         segments.push_punct(PathSep {
159             spans: [Span::call_site(), Span::mixed_site()],
160         });
161         segments.push_value(PathSegment {
162             ident: Ident::new(name, Span::call_site()),
163             arguments: PathArguments::None,
164         });
165     }
166     Path {
167         leading_colon: match use_global_prefix {
168             UseGlobalPrefix::LeadingColon => Some(PathSep {
169                 spans: [Span::call_site(), Span::mixed_site()],
170             }),
171             UseGlobalPrefix::NoLeadingColon => None,
172         },
173         segments,
174     }
175 }
176 
177 /// Push `new_type_predicate` onto the end of `where_clause`.
append_where_clause_type_predicate( where_clause: &mut WhereClause, new_type_predicate: PredicateType, )178 fn append_where_clause_type_predicate(
179     where_clause: &mut WhereClause,
180     new_type_predicate: PredicateType,
181 ) {
182     // Push a comma at the end if there are already any `where` predicates.
183     if !where_clause.predicates.is_empty() {
184         where_clause.predicates.push_punct(Comma {
185             spans: [Span::call_site()],
186         });
187     }
188     where_clause
189         .predicates
190         .push_value(WherePredicate::Type(new_type_predicate));
191 }
192 
193 /// Add a requirement for [core::fmt::Display] to a `where` predicate for some type.
add_display_constraint_to_type_predicate( predicate_that_needs_a_display_impl: &mut PredicateType, )194 fn add_display_constraint_to_type_predicate(
195     predicate_that_needs_a_display_impl: &mut PredicateType,
196 ) {
197     // Create a `Path` of `::core::fmt::Display`.
198     let display_path = join_paths(&["core", "fmt", "Display"], UseGlobalPrefix::LeadingColon);
199 
200     let display_bound = TypeParamBound::Trait(TraitBound {
201         paren_token: None,
202         modifier: TraitBoundModifier::None,
203         lifetimes: None,
204         path: display_path,
205     });
206     if !predicate_that_needs_a_display_impl.bounds.is_empty() {
207         predicate_that_needs_a_display_impl.bounds.push_punct(Plus {
208             spans: [Span::call_site()],
209         });
210     }
211 
212     predicate_that_needs_a_display_impl
213         .bounds
214         .push_value(display_bound);
215 }
216 
217 /// Map each declared generic type parameter to the set of all trait boundaries declared on it.
218 ///
219 /// These boundaries may come from the declaration site:
220 ///     pub enum E<T: MyTrait> { ... }
221 /// or a `where` clause after the parameter declarations:
222 ///     pub enum E<T> where T: MyTrait { ... }
223 /// This method will return the boundaries from both of those cases.
extract_trait_constraints_from_source( where_clause: &WhereClause, type_params: &[&TypeParam], ) -> HashMap<Ident, Vec<TraitBound>>224 fn extract_trait_constraints_from_source(
225     where_clause: &WhereClause,
226     type_params: &[&TypeParam],
227 ) -> HashMap<Ident, Vec<TraitBound>> {
228     // Add trait bounds provided at the declaration site of type parameters for the struct/enum.
229     let mut param_constraint_mapping: HashMap<Ident, Vec<TraitBound>> = type_params
230         .iter()
231         .map(|type_param| {
232             let trait_bounds: Vec<TraitBound> = type_param
233                 .bounds
234                 .iter()
235                 .flat_map(|bound| match bound {
236                     TypeParamBound::Trait(trait_bound) => Some(trait_bound),
237                     _ => None,
238                 })
239                 .cloned()
240                 .collect();
241             (type_param.ident.clone(), trait_bounds)
242         })
243         .collect();
244 
245     // Add trait bounds from `where` clauses, which may be type parameters or types containing
246     // those parameters.
247     for predicate in where_clause.predicates.iter() {
248         // We only care about type and not lifetime constraints here.
249         if let WherePredicate::Type(ref pred_ty) = predicate {
250             let ident = match &pred_ty.bounded_ty {
251                 Type::Path(TypePath { path, qself: None }) => match path.get_ident() {
252                     None => continue,
253                     Some(ident) => ident,
254                 },
255                 _ => continue,
256             };
257             // We ignore any type constraints that aren't direct references to type
258             // parameters of the current enum of struct definition. No types can be
259             // constrained in a `where` clause unless they are a type parameter or a generic
260             // type instantiated with one of the type parameters, so by only allowing single
261             // identifiers, we can be sure that the constrained type is a type parameter
262             // that is contained in `param_constraint_mapping`.
263             if let Some((_, ref mut known_bounds)) = param_constraint_mapping
264                 .iter_mut()
265                 .find(|(id, _)| *id == ident)
266             {
267                 for bound in pred_ty.bounds.iter() {
268                     // We only care about trait bounds here.
269                     if let TypeParamBound::Trait(ref bound) = bound {
270                         known_bounds.push(bound.clone());
271                     }
272                 }
273             }
274         }
275     }
276 
277     param_constraint_mapping
278 }
279 
280 /// Hygienically add `where _: Display` to the set of [TypeParamBound]s for `ident`, creating such
281 /// a set if necessary.
ensure_display_in_where_clause_for_type(where_clause: &mut WhereClause, ident: Ident)282 fn ensure_display_in_where_clause_for_type(where_clause: &mut WhereClause, ident: Ident) {
283     for pred_ty in where_clause
284         .predicates
285         .iter_mut()
286         // Find the `where` predicate constraining the current type param, if it exists.
287         .flat_map(|predicate| match predicate {
288             WherePredicate::Type(pred_ty) => Some(pred_ty),
289             // We're looking through type constraints, not lifetime constraints.
290             _ => None,
291         })
292     {
293         // Do a complicated destructuring in order to check if the type being constrained in this
294         // `where` clause is the type we're looking for, so we can use the mutable reference to
295         // `pred_ty` if so.
296         let matches_desired_type = matches!(
297             &pred_ty.bounded_ty,
298             Type::Path(TypePath { path, .. }) if Some(&ident) == path.get_ident());
299         if matches_desired_type {
300             add_display_constraint_to_type_predicate(pred_ty);
301             return;
302         }
303     }
304 
305     // If there is no `where` predicate for the current type param, we will construct one.
306     let mut new_type_predicate = new_empty_where_type_predicate(ident);
307     add_display_constraint_to_type_predicate(&mut new_type_predicate);
308     append_where_clause_type_predicate(where_clause, new_type_predicate);
309 }
310 
311 /// For all declared type parameters, add a [core::fmt::Display] constraint, unless the type
312 /// parameter already has any type constraint.
ensure_where_clause_has_display_for_all_unconstrained_members( where_clause: &mut WhereClause, type_params: &[&TypeParam], )313 fn ensure_where_clause_has_display_for_all_unconstrained_members(
314     where_clause: &mut WhereClause,
315     type_params: &[&TypeParam],
316 ) {
317     let param_constraint_mapping = extract_trait_constraints_from_source(where_clause, type_params);
318 
319     for (ident, known_bounds) in param_constraint_mapping.into_iter() {
320         // If the type parameter has any constraints already, we don't want to touch it, to avoid
321         // breaking use cases where a type parameter only needs to impl `Debug`, for example.
322         if known_bounds.is_empty() {
323             ensure_display_in_where_clause_for_type(where_clause, ident);
324         }
325     }
326 }
327 
328 /// Generate a `where` clause that ensures all generic type parameters `impl`
329 /// [core::fmt::Display] unless already constrained.
330 ///
331 /// This approach allows struct/enum definitions deriving [crate::Display] to avoid hardcoding
332 /// a [core::fmt::Display] constraint into every type parameter.
333 ///
334 /// If the type parameter isn't already constrained, we add a `where _: Display` clause to our
335 /// display implementation to expect to be able to format every enum case or struct member.
336 ///
337 /// In fact, we would preferably only require `where _: Display` or `where _: Debug` where the
338 /// format string actually requires it. However, while [`std::fmt` defines a formal syntax for
339 /// `format!()`][format syntax], it *doesn't* expose the actual logic to parse the format string,
340 /// which appears to live in [`rustc_parse_format`]. While we use the [`syn`] crate to parse rust
341 /// syntax, it also doesn't currently provide any method to introspect a `format!()` string. It
342 /// would be nice to contribute this upstream in [`syn`].
343 ///
344 /// [format syntax]: std::fmt#syntax
345 /// [`rustc_parse_format`]: https://doc.rust-lang.org/nightly/nightly-rustc/rustc_parse_format/index.html
generate_where_clause(generics: &Generics, where_clause: Option<&WhereClause>) -> WhereClause346 fn generate_where_clause(generics: &Generics, where_clause: Option<&WhereClause>) -> WhereClause {
347     let mut where_clause = where_clause.cloned().unwrap_or_else(new_empty_where_clause);
348     let type_params: Vec<&TypeParam> = generics.type_params().collect();
349     ensure_where_clause_has_display_for_all_unconstrained_members(&mut where_clause, &type_params);
350     where_clause
351 }
352 
impl_enum(input: &DeriveInput, data: &DataEnum) -> Result<TokenStream>353 fn impl_enum(input: &DeriveInput, data: &DataEnum) -> Result<TokenStream> {
354     let ty = &input.ident;
355     let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
356     let where_clause = generate_where_clause(&input.generics, where_clause);
357 
358     let helper = AttrsHelper::new(&input.attrs);
359 
360     let displays = data
361         .variants
362         .iter()
363         .map(|variant| helper.display_with_input(&input.attrs, &variant.attrs))
364         .collect::<Result<Vec<_>>>()?;
365 
366     if data.variants.is_empty() {
367         Ok(quote! {
368             impl #impl_generics ::core::fmt::Display for #ty #ty_generics #where_clause {
369                 fn fmt(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
370                     unreachable!("empty enums cannot be instantiated and thus cannot be printed")
371                 }
372             }
373         })
374     } else if displays.iter().any(Option::is_some) {
375         let arms = data
376             .variants
377             .iter()
378             .zip(displays)
379             .map(|(variant, display)| {
380                 let display =
381                     display.ok_or_else(|| Error::new_spanned(variant, "missing doc comment"))?;
382                 let ident = &variant.ident;
383                 Ok(match &variant.fields {
384                     Fields::Named(fields) => {
385                         let var = fields.named.iter().map(|field| &field.ident);
386                         quote!(Self::#ident { #(#var),* } => { #display })
387                     }
388                     Fields::Unnamed(fields) => {
389                         let var = (0..fields.unnamed.len()).map(|i| format_ident!("_{}", i));
390                         quote!(Self::#ident(#(#var),*) => { #display })
391                     }
392                     Fields::Unit => quote!(Self::#ident => { #display }),
393                 })
394             })
395             .collect::<Result<Vec<_>>>()?;
396         Ok(quote! {
397             impl #impl_generics ::core::fmt::Display for #ty #ty_generics #where_clause {
398                 fn fmt(&self, formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
399                     #[allow(unused_variables)]
400                     match self {
401                         #(#arms,)*
402                     }
403                 }
404             }
405         })
406     } else {
407         Err(Error::new_spanned(input, "Missing doc comments"))
408     }
409 }
410