1 //! Derive macros for [bytemuck](https://docs.rs/bytemuck) traits.
2 
3 extern crate proc_macro;
4 
5 mod traits;
6 
7 use proc_macro2::TokenStream;
8 use quote::quote;
9 use syn::{parse_macro_input, DeriveInput, Result};
10 
11 use crate::traits::{
12   bytemuck_crate_name, AnyBitPattern, CheckedBitPattern, Contiguous, Derivable,
13   NoUninit, Pod, TransparentWrapper, Zeroable,
14 };
15 
16 /// Derive the `Pod` trait for a struct
17 ///
18 /// The macro ensures that the struct follows all the the safety requirements
19 /// for the `Pod` trait.
20 ///
21 /// The following constraints need to be satisfied for the macro to succeed
22 ///
23 /// - All fields in the struct must implement `Pod`
24 /// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
25 /// - The struct must not contain any padding bytes
26 /// - The struct contains no generic parameters, if it is not
27 ///   `#[repr(transparent)]`
28 ///
29 /// ## Examples
30 ///
31 /// ```rust
32 /// # use std::marker::PhantomData;
33 /// # use bytemuck_derive::{Pod, Zeroable};
34 /// #[derive(Copy, Clone, Pod, Zeroable)]
35 /// #[repr(C)]
36 /// struct Test {
37 ///   a: u16,
38 ///   b: u16,
39 /// }
40 ///
41 /// #[derive(Copy, Clone, Pod, Zeroable)]
42 /// #[repr(transparent)]
43 /// struct Generic<A, B> {
44 ///   a: A,
45 ///   b: PhantomData<B>,
46 /// }
47 /// ```
48 ///
49 /// If the struct is generic, it must be `#[repr(transparent)]` also.
50 ///
51 /// ```compile_fail
52 /// # use bytemuck::{Pod, Zeroable};
53 /// # use std::marker::PhantomData;
54 /// #[derive(Copy, Clone, Pod, Zeroable)]
55 /// #[repr(C)] // must be `#[repr(transparent)]`
56 /// struct Generic<A> {
57 ///   a: A,
58 /// }
59 /// ```
60 ///
61 /// If the struct is generic and `#[repr(transparent)]`, then it is only `Pod`
62 /// when all of its generics are `Pod`, not just its fields.
63 ///
64 /// ```
65 /// # use bytemuck::{Pod, Zeroable};
66 /// # use std::marker::PhantomData;
67 /// #[derive(Copy, Clone, Pod, Zeroable)]
68 /// #[repr(transparent)]
69 /// struct Generic<A, B> {
70 ///   a: A,
71 ///   b: PhantomData<B>,
72 /// }
73 ///
74 /// let _: u32 = bytemuck::cast(Generic { a: 4u32, b: PhantomData::<u32> });
75 /// ```
76 ///
77 /// ```compile_fail
78 /// # use bytemuck::{Pod, Zeroable};
79 /// # use std::marker::PhantomData;
80 /// # #[derive(Copy, Clone, Pod, Zeroable)]
81 /// # #[repr(transparent)]
82 /// # struct Generic<A, B> {
83 /// #   a: A,
84 /// #   b: PhantomData<B>,
85 /// # }
86 /// struct NotPod;
87 ///
88 /// let _: u32 = bytemuck::cast(Generic { a: 4u32, b: PhantomData::<NotPod> });
89 /// ```
90 #[proc_macro_derive(Pod, attributes(bytemuck))]
derive_pod(input: proc_macro::TokenStream) -> proc_macro::TokenStream91 pub fn derive_pod(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
92   let expanded =
93     derive_marker_trait::<Pod>(parse_macro_input!(input as DeriveInput));
94 
95   proc_macro::TokenStream::from(expanded)
96 }
97 
98 /// Derive the `AnyBitPattern` trait for a struct
99 ///
100 /// The macro ensures that the struct follows all the the safety requirements
101 /// for the `AnyBitPattern` trait.
102 ///
103 /// The following constraints need to be satisfied for the macro to succeed
104 ///
105 /// - All fields in the struct must to implement `AnyBitPattern`
106 #[proc_macro_derive(AnyBitPattern, attributes(bytemuck))]
derive_anybitpattern( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream107 pub fn derive_anybitpattern(
108   input: proc_macro::TokenStream,
109 ) -> proc_macro::TokenStream {
110   let expanded = derive_marker_trait::<AnyBitPattern>(parse_macro_input!(
111     input as DeriveInput
112   ));
113 
114   proc_macro::TokenStream::from(expanded)
115 }
116 
117 /// Derive the `Zeroable` trait for a struct
118 ///
119 /// The macro ensures that the struct follows all the the safety requirements
120 /// for the `Zeroable` trait.
121 ///
122 /// The following constraints need to be satisfied for the macro to succeed
123 ///
124 /// - All fields in the struct must to implement `Zeroable`
125 ///
126 /// ## Example
127 ///
128 /// ```rust
129 /// # use bytemuck_derive::{Zeroable};
130 /// #[derive(Copy, Clone, Zeroable)]
131 /// #[repr(C)]
132 /// struct Test {
133 ///   a: u16,
134 ///   b: u16,
135 /// }
136 /// ```
137 ///
138 /// # Custom bounds
139 ///
140 /// Custom bounds for the derived `Zeroable` impl can be given using the
141 /// `#[zeroable(bound = "")]` helper attribute.
142 ///
143 /// Using this attribute additionally opts-in to "perfect derive" semantics,
144 /// where instead of adding bounds for each generic type parameter, bounds are
145 /// added for each field's type.
146 ///
147 /// ## Examples
148 ///
149 /// ```rust
150 /// # use bytemuck::Zeroable;
151 /// # use std::marker::PhantomData;
152 /// #[derive(Clone, Zeroable)]
153 /// #[zeroable(bound = "")]
154 /// struct AlwaysZeroable<T> {
155 ///   a: PhantomData<T>,
156 /// }
157 ///
158 /// AlwaysZeroable::<std::num::NonZeroU8>::zeroed();
159 /// ```
160 ///
161 /// ```rust,compile_fail
162 /// # use bytemuck::Zeroable;
163 /// # use std::marker::PhantomData;
164 /// #[derive(Clone, Zeroable)]
165 /// #[zeroable(bound = "T: Copy")]
166 /// struct ZeroableWhenTIsCopy<T> {
167 ///   a: PhantomData<T>,
168 /// }
169 ///
170 /// ZeroableWhenTIsCopy::<String>::zeroed();
171 /// ```
172 ///
173 /// The restriction that all fields must be Zeroable is still applied, and this
174 /// is enforced using the mentioned "perfect derive" semantics.
175 ///
176 /// ```rust
177 /// # use bytemuck::Zeroable;
178 /// #[derive(Clone, Zeroable)]
179 /// #[zeroable(bound = "")]
180 /// struct ZeroableWhenTIsZeroable<T> {
181 ///   a: T,
182 /// }
183 /// ZeroableWhenTIsZeroable::<u32>::zeroed();
184 /// ```
185 ///
186 /// ```rust,compile_fail
187 /// # use bytemuck::Zeroable;
188 /// # #[derive(Clone, Zeroable)]
189 /// # #[zeroable(bound = "")]
190 /// # struct ZeroableWhenTIsZeroable<T> {
191 /// #   a: T,
192 /// # }
193 /// ZeroableWhenTIsZeroable::<String>::zeroed();
194 /// ```
195 #[proc_macro_derive(Zeroable, attributes(bytemuck, zeroable))]
derive_zeroable( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream196 pub fn derive_zeroable(
197   input: proc_macro::TokenStream,
198 ) -> proc_macro::TokenStream {
199   let expanded =
200     derive_marker_trait::<Zeroable>(parse_macro_input!(input as DeriveInput));
201 
202   proc_macro::TokenStream::from(expanded)
203 }
204 
205 /// Derive the `NoUninit` trait for a struct or enum
206 ///
207 /// The macro ensures that the type follows all the the safety requirements
208 /// for the `NoUninit` trait.
209 ///
210 /// The following constraints need to be satisfied for the macro to succeed
211 /// (the rest of the constraints are guaranteed by the `NoUninit` subtrait
212 /// bounds, i.e. the type must be `Sized + Copy + 'static`):
213 ///
214 /// If applied to a struct:
215 /// - All fields in the struct must implement `NoUninit`
216 /// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
217 /// - The struct must not contain any padding bytes
218 /// - The struct must contain no generic parameters
219 ///
220 /// If applied to an enum:
221 /// - The enum must be explicit `#[repr(Int)]`, `#[repr(C)]`, or both
222 /// - All variants must be fieldless
223 /// - The enum must contain no generic parameters
224 #[proc_macro_derive(NoUninit, attributes(bytemuck))]
derive_no_uninit( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream225 pub fn derive_no_uninit(
226   input: proc_macro::TokenStream,
227 ) -> proc_macro::TokenStream {
228   let expanded =
229     derive_marker_trait::<NoUninit>(parse_macro_input!(input as DeriveInput));
230 
231   proc_macro::TokenStream::from(expanded)
232 }
233 
234 /// Derive the `CheckedBitPattern` trait for a struct or enum.
235 ///
236 /// The macro ensures that the type follows all the the safety requirements
237 /// for the `CheckedBitPattern` trait and derives the required `Bits` type
238 /// definition and `is_valid_bit_pattern` method for the type automatically.
239 ///
240 /// The following constraints need to be satisfied for the macro to succeed:
241 ///
242 /// If applied to a struct:
243 /// - All fields must implement `CheckedBitPattern`
244 /// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
245 /// - The struct must contain no generic parameters
246 ///
247 /// If applied to an enum:
248 /// - The enum must be explicit `#[repr(Int)]`
249 /// - All fields in variants must implement `CheckedBitPattern`
250 /// - The enum must contain no generic parameters
251 #[proc_macro_derive(CheckedBitPattern)]
derive_maybe_pod( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream252 pub fn derive_maybe_pod(
253   input: proc_macro::TokenStream,
254 ) -> proc_macro::TokenStream {
255   let expanded = derive_marker_trait::<CheckedBitPattern>(parse_macro_input!(
256     input as DeriveInput
257   ));
258 
259   proc_macro::TokenStream::from(expanded)
260 }
261 
262 /// Derive the `TransparentWrapper` trait for a struct
263 ///
264 /// The macro ensures that the struct follows all the the safety requirements
265 /// for the `TransparentWrapper` trait.
266 ///
267 /// The following constraints need to be satisfied for the macro to succeed
268 ///
269 /// - The struct must be `#[repr(transparent)]`
270 /// - The struct must contain the `Wrapped` type
271 /// - Any ZST fields must be [`Zeroable`][derive@Zeroable].
272 ///
273 /// If the struct only contains a single field, the `Wrapped` type will
274 /// automatically be determined. If there is more then one field in the struct,
275 /// you need to specify the `Wrapped` type using `#[transparent(T)]`
276 ///
277 /// ## Examples
278 ///
279 /// ```rust
280 /// # use bytemuck_derive::TransparentWrapper;
281 /// # use std::marker::PhantomData;
282 /// #[derive(Copy, Clone, TransparentWrapper)]
283 /// #[repr(transparent)]
284 /// #[transparent(u16)]
285 /// struct Test<T> {
286 ///   inner: u16,
287 ///   extra: PhantomData<T>,
288 /// }
289 /// ```
290 ///
291 /// If the struct contains more than one field, the `Wrapped` type must be
292 /// explicitly specified.
293 ///
294 /// ```rust,compile_fail
295 /// # use bytemuck_derive::TransparentWrapper;
296 /// # use std::marker::PhantomData;
297 /// #[derive(Copy, Clone, TransparentWrapper)]
298 /// #[repr(transparent)]
299 /// // missing `#[transparent(u16)]`
300 /// struct Test<T> {
301 ///   inner: u16,
302 ///   extra: PhantomData<T>,
303 /// }
304 /// ```
305 ///
306 /// Any ZST fields must be `Zeroable`.
307 ///
308 /// ```rust,compile_fail
309 /// # use bytemuck_derive::TransparentWrapper;
310 /// # use std::marker::PhantomData;
311 /// struct NonTransparentSafeZST;
312 ///
313 /// #[derive(TransparentWrapper)]
314 /// #[repr(transparent)]
315 /// #[transparent(u16)]
316 /// struct Test<T> {
317 ///   inner: u16,
318 ///   extra: PhantomData<T>,
319 ///   another_extra: NonTransparentSafeZST, // not `Zeroable`
320 /// }
321 /// ```
322 #[proc_macro_derive(TransparentWrapper, attributes(bytemuck, transparent))]
derive_transparent( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream323 pub fn derive_transparent(
324   input: proc_macro::TokenStream,
325 ) -> proc_macro::TokenStream {
326   let expanded = derive_marker_trait::<TransparentWrapper>(parse_macro_input!(
327     input as DeriveInput
328   ));
329 
330   proc_macro::TokenStream::from(expanded)
331 }
332 
333 /// Derive the `Contiguous` trait for an enum
334 ///
335 /// The macro ensures that the enum follows all the the safety requirements
336 /// for the `Contiguous` trait.
337 ///
338 /// The following constraints need to be satisfied for the macro to succeed
339 ///
340 /// - The enum must be `#[repr(Int)]`
341 /// - The enum must be fieldless
342 /// - The enum discriminants must form a contiguous range
343 ///
344 /// ## Example
345 ///
346 /// ```rust
347 /// # use bytemuck_derive::{Contiguous};
348 ///
349 /// #[derive(Copy, Clone, Contiguous)]
350 /// #[repr(u8)]
351 /// enum Test {
352 ///   A = 0,
353 ///   B = 1,
354 ///   C = 2,
355 /// }
356 /// ```
357 #[proc_macro_derive(Contiguous)]
derive_contiguous( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream358 pub fn derive_contiguous(
359   input: proc_macro::TokenStream,
360 ) -> proc_macro::TokenStream {
361   let expanded =
362     derive_marker_trait::<Contiguous>(parse_macro_input!(input as DeriveInput));
363 
364   proc_macro::TokenStream::from(expanded)
365 }
366 
367 /// Derive the `PartialEq` and `Eq` trait for a type
368 ///
369 /// The macro implements `PartialEq` and `Eq` by casting both sides of the
370 /// comparison to a byte slice and then compares those.
371 ///
372 /// ## Warning
373 ///
374 /// Since this implements a byte wise comparison, the behavior of floating point
375 /// numbers does not match their usual comparison behavior. Additionally other
376 /// custom comparison behaviors of the individual fields are also ignored. This
377 /// also does not implement `StructuralPartialEq` / `StructuralEq` like
378 /// `PartialEq` / `Eq` would. This means you can't pattern match on the values.
379 ///
380 /// ## Examples
381 ///
382 /// ```rust
383 /// # use bytemuck_derive::{ByteEq, NoUninit};
384 /// #[derive(Copy, Clone, NoUninit, ByteEq)]
385 /// #[repr(C)]
386 /// struct Test {
387 ///   a: u32,
388 ///   b: char,
389 ///   c: f32,
390 /// }
391 /// ```
392 ///
393 /// ```rust
394 /// # use bytemuck_derive::ByteEq;
395 /// # use bytemuck::NoUninit;
396 /// #[derive(Copy, Clone, ByteEq)]
397 /// #[repr(C)]
398 /// struct Test<const N: usize> {
399 ///   a: [u32; N],
400 /// }
401 /// unsafe impl<const N: usize> NoUninit for Test<N> {}
402 /// ```
403 #[proc_macro_derive(ByteEq)]
derive_byte_eq( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream404 pub fn derive_byte_eq(
405   input: proc_macro::TokenStream,
406 ) -> proc_macro::TokenStream {
407   let input = parse_macro_input!(input as DeriveInput);
408   let crate_name = bytemuck_crate_name(&input);
409   let ident = input.ident;
410   let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
411 
412   proc_macro::TokenStream::from(quote! {
413     impl #impl_generics ::core::cmp::PartialEq for #ident #ty_generics #where_clause {
414       #[inline]
415       #[must_use]
416       fn eq(&self, other: &Self) -> bool {
417         #crate_name::bytes_of(self) == #crate_name::bytes_of(other)
418       }
419     }
420     impl #impl_generics ::core::cmp::Eq for #ident #ty_generics #where_clause { }
421   })
422 }
423 
424 /// Derive the `Hash` trait for a type
425 ///
426 /// The macro implements `Hash` by casting the value to a byte slice and hashing
427 /// that.
428 ///
429 /// ## Warning
430 ///
431 /// The hash does not match the standard library's `Hash` derive.
432 ///
433 /// ## Examples
434 ///
435 /// ```rust
436 /// # use bytemuck_derive::{ByteHash, NoUninit};
437 /// #[derive(Copy, Clone, NoUninit, ByteHash)]
438 /// #[repr(C)]
439 /// struct Test {
440 ///   a: u32,
441 ///   b: char,
442 ///   c: f32,
443 /// }
444 /// ```
445 ///
446 /// ```rust
447 /// # use bytemuck_derive::ByteHash;
448 /// # use bytemuck::NoUninit;
449 /// #[derive(Copy, Clone, ByteHash)]
450 /// #[repr(C)]
451 /// struct Test<const N: usize> {
452 ///   a: [u32; N],
453 /// }
454 /// unsafe impl<const N: usize> NoUninit for Test<N> {}
455 /// ```
456 #[proc_macro_derive(ByteHash)]
derive_byte_hash( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream457 pub fn derive_byte_hash(
458   input: proc_macro::TokenStream,
459 ) -> proc_macro::TokenStream {
460   let input = parse_macro_input!(input as DeriveInput);
461   let crate_name = bytemuck_crate_name(&input);
462   let ident = input.ident;
463   let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
464 
465   proc_macro::TokenStream::from(quote! {
466     impl #impl_generics ::core::hash::Hash for #ident #ty_generics #where_clause {
467       #[inline]
468       fn hash<H: ::core::hash::Hasher>(&self, state: &mut H) {
469         ::core::hash::Hash::hash_slice(#crate_name::bytes_of(self), state)
470       }
471 
472       #[inline]
473       fn hash_slice<H: ::core::hash::Hasher>(data: &[Self], state: &mut H) {
474         ::core::hash::Hash::hash_slice(#crate_name::cast_slice::<_, u8>(data), state)
475       }
476     }
477   })
478 }
479 
480 /// Basic wrapper for error handling
derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream481 fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
482   derive_marker_trait_inner::<Trait>(input)
483     .unwrap_or_else(|err| err.into_compile_error())
484 }
485 
486 /// Find `#[name(key = "value")]` helper attributes on the struct, and return
487 /// their `"value"`s parsed with `parser`.
488 ///
489 /// Returns an error if any attributes with the given `name` do not match the
490 /// expected format. Returns `Ok([])` if no attributes with `name` are found.
find_and_parse_helper_attributes<P: syn::parse::Parser + Copy>( attributes: &[syn::Attribute], name: &str, key: &str, parser: P, example_value: &str, invalid_value_msg: &str, ) -> Result<Vec<P::Output>>491 fn find_and_parse_helper_attributes<P: syn::parse::Parser + Copy>(
492   attributes: &[syn::Attribute], name: &str, key: &str, parser: P,
493   example_value: &str, invalid_value_msg: &str,
494 ) -> Result<Vec<P::Output>> {
495   let invalid_format_msg =
496     format!("{name} attribute must be `{name}({key} = \"{example_value}\")`",);
497   let values_to_check = attributes.iter().filter_map(|attr| match &attr.meta {
498     // If a `Path` matches our `name`, return an error, else ignore it.
499     // e.g. `#[zeroable]`
500     syn::Meta::Path(path) => path
501       .is_ident(name)
502       .then(|| Err(syn::Error::new_spanned(path, &invalid_format_msg))),
503     // If a `NameValue` matches our `name`, return an error, else ignore it.
504     // e.g. `#[zeroable = "hello"]`
505     syn::Meta::NameValue(namevalue) => {
506       namevalue.path.is_ident(name).then(|| {
507         Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
508       })
509     }
510     // If a `List` matches our `name`, match its contents to our format, else
511     // ignore it. If its contents match our format, return the value, else
512     // return an error.
513     syn::Meta::List(list) => list.path.is_ident(name).then(|| {
514       let namevalue: syn::MetaNameValue = syn::parse2(list.tokens.clone())
515         .map_err(|_| {
516           syn::Error::new_spanned(&list.tokens, &invalid_format_msg)
517         })?;
518       if namevalue.path.is_ident(key) {
519         match namevalue.value {
520           syn::Expr::Lit(syn::ExprLit {
521             lit: syn::Lit::Str(strlit), ..
522           }) => Ok(strlit),
523           _ => {
524             Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
525           }
526         }
527       } else {
528         Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
529       }
530     }),
531   });
532   // Parse each value found with the given parser, and return them if no errors
533   // occur.
534   values_to_check
535     .map(|lit| {
536       let lit = lit?;
537       lit.parse_with(parser).map_err(|err| {
538         syn::Error::new_spanned(&lit, format!("{invalid_value_msg}: {err}"))
539       })
540     })
541     .collect()
542 }
543 
derive_marker_trait_inner<Trait: Derivable>( mut input: DeriveInput, ) -> Result<TokenStream>544 fn derive_marker_trait_inner<Trait: Derivable>(
545   mut input: DeriveInput,
546 ) -> Result<TokenStream> {
547   let crate_name = bytemuck_crate_name(&input);
548   let trait_ = Trait::ident(&input, &crate_name)?;
549   // If this trait allows explicit bounds, and any explicit bounds were given,
550   // then use those explicit bounds. Else, apply the default bounds (bound
551   // each generic type on this trait).
552   if let Some(name) = Trait::explicit_bounds_attribute_name() {
553     // See if any explicit bounds were given in attributes.
554     let explicit_bounds = find_and_parse_helper_attributes(
555       &input.attrs,
556       name,
557       "bound",
558       <syn::punctuated::Punctuated<syn::WherePredicate, syn::Token![,]>>::parse_terminated,
559       "Type: Trait",
560       "invalid where predicate",
561     )?;
562 
563     if !explicit_bounds.is_empty() {
564       // Explicit bounds were given.
565       // Enforce explicitly given bounds, and emit "perfect derive" (i.e. add
566       // bounds for each field's type).
567       let explicit_bounds = explicit_bounds
568         .into_iter()
569         .flatten()
570         .collect::<Vec<syn::WherePredicate>>();
571 
572       let predicates = &mut input.generics.make_where_clause().predicates;
573 
574       predicates.extend(explicit_bounds);
575 
576       let fields = match &input.data {
577         syn::Data::Struct(syn::DataStruct { fields, .. }) => fields.clone(),
578         syn::Data::Union(_) => {
579           return Err(syn::Error::new_spanned(
580             trait_,
581             &"perfect derive is not supported for unions",
582           ));
583         }
584         syn::Data::Enum(_) => {
585           return Err(syn::Error::new_spanned(
586             trait_,
587             &"perfect derive is not supported for enums",
588           ));
589         }
590       };
591 
592       for field in fields {
593         let ty = field.ty;
594         predicates.push(syn::parse_quote!(
595           #ty: #trait_
596         ));
597       }
598     } else {
599       // No explicit bounds were given.
600       // Enforce trait bound on all type generics.
601       add_trait_marker(&mut input.generics, &trait_);
602     }
603   } else {
604     // This trait does not allow explicit bounds.
605     // Enforce trait bound on all type generics.
606     add_trait_marker(&mut input.generics, &trait_);
607   }
608 
609   let name = &input.ident;
610 
611   let (impl_generics, ty_generics, where_clause) =
612     input.generics.split_for_impl();
613 
614   Trait::check_attributes(&input.data, &input.attrs)?;
615   let asserts = Trait::asserts(&input, &crate_name)?;
616   let (trait_impl_extras, trait_impl) = Trait::trait_impl(&input, &crate_name)?;
617 
618   let implies_trait = if let Some(implies_trait) =
619     Trait::implies_trait(&crate_name)
620   {
621     quote!(unsafe impl #impl_generics #implies_trait for #name #ty_generics #where_clause {})
622   } else {
623     quote!()
624   };
625 
626   let where_clause =
627     if Trait::requires_where_clause() { where_clause } else { None };
628 
629   Ok(quote! {
630     #asserts
631 
632     #trait_impl_extras
633 
634     unsafe impl #impl_generics #trait_ for #name #ty_generics #where_clause {
635       #trait_impl
636     }
637 
638     #implies_trait
639   })
640 }
641 
642 /// Add a trait marker to the generics if it is not already present
add_trait_marker(generics: &mut syn::Generics, trait_name: &syn::Path)643 fn add_trait_marker(generics: &mut syn::Generics, trait_name: &syn::Path) {
644   // Get each generic type parameter.
645   let type_params = generics
646     .type_params()
647     .map(|param| &param.ident)
648     .map(|param| {
649       syn::parse_quote!(
650         #param: #trait_name
651       )
652     })
653     .collect::<Vec<syn::WherePredicate>>();
654 
655   generics.make_where_clause().predicates.extend(type_params);
656 }
657