1 //! Support for deriving the `Decode` and `Encode` traits on enums for
2 //! the purposes of decoding/encoding ASN.1 `ENUMERATED` types as mapped to
3 //! enum variants.
4 
5 use crate::attributes::AttrNameValue;
6 use crate::{default_lifetime, ATTR_NAME};
7 use proc_macro2::TokenStream;
8 use quote::quote;
9 use syn::{DeriveInput, Expr, ExprLit, Ident, Lit, LitInt, Variant};
10 
11 /// Valid options for the `#[repr]` attribute on `Enumerated` types.
12 const REPR_TYPES: &[&str] = &["u8", "u16", "u32"];
13 
14 /// Derive the `Enumerated` trait for an enum.
15 pub(crate) struct DeriveEnumerated {
16     /// Name of the enum type.
17     ident: Ident,
18 
19     /// Value of the `repr` attribute.
20     repr: Ident,
21 
22     /// Whether or not to tag the enum as an integer
23     integer: bool,
24 
25     /// Variants of this enum.
26     variants: Vec<EnumeratedVariant>,
27 }
28 
29 impl DeriveEnumerated {
30     /// Parse [`DeriveInput`].
new(input: DeriveInput) -> syn::Result<Self>31     pub fn new(input: DeriveInput) -> syn::Result<Self> {
32         let data = match input.data {
33             syn::Data::Enum(data) => data,
34             _ => abort!(
35                 input.ident,
36                 "can't derive `Enumerated` on this type: only `enum` types are allowed",
37             ),
38         };
39 
40         // Reject `asn1` attributes, parse the `repr` attribute
41         let mut repr: Option<Ident> = None;
42         let mut integer = false;
43 
44         for attr in &input.attrs {
45             if attr.path().is_ident(ATTR_NAME) {
46                 let kvs = match AttrNameValue::parse_attribute(attr) {
47                     Ok(kvs) => kvs,
48                     Err(e) => abort!(attr, e),
49                 };
50                 for anv in kvs {
51                     if anv.name.is_ident("type") {
52                         match anv.value.value().as_str() {
53                             "ENUMERATED" => integer = false,
54                             "INTEGER" => integer = true,
55                             s => abort!(anv.value, format_args!("`type = \"{s}\"` is unsupported")),
56                         }
57                     }
58                 }
59             } else if attr.path().is_ident("repr") {
60                 if repr.is_some() {
61                     abort!(
62                         attr,
63                         "multiple `#[repr]` attributes encountered on `Enumerated`",
64                     );
65                 }
66 
67                 let r = attr.parse_args::<Ident>().map_err(|_| {
68                     syn::Error::new_spanned(attr, "error parsing `#[repr]` attribute")
69                 })?;
70 
71                 // Validate
72                 if !REPR_TYPES.contains(&r.to_string().as_str()) {
73                     abort!(
74                         attr,
75                         format_args!("invalid `#[repr]` type: allowed types are {REPR_TYPES:?}"),
76                     );
77                 }
78 
79                 repr = Some(r);
80             }
81         }
82 
83         // Parse enum variants
84         let variants = data
85             .variants
86             .iter()
87             .map(EnumeratedVariant::new)
88             .collect::<syn::Result<_>>()?;
89 
90         Ok(Self {
91             ident: input.ident.clone(),
92             repr: repr.ok_or_else(|| {
93                 syn::Error::new_spanned(
94                     &input.ident,
95                     format_args!("no `#[repr]` attribute on enum: must be one of {REPR_TYPES:?}"),
96                 )
97             })?,
98             variants,
99             integer,
100         })
101     }
102 
103     /// Lower the derived output into a [`TokenStream`].
to_tokens(&self) -> TokenStream104     pub fn to_tokens(&self) -> TokenStream {
105         let default_lifetime = default_lifetime();
106         let ident = &self.ident;
107         let repr = &self.repr;
108         let tag = match self.integer {
109             false => quote! { ::der::Tag::Enumerated },
110             true => quote! { ::der::Tag::Integer },
111         };
112 
113         let mut try_from_body = Vec::new();
114         for variant in &self.variants {
115             try_from_body.push(variant.to_try_from_tokens());
116         }
117 
118         quote! {
119             impl<#default_lifetime> ::der::DecodeValue<#default_lifetime> for #ident {
120                 fn decode_value<R: ::der::Reader<#default_lifetime>>(
121                     reader: &mut R,
122                     header: ::der::Header
123                 ) -> ::der::Result<Self> {
124                     <#repr as ::der::DecodeValue>::decode_value(reader, header)?.try_into()
125                 }
126             }
127 
128             impl ::der::EncodeValue for #ident {
129                 fn value_len(&self) -> ::der::Result<::der::Length> {
130                     ::der::EncodeValue::value_len(&(*self as #repr))
131                 }
132 
133                 fn encode_value(&self, encoder: &mut impl ::der::Writer) -> ::der::Result<()> {
134                     ::der::EncodeValue::encode_value(&(*self as #repr), encoder)
135                 }
136             }
137 
138             impl ::der::FixedTag for #ident {
139                 const TAG: ::der::Tag = #tag;
140             }
141 
142             impl TryFrom<#repr> for #ident {
143                 type Error = ::der::Error;
144 
145                 fn try_from(n: #repr) -> ::der::Result<Self> {
146                     match n {
147                         #(#try_from_body)*
148                         _ => Err(#tag.value_error())
149                     }
150                 }
151             }
152         }
153     }
154 }
155 
156 /// "IR" for a variant of a derived `Enumerated`.
157 pub struct EnumeratedVariant {
158     /// Variant name.
159     ident: Ident,
160 
161     /// Integer value that this variant corresponds to.
162     discriminant: LitInt,
163 }
164 
165 impl EnumeratedVariant {
166     /// Create a new [`ChoiceVariant`] from the input [`Variant`].
new(input: &Variant) -> syn::Result<Self>167     fn new(input: &Variant) -> syn::Result<Self> {
168         for attr in &input.attrs {
169             if attr.path().is_ident(ATTR_NAME) {
170                 abort!(
171                     attr,
172                     "`asn1` attribute is not allowed on fields of `Enumerated` types"
173                 );
174             }
175         }
176 
177         match &input.discriminant {
178             Some((
179                 _,
180                 Expr::Lit(ExprLit {
181                     lit: Lit::Int(discriminant),
182                     ..
183                 }),
184             )) => Ok(Self {
185                 ident: input.ident.clone(),
186                 discriminant: discriminant.clone(),
187             }),
188             Some((_, other)) => abort!(other, "invalid discriminant for `Enumerated`"),
189             None => abort!(input, "`Enumerated` variant has no discriminant"),
190         }
191     }
192 
193     /// Write the body for the derived [`TryFrom`] impl.
to_try_from_tokens(&self) -> TokenStream194     pub fn to_try_from_tokens(&self) -> TokenStream {
195         let ident = &self.ident;
196         let discriminant = &self.discriminant;
197         quote! {
198             #discriminant => Ok(Self::#ident),
199         }
200     }
201 }
202 
203 #[cfg(test)]
204 mod tests {
205     use super::DeriveEnumerated;
206     use syn::parse_quote;
207 
208     /// X.509 `CRLReason`.
209     #[test]
crlreason_example()210     fn crlreason_example() {
211         let input = parse_quote! {
212             #[repr(u32)]
213             pub enum CrlReason {
214                 Unspecified = 0,
215                 KeyCompromise = 1,
216                 CaCompromise = 2,
217                 AffiliationChanged = 3,
218                 Superseded = 4,
219                 CessationOfOperation = 5,
220                 CertificateHold = 6,
221                 RemoveFromCrl = 8,
222                 PrivilegeWithdrawn = 9,
223                 AaCompromised = 10,
224             }
225         };
226 
227         let ir = DeriveEnumerated::new(input).unwrap();
228         assert_eq!(ir.ident, "CrlReason");
229         assert_eq!(ir.repr, "u32");
230         assert_eq!(ir.variants.len(), 10);
231 
232         let unspecified = &ir.variants[0];
233         assert_eq!(unspecified.ident, "Unspecified");
234         assert_eq!(unspecified.discriminant.to_string(), "0");
235 
236         let key_compromise = &ir.variants[1];
237         assert_eq!(key_compromise.ident, "KeyCompromise");
238         assert_eq!(key_compromise.discriminant.to_string(), "1");
239 
240         let key_compromise = &ir.variants[2];
241         assert_eq!(key_compromise.ident, "CaCompromise");
242         assert_eq!(key_compromise.discriminant.to_string(), "2");
243     }
244 }
245