1 //! Derive macros for [bytemuck](https://docs.rs/bytemuck) traits.
2
3 extern crate proc_macro;
4
5 mod traits;
6
7 use proc_macro2::TokenStream;
8 use quote::quote;
9 use syn::{parse_macro_input, DeriveInput, Result};
10
11 use crate::traits::{
12 bytemuck_crate_name, AnyBitPattern, CheckedBitPattern, Contiguous, Derivable,
13 NoUninit, Pod, TransparentWrapper, Zeroable,
14 };
15
16 /// Derive the `Pod` trait for a struct
17 ///
18 /// The macro ensures that the struct follows all the the safety requirements
19 /// for the `Pod` trait.
20 ///
21 /// The following constraints need to be satisfied for the macro to succeed
22 ///
23 /// - All fields in the struct must implement `Pod`
24 /// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
25 /// - The struct must not contain any padding bytes
26 /// - The struct contains no generic parameters, if it is not
27 /// `#[repr(transparent)]`
28 ///
29 /// ## Examples
30 ///
31 /// ```rust
32 /// # use std::marker::PhantomData;
33 /// # use bytemuck_derive::{Pod, Zeroable};
34 /// #[derive(Copy, Clone, Pod, Zeroable)]
35 /// #[repr(C)]
36 /// struct Test {
37 /// a: u16,
38 /// b: u16,
39 /// }
40 ///
41 /// #[derive(Copy, Clone, Pod, Zeroable)]
42 /// #[repr(transparent)]
43 /// struct Generic<A, B> {
44 /// a: A,
45 /// b: PhantomData<B>,
46 /// }
47 /// ```
48 ///
49 /// If the struct is generic, it must be `#[repr(transparent)]` also.
50 ///
51 /// ```compile_fail
52 /// # use bytemuck::{Pod, Zeroable};
53 /// # use std::marker::PhantomData;
54 /// #[derive(Copy, Clone, Pod, Zeroable)]
55 /// #[repr(C)] // must be `#[repr(transparent)]`
56 /// struct Generic<A> {
57 /// a: A,
58 /// }
59 /// ```
60 ///
61 /// If the struct is generic and `#[repr(transparent)]`, then it is only `Pod`
62 /// when all of its generics are `Pod`, not just its fields.
63 ///
64 /// ```
65 /// # use bytemuck::{Pod, Zeroable};
66 /// # use std::marker::PhantomData;
67 /// #[derive(Copy, Clone, Pod, Zeroable)]
68 /// #[repr(transparent)]
69 /// struct Generic<A, B> {
70 /// a: A,
71 /// b: PhantomData<B>,
72 /// }
73 ///
74 /// let _: u32 = bytemuck::cast(Generic { a: 4u32, b: PhantomData::<u32> });
75 /// ```
76 ///
77 /// ```compile_fail
78 /// # use bytemuck::{Pod, Zeroable};
79 /// # use std::marker::PhantomData;
80 /// # #[derive(Copy, Clone, Pod, Zeroable)]
81 /// # #[repr(transparent)]
82 /// # struct Generic<A, B> {
83 /// # a: A,
84 /// # b: PhantomData<B>,
85 /// # }
86 /// struct NotPod;
87 ///
88 /// let _: u32 = bytemuck::cast(Generic { a: 4u32, b: PhantomData::<NotPod> });
89 /// ```
90 #[proc_macro_derive(Pod, attributes(bytemuck))]
derive_pod(input: proc_macro::TokenStream) -> proc_macro::TokenStream91 pub fn derive_pod(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
92 let expanded =
93 derive_marker_trait::<Pod>(parse_macro_input!(input as DeriveInput));
94
95 proc_macro::TokenStream::from(expanded)
96 }
97
98 /// Derive the `AnyBitPattern` trait for a struct
99 ///
100 /// The macro ensures that the struct follows all the the safety requirements
101 /// for the `AnyBitPattern` trait.
102 ///
103 /// The following constraints need to be satisfied for the macro to succeed
104 ///
105 /// - All fields in the struct must to implement `AnyBitPattern`
106 #[proc_macro_derive(AnyBitPattern, attributes(bytemuck))]
derive_anybitpattern( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream107 pub fn derive_anybitpattern(
108 input: proc_macro::TokenStream,
109 ) -> proc_macro::TokenStream {
110 let expanded = derive_marker_trait::<AnyBitPattern>(parse_macro_input!(
111 input as DeriveInput
112 ));
113
114 proc_macro::TokenStream::from(expanded)
115 }
116
117 /// Derive the `Zeroable` trait for a struct
118 ///
119 /// The macro ensures that the struct follows all the the safety requirements
120 /// for the `Zeroable` trait.
121 ///
122 /// The following constraints need to be satisfied for the macro to succeed
123 ///
124 /// - All fields in the struct must to implement `Zeroable`
125 ///
126 /// ## Example
127 ///
128 /// ```rust
129 /// # use bytemuck_derive::{Zeroable};
130 /// #[derive(Copy, Clone, Zeroable)]
131 /// #[repr(C)]
132 /// struct Test {
133 /// a: u16,
134 /// b: u16,
135 /// }
136 /// ```
137 ///
138 /// # Custom bounds
139 ///
140 /// Custom bounds for the derived `Zeroable` impl can be given using the
141 /// `#[zeroable(bound = "")]` helper attribute.
142 ///
143 /// Using this attribute additionally opts-in to "perfect derive" semantics,
144 /// where instead of adding bounds for each generic type parameter, bounds are
145 /// added for each field's type.
146 ///
147 /// ## Examples
148 ///
149 /// ```rust
150 /// # use bytemuck::Zeroable;
151 /// # use std::marker::PhantomData;
152 /// #[derive(Clone, Zeroable)]
153 /// #[zeroable(bound = "")]
154 /// struct AlwaysZeroable<T> {
155 /// a: PhantomData<T>,
156 /// }
157 ///
158 /// AlwaysZeroable::<std::num::NonZeroU8>::zeroed();
159 /// ```
160 ///
161 /// ```rust,compile_fail
162 /// # use bytemuck::Zeroable;
163 /// # use std::marker::PhantomData;
164 /// #[derive(Clone, Zeroable)]
165 /// #[zeroable(bound = "T: Copy")]
166 /// struct ZeroableWhenTIsCopy<T> {
167 /// a: PhantomData<T>,
168 /// }
169 ///
170 /// ZeroableWhenTIsCopy::<String>::zeroed();
171 /// ```
172 ///
173 /// The restriction that all fields must be Zeroable is still applied, and this
174 /// is enforced using the mentioned "perfect derive" semantics.
175 ///
176 /// ```rust
177 /// # use bytemuck::Zeroable;
178 /// #[derive(Clone, Zeroable)]
179 /// #[zeroable(bound = "")]
180 /// struct ZeroableWhenTIsZeroable<T> {
181 /// a: T,
182 /// }
183 /// ZeroableWhenTIsZeroable::<u32>::zeroed();
184 /// ```
185 ///
186 /// ```rust,compile_fail
187 /// # use bytemuck::Zeroable;
188 /// # #[derive(Clone, Zeroable)]
189 /// # #[zeroable(bound = "")]
190 /// # struct ZeroableWhenTIsZeroable<T> {
191 /// # a: T,
192 /// # }
193 /// ZeroableWhenTIsZeroable::<String>::zeroed();
194 /// ```
195 #[proc_macro_derive(Zeroable, attributes(bytemuck, zeroable))]
derive_zeroable( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream196 pub fn derive_zeroable(
197 input: proc_macro::TokenStream,
198 ) -> proc_macro::TokenStream {
199 let expanded =
200 derive_marker_trait::<Zeroable>(parse_macro_input!(input as DeriveInput));
201
202 proc_macro::TokenStream::from(expanded)
203 }
204
205 /// Derive the `NoUninit` trait for a struct or enum
206 ///
207 /// The macro ensures that the type follows all the the safety requirements
208 /// for the `NoUninit` trait.
209 ///
210 /// The following constraints need to be satisfied for the macro to succeed
211 /// (the rest of the constraints are guaranteed by the `NoUninit` subtrait
212 /// bounds, i.e. the type must be `Sized + Copy + 'static`):
213 ///
214 /// If applied to a struct:
215 /// - All fields in the struct must implement `NoUninit`
216 /// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
217 /// - The struct must not contain any padding bytes
218 /// - The struct must contain no generic parameters
219 ///
220 /// If applied to an enum:
221 /// - The enum must be explicit `#[repr(Int)]`, `#[repr(C)]`, or both
222 /// - All variants must be fieldless
223 /// - The enum must contain no generic parameters
224 #[proc_macro_derive(NoUninit, attributes(bytemuck))]
derive_no_uninit( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream225 pub fn derive_no_uninit(
226 input: proc_macro::TokenStream,
227 ) -> proc_macro::TokenStream {
228 let expanded =
229 derive_marker_trait::<NoUninit>(parse_macro_input!(input as DeriveInput));
230
231 proc_macro::TokenStream::from(expanded)
232 }
233
234 /// Derive the `CheckedBitPattern` trait for a struct or enum.
235 ///
236 /// The macro ensures that the type follows all the the safety requirements
237 /// for the `CheckedBitPattern` trait and derives the required `Bits` type
238 /// definition and `is_valid_bit_pattern` method for the type automatically.
239 ///
240 /// The following constraints need to be satisfied for the macro to succeed:
241 ///
242 /// If applied to a struct:
243 /// - All fields must implement `CheckedBitPattern`
244 /// - The struct must be `#[repr(C)]` or `#[repr(transparent)]`
245 /// - The struct must contain no generic parameters
246 ///
247 /// If applied to an enum:
248 /// - The enum must be explicit `#[repr(Int)]`
249 /// - All fields in variants must implement `CheckedBitPattern`
250 /// - The enum must contain no generic parameters
251 #[proc_macro_derive(CheckedBitPattern)]
derive_maybe_pod( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream252 pub fn derive_maybe_pod(
253 input: proc_macro::TokenStream,
254 ) -> proc_macro::TokenStream {
255 let expanded = derive_marker_trait::<CheckedBitPattern>(parse_macro_input!(
256 input as DeriveInput
257 ));
258
259 proc_macro::TokenStream::from(expanded)
260 }
261
262 /// Derive the `TransparentWrapper` trait for a struct
263 ///
264 /// The macro ensures that the struct follows all the the safety requirements
265 /// for the `TransparentWrapper` trait.
266 ///
267 /// The following constraints need to be satisfied for the macro to succeed
268 ///
269 /// - The struct must be `#[repr(transparent)]`
270 /// - The struct must contain the `Wrapped` type
271 /// - Any ZST fields must be [`Zeroable`][derive@Zeroable].
272 ///
273 /// If the struct only contains a single field, the `Wrapped` type will
274 /// automatically be determined. If there is more then one field in the struct,
275 /// you need to specify the `Wrapped` type using `#[transparent(T)]`
276 ///
277 /// ## Examples
278 ///
279 /// ```rust
280 /// # use bytemuck_derive::TransparentWrapper;
281 /// # use std::marker::PhantomData;
282 /// #[derive(Copy, Clone, TransparentWrapper)]
283 /// #[repr(transparent)]
284 /// #[transparent(u16)]
285 /// struct Test<T> {
286 /// inner: u16,
287 /// extra: PhantomData<T>,
288 /// }
289 /// ```
290 ///
291 /// If the struct contains more than one field, the `Wrapped` type must be
292 /// explicitly specified.
293 ///
294 /// ```rust,compile_fail
295 /// # use bytemuck_derive::TransparentWrapper;
296 /// # use std::marker::PhantomData;
297 /// #[derive(Copy, Clone, TransparentWrapper)]
298 /// #[repr(transparent)]
299 /// // missing `#[transparent(u16)]`
300 /// struct Test<T> {
301 /// inner: u16,
302 /// extra: PhantomData<T>,
303 /// }
304 /// ```
305 ///
306 /// Any ZST fields must be `Zeroable`.
307 ///
308 /// ```rust,compile_fail
309 /// # use bytemuck_derive::TransparentWrapper;
310 /// # use std::marker::PhantomData;
311 /// struct NonTransparentSafeZST;
312 ///
313 /// #[derive(TransparentWrapper)]
314 /// #[repr(transparent)]
315 /// #[transparent(u16)]
316 /// struct Test<T> {
317 /// inner: u16,
318 /// extra: PhantomData<T>,
319 /// another_extra: NonTransparentSafeZST, // not `Zeroable`
320 /// }
321 /// ```
322 #[proc_macro_derive(TransparentWrapper, attributes(bytemuck, transparent))]
derive_transparent( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream323 pub fn derive_transparent(
324 input: proc_macro::TokenStream,
325 ) -> proc_macro::TokenStream {
326 let expanded = derive_marker_trait::<TransparentWrapper>(parse_macro_input!(
327 input as DeriveInput
328 ));
329
330 proc_macro::TokenStream::from(expanded)
331 }
332
333 /// Derive the `Contiguous` trait for an enum
334 ///
335 /// The macro ensures that the enum follows all the the safety requirements
336 /// for the `Contiguous` trait.
337 ///
338 /// The following constraints need to be satisfied for the macro to succeed
339 ///
340 /// - The enum must be `#[repr(Int)]`
341 /// - The enum must be fieldless
342 /// - The enum discriminants must form a contiguous range
343 ///
344 /// ## Example
345 ///
346 /// ```rust
347 /// # use bytemuck_derive::{Contiguous};
348 ///
349 /// #[derive(Copy, Clone, Contiguous)]
350 /// #[repr(u8)]
351 /// enum Test {
352 /// A = 0,
353 /// B = 1,
354 /// C = 2,
355 /// }
356 /// ```
357 #[proc_macro_derive(Contiguous)]
derive_contiguous( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream358 pub fn derive_contiguous(
359 input: proc_macro::TokenStream,
360 ) -> proc_macro::TokenStream {
361 let expanded =
362 derive_marker_trait::<Contiguous>(parse_macro_input!(input as DeriveInput));
363
364 proc_macro::TokenStream::from(expanded)
365 }
366
367 /// Derive the `PartialEq` and `Eq` trait for a type
368 ///
369 /// The macro implements `PartialEq` and `Eq` by casting both sides of the
370 /// comparison to a byte slice and then compares those.
371 ///
372 /// ## Warning
373 ///
374 /// Since this implements a byte wise comparison, the behavior of floating point
375 /// numbers does not match their usual comparison behavior. Additionally other
376 /// custom comparison behaviors of the individual fields are also ignored. This
377 /// also does not implement `StructuralPartialEq` / `StructuralEq` like
378 /// `PartialEq` / `Eq` would. This means you can't pattern match on the values.
379 ///
380 /// ## Examples
381 ///
382 /// ```rust
383 /// # use bytemuck_derive::{ByteEq, NoUninit};
384 /// #[derive(Copy, Clone, NoUninit, ByteEq)]
385 /// #[repr(C)]
386 /// struct Test {
387 /// a: u32,
388 /// b: char,
389 /// c: f32,
390 /// }
391 /// ```
392 ///
393 /// ```rust
394 /// # use bytemuck_derive::ByteEq;
395 /// # use bytemuck::NoUninit;
396 /// #[derive(Copy, Clone, ByteEq)]
397 /// #[repr(C)]
398 /// struct Test<const N: usize> {
399 /// a: [u32; N],
400 /// }
401 /// unsafe impl<const N: usize> NoUninit for Test<N> {}
402 /// ```
403 #[proc_macro_derive(ByteEq)]
derive_byte_eq( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream404 pub fn derive_byte_eq(
405 input: proc_macro::TokenStream,
406 ) -> proc_macro::TokenStream {
407 let input = parse_macro_input!(input as DeriveInput);
408 let crate_name = bytemuck_crate_name(&input);
409 let ident = input.ident;
410 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
411
412 proc_macro::TokenStream::from(quote! {
413 impl #impl_generics ::core::cmp::PartialEq for #ident #ty_generics #where_clause {
414 #[inline]
415 #[must_use]
416 fn eq(&self, other: &Self) -> bool {
417 #crate_name::bytes_of(self) == #crate_name::bytes_of(other)
418 }
419 }
420 impl #impl_generics ::core::cmp::Eq for #ident #ty_generics #where_clause { }
421 })
422 }
423
424 /// Derive the `Hash` trait for a type
425 ///
426 /// The macro implements `Hash` by casting the value to a byte slice and hashing
427 /// that.
428 ///
429 /// ## Warning
430 ///
431 /// The hash does not match the standard library's `Hash` derive.
432 ///
433 /// ## Examples
434 ///
435 /// ```rust
436 /// # use bytemuck_derive::{ByteHash, NoUninit};
437 /// #[derive(Copy, Clone, NoUninit, ByteHash)]
438 /// #[repr(C)]
439 /// struct Test {
440 /// a: u32,
441 /// b: char,
442 /// c: f32,
443 /// }
444 /// ```
445 ///
446 /// ```rust
447 /// # use bytemuck_derive::ByteHash;
448 /// # use bytemuck::NoUninit;
449 /// #[derive(Copy, Clone, ByteHash)]
450 /// #[repr(C)]
451 /// struct Test<const N: usize> {
452 /// a: [u32; N],
453 /// }
454 /// unsafe impl<const N: usize> NoUninit for Test<N> {}
455 /// ```
456 #[proc_macro_derive(ByteHash)]
derive_byte_hash( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream457 pub fn derive_byte_hash(
458 input: proc_macro::TokenStream,
459 ) -> proc_macro::TokenStream {
460 let input = parse_macro_input!(input as DeriveInput);
461 let crate_name = bytemuck_crate_name(&input);
462 let ident = input.ident;
463 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
464
465 proc_macro::TokenStream::from(quote! {
466 impl #impl_generics ::core::hash::Hash for #ident #ty_generics #where_clause {
467 #[inline]
468 fn hash<H: ::core::hash::Hasher>(&self, state: &mut H) {
469 ::core::hash::Hash::hash_slice(#crate_name::bytes_of(self), state)
470 }
471
472 #[inline]
473 fn hash_slice<H: ::core::hash::Hasher>(data: &[Self], state: &mut H) {
474 ::core::hash::Hash::hash_slice(#crate_name::cast_slice::<_, u8>(data), state)
475 }
476 }
477 })
478 }
479
480 /// Basic wrapper for error handling
derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream481 fn derive_marker_trait<Trait: Derivable>(input: DeriveInput) -> TokenStream {
482 derive_marker_trait_inner::<Trait>(input)
483 .unwrap_or_else(|err| err.into_compile_error())
484 }
485
486 /// Find `#[name(key = "value")]` helper attributes on the struct, and return
487 /// their `"value"`s parsed with `parser`.
488 ///
489 /// Returns an error if any attributes with the given `name` do not match the
490 /// expected format. Returns `Ok([])` if no attributes with `name` are found.
find_and_parse_helper_attributes<P: syn::parse::Parser + Copy>( attributes: &[syn::Attribute], name: &str, key: &str, parser: P, example_value: &str, invalid_value_msg: &str, ) -> Result<Vec<P::Output>>491 fn find_and_parse_helper_attributes<P: syn::parse::Parser + Copy>(
492 attributes: &[syn::Attribute], name: &str, key: &str, parser: P,
493 example_value: &str, invalid_value_msg: &str,
494 ) -> Result<Vec<P::Output>> {
495 let invalid_format_msg =
496 format!("{name} attribute must be `{name}({key} = \"{example_value}\")`",);
497 let values_to_check = attributes.iter().filter_map(|attr| match &attr.meta {
498 // If a `Path` matches our `name`, return an error, else ignore it.
499 // e.g. `#[zeroable]`
500 syn::Meta::Path(path) => path
501 .is_ident(name)
502 .then(|| Err(syn::Error::new_spanned(path, &invalid_format_msg))),
503 // If a `NameValue` matches our `name`, return an error, else ignore it.
504 // e.g. `#[zeroable = "hello"]`
505 syn::Meta::NameValue(namevalue) => {
506 namevalue.path.is_ident(name).then(|| {
507 Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
508 })
509 }
510 // If a `List` matches our `name`, match its contents to our format, else
511 // ignore it. If its contents match our format, return the value, else
512 // return an error.
513 syn::Meta::List(list) => list.path.is_ident(name).then(|| {
514 let namevalue: syn::MetaNameValue = syn::parse2(list.tokens.clone())
515 .map_err(|_| {
516 syn::Error::new_spanned(&list.tokens, &invalid_format_msg)
517 })?;
518 if namevalue.path.is_ident(key) {
519 match namevalue.value {
520 syn::Expr::Lit(syn::ExprLit {
521 lit: syn::Lit::Str(strlit), ..
522 }) => Ok(strlit),
523 _ => {
524 Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
525 }
526 }
527 } else {
528 Err(syn::Error::new_spanned(&namevalue.path, &invalid_format_msg))
529 }
530 }),
531 });
532 // Parse each value found with the given parser, and return them if no errors
533 // occur.
534 values_to_check
535 .map(|lit| {
536 let lit = lit?;
537 lit.parse_with(parser).map_err(|err| {
538 syn::Error::new_spanned(&lit, format!("{invalid_value_msg}: {err}"))
539 })
540 })
541 .collect()
542 }
543
derive_marker_trait_inner<Trait: Derivable>( mut input: DeriveInput, ) -> Result<TokenStream>544 fn derive_marker_trait_inner<Trait: Derivable>(
545 mut input: DeriveInput,
546 ) -> Result<TokenStream> {
547 let crate_name = bytemuck_crate_name(&input);
548 let trait_ = Trait::ident(&input, &crate_name)?;
549 // If this trait allows explicit bounds, and any explicit bounds were given,
550 // then use those explicit bounds. Else, apply the default bounds (bound
551 // each generic type on this trait).
552 if let Some(name) = Trait::explicit_bounds_attribute_name() {
553 // See if any explicit bounds were given in attributes.
554 let explicit_bounds = find_and_parse_helper_attributes(
555 &input.attrs,
556 name,
557 "bound",
558 <syn::punctuated::Punctuated<syn::WherePredicate, syn::Token![,]>>::parse_terminated,
559 "Type: Trait",
560 "invalid where predicate",
561 )?;
562
563 if !explicit_bounds.is_empty() {
564 // Explicit bounds were given.
565 // Enforce explicitly given bounds, and emit "perfect derive" (i.e. add
566 // bounds for each field's type).
567 let explicit_bounds = explicit_bounds
568 .into_iter()
569 .flatten()
570 .collect::<Vec<syn::WherePredicate>>();
571
572 let predicates = &mut input.generics.make_where_clause().predicates;
573
574 predicates.extend(explicit_bounds);
575
576 let fields = match &input.data {
577 syn::Data::Struct(syn::DataStruct { fields, .. }) => fields.clone(),
578 syn::Data::Union(_) => {
579 return Err(syn::Error::new_spanned(
580 trait_,
581 &"perfect derive is not supported for unions",
582 ));
583 }
584 syn::Data::Enum(_) => {
585 return Err(syn::Error::new_spanned(
586 trait_,
587 &"perfect derive is not supported for enums",
588 ));
589 }
590 };
591
592 for field in fields {
593 let ty = field.ty;
594 predicates.push(syn::parse_quote!(
595 #ty: #trait_
596 ));
597 }
598 } else {
599 // No explicit bounds were given.
600 // Enforce trait bound on all type generics.
601 add_trait_marker(&mut input.generics, &trait_);
602 }
603 } else {
604 // This trait does not allow explicit bounds.
605 // Enforce trait bound on all type generics.
606 add_trait_marker(&mut input.generics, &trait_);
607 }
608
609 let name = &input.ident;
610
611 let (impl_generics, ty_generics, where_clause) =
612 input.generics.split_for_impl();
613
614 Trait::check_attributes(&input.data, &input.attrs)?;
615 let asserts = Trait::asserts(&input, &crate_name)?;
616 let (trait_impl_extras, trait_impl) = Trait::trait_impl(&input, &crate_name)?;
617
618 let implies_trait = if let Some(implies_trait) =
619 Trait::implies_trait(&crate_name)
620 {
621 quote!(unsafe impl #impl_generics #implies_trait for #name #ty_generics #where_clause {})
622 } else {
623 quote!()
624 };
625
626 let where_clause =
627 if Trait::requires_where_clause() { where_clause } else { None };
628
629 Ok(quote! {
630 #asserts
631
632 #trait_impl_extras
633
634 unsafe impl #impl_generics #trait_ for #name #ty_generics #where_clause {
635 #trait_impl
636 }
637
638 #implies_trait
639 })
640 }
641
642 /// Add a trait marker to the generics if it is not already present
add_trait_marker(generics: &mut syn::Generics, trait_name: &syn::Path)643 fn add_trait_marker(generics: &mut syn::Generics, trait_name: &syn::Path) {
644 // Get each generic type parameter.
645 let type_params = generics
646 .type_params()
647 .map(|param| ¶m.ident)
648 .map(|param| {
649 syn::parse_quote!(
650 #param: #trait_name
651 )
652 })
653 .collect::<Vec<syn::WherePredicate>>();
654
655 generics.make_where_clause().predicates.extend(type_params);
656 }
657