xref: /aosp_15_r20/external/crosvm/base/base_event_token_derive/src/lib.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2018 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #![recursion_limit = "128"]
6 
7 extern crate proc_macro;
8 
9 use proc_macro2::Ident;
10 use proc_macro2::TokenStream;
11 use quote::quote;
12 use syn::parse_macro_input;
13 use syn::Data;
14 use syn::DeriveInput;
15 use syn::Field;
16 use syn::Fields;
17 use syn::Index;
18 use syn::Member;
19 use syn::Variant;
20 
21 #[cfg(test)]
22 mod tests;
23 
24 // The method for packing an enum into a u64 is as follows:
25 // 1) Reserve the lowest "ceil(log_2(x))" bits where x is the number of enum variants.
26 // 2) Store the enum variant's index (0-based index based on order in the enum definition) in
27 //    reserved bits.
28 // 3) If there is data in the enum variant, store the data in remaining bits.
29 // The method for unpacking is as follows
30 // 1) Mask the raw token to just the reserved bits
31 // 2) Match the reserved bits to the enum variant token.
32 // 3) If the indicated enum variant had data, extract it from the unreserved bits.
33 
34 // Calculates the number of bits needed to store the variant index. Essentially the log base 2
35 // of the number of variants, rounded up.
variant_bits(variants: &[Variant]) -> u3236 fn variant_bits(variants: &[Variant]) -> u32 {
37     if variants.is_empty() {
38         // The degenerate case of no variants.
39         0
40     } else {
41         variants.len().next_power_of_two().trailing_zeros()
42     }
43 }
44 
45 // Name of the field if it has one, otherwise 0 assuming this is the zeroth
46 // field of a tuple variant.
field_member(field: &Field) -> Member47 fn field_member(field: &Field) -> Member {
48     match &field.ident {
49         Some(name) => Member::Named(name.clone()),
50         None => Member::Unnamed(Index::from(0)),
51     }
52 }
53 
54 // Generates the function body for `as_raw_token`.
generate_as_raw_token(enum_name: &Ident, variants: &[Variant]) -> TokenStream55 fn generate_as_raw_token(enum_name: &Ident, variants: &[Variant]) -> TokenStream {
56     let variant_bits = variant_bits(variants);
57 
58     // Each iteration corresponds to one variant's match arm.
59     let cases = variants.iter().enumerate().map(|(index, variant)| {
60         let variant_name = &variant.ident;
61         let index = index as u64;
62 
63         // The capture string is for everything between the variant identifier and the `=>` in
64         // the match arm: the variant's data capture.
65         let capture = variant.fields.iter().next().map(|field| {
66             let member = field_member(field);
67             quote!({ #member: data })
68         });
69 
70         // The modifier string ORs the variant index with extra bits from the variant data
71         // field.
72         let modifier = match variant.fields {
73             Fields::Named(_) | Fields::Unnamed(_) => Some(quote! {
74                 | ((data as u64) << #variant_bits)
75             }),
76             Fields::Unit => None,
77         };
78 
79         // Assembly of the match arm.
80         quote! {
81             #enum_name::#variant_name #capture => #index #modifier
82         }
83     });
84 
85     quote! {
86         match *self {
87             #(
88                 #cases,
89             )*
90         }
91     }
92 }
93 
94 // Generates the function body for `from_raw_token`.
generate_from_raw_token(enum_name: &Ident, variants: &[Variant]) -> TokenStream95 fn generate_from_raw_token(enum_name: &Ident, variants: &[Variant]) -> TokenStream {
96     let variant_bits = variant_bits(variants);
97     let variant_mask = ((1 << variant_bits) - 1) as u64;
98 
99     // Each iteration corresponds to one variant's match arm.
100     let cases = variants.iter().enumerate().map(|(index, variant)| {
101         let variant_name = &variant.ident;
102         let index = index as u64;
103 
104         // The data string is for extracting the enum variant's data bits out of the raw token
105         // data, which includes both variant index and data bits.
106         let data = variant.fields.iter().next().map(|field| {
107             let member = field_member(field);
108             let ty = &field.ty;
109             quote!({ #member: (data >> #variant_bits) as #ty })
110         });
111 
112         // Assembly of the match arm.
113         quote! {
114             #index => #enum_name::#variant_name #data
115         }
116     });
117 
118     quote! {
119         // The match expression only matches the bits for the variant index.
120         match data & #variant_mask {
121             #(
122                 #cases,
123             )*
124             _ => unreachable!(),
125         }
126     }
127 }
128 
129 // The proc_macro::TokenStream type can only be constructed from within a
130 // procedural macro, meaning that unit tests are not able to invoke `fn
131 // event_token` below as an ordinary Rust function. We factor out the logic into
132 // a signature that deals with Syn and proc-macro2 types only which are not
133 // restricted to a procedural macro invocation.
event_token_inner(input: DeriveInput) -> TokenStream134 fn event_token_inner(input: DeriveInput) -> TokenStream {
135     let variants: Vec<Variant> = match input.data {
136         Data::Enum(data) => data.variants.into_iter().collect(),
137         Data::Struct(_) | Data::Union(_) => panic!("input must be an enum"),
138     };
139 
140     for variant in &variants {
141         assert!(variant.fields.iter().count() <= 1);
142     }
143 
144     // Given our basic model of a user given enum that is suitable as a token, we generate the
145     // implementation. The implementation is NOT always well formed, such as when a variant's data
146     // type is not bit shiftable or castable to u64, but we let Rust generate such errors as it
147     // would be difficult to detect every kind of error. Importantly, every implementation that we
148     // generate here and goes on to compile succesfully is sound.
149 
150     let enum_name = input.ident;
151     let as_raw_token = generate_as_raw_token(&enum_name, &variants);
152     let from_raw_token = generate_from_raw_token(&enum_name, &variants);
153 
154     quote! {
155         impl EventToken for #enum_name {
156             fn as_raw_token(&self) -> u64 {
157                 #as_raw_token
158             }
159 
160             fn from_raw_token(data: u64) -> Self {
161                 #from_raw_token
162             }
163         }
164     }
165 }
166 
167 /// Implements the EventToken trait for a given `enum`.
168 ///
169 /// There are limitations on what `enum`s this custom derive will work on:
170 ///
171 /// * Each variant must be a unit variant (no data), or have a single (un)named data field.
172 /// * If a variant has data, it must be a primitive type castable to and from a `u64`.
173 /// * If a variant data has size greater than or equal to a `u64`, its most significant bits must be
174 ///   zero. The number of bits truncated is equal to the number of bits used to store the variant
175 ///   index plus the number of bits above 64.
176 #[proc_macro_derive(EventToken)]
event_token(input: proc_macro::TokenStream) -> proc_macro::TokenStream177 pub fn event_token(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
178     let input = parse_macro_input!(input as DeriveInput);
179     event_token_inner(input).into()
180 }
181