1 //! Support for deriving the `Decode` and `Encode` traits on enums for 2 //! the purposes of decoding/encoding ASN.1 `ENUMERATED` types as mapped to 3 //! enum variants. 4 5 use crate::attributes::AttrNameValue; 6 use crate::{default_lifetime, ATTR_NAME}; 7 use proc_macro2::TokenStream; 8 use quote::quote; 9 use syn::{DeriveInput, Expr, ExprLit, Ident, Lit, LitInt, Variant}; 10 11 /// Valid options for the `#[repr]` attribute on `Enumerated` types. 12 const REPR_TYPES: &[&str] = &["u8", "u16", "u32"]; 13 14 /// Derive the `Enumerated` trait for an enum. 15 pub(crate) struct DeriveEnumerated { 16 /// Name of the enum type. 17 ident: Ident, 18 19 /// Value of the `repr` attribute. 20 repr: Ident, 21 22 /// Whether or not to tag the enum as an integer 23 integer: bool, 24 25 /// Variants of this enum. 26 variants: Vec<EnumeratedVariant>, 27 } 28 29 impl DeriveEnumerated { 30 /// Parse [`DeriveInput`]. new(input: DeriveInput) -> syn::Result<Self>31 pub fn new(input: DeriveInput) -> syn::Result<Self> { 32 let data = match input.data { 33 syn::Data::Enum(data) => data, 34 _ => abort!( 35 input.ident, 36 "can't derive `Enumerated` on this type: only `enum` types are allowed", 37 ), 38 }; 39 40 // Reject `asn1` attributes, parse the `repr` attribute 41 let mut repr: Option<Ident> = None; 42 let mut integer = false; 43 44 for attr in &input.attrs { 45 if attr.path().is_ident(ATTR_NAME) { 46 let kvs = match AttrNameValue::parse_attribute(attr) { 47 Ok(kvs) => kvs, 48 Err(e) => abort!(attr, e), 49 }; 50 for anv in kvs { 51 if anv.name.is_ident("type") { 52 match anv.value.value().as_str() { 53 "ENUMERATED" => integer = false, 54 "INTEGER" => integer = true, 55 s => abort!(anv.value, format_args!("`type = \"{s}\"` is unsupported")), 56 } 57 } 58 } 59 } else if attr.path().is_ident("repr") { 60 if repr.is_some() { 61 abort!( 62 attr, 63 "multiple `#[repr]` attributes encountered on `Enumerated`", 64 ); 65 } 66 67 let r = attr.parse_args::<Ident>().map_err(|_| { 68 syn::Error::new_spanned(attr, "error parsing `#[repr]` attribute") 69 })?; 70 71 // Validate 72 if !REPR_TYPES.contains(&r.to_string().as_str()) { 73 abort!( 74 attr, 75 format_args!("invalid `#[repr]` type: allowed types are {REPR_TYPES:?}"), 76 ); 77 } 78 79 repr = Some(r); 80 } 81 } 82 83 // Parse enum variants 84 let variants = data 85 .variants 86 .iter() 87 .map(EnumeratedVariant::new) 88 .collect::<syn::Result<_>>()?; 89 90 Ok(Self { 91 ident: input.ident.clone(), 92 repr: repr.ok_or_else(|| { 93 syn::Error::new_spanned( 94 &input.ident, 95 format_args!("no `#[repr]` attribute on enum: must be one of {REPR_TYPES:?}"), 96 ) 97 })?, 98 variants, 99 integer, 100 }) 101 } 102 103 /// Lower the derived output into a [`TokenStream`]. to_tokens(&self) -> TokenStream104 pub fn to_tokens(&self) -> TokenStream { 105 let default_lifetime = default_lifetime(); 106 let ident = &self.ident; 107 let repr = &self.repr; 108 let tag = match self.integer { 109 false => quote! { ::der::Tag::Enumerated }, 110 true => quote! { ::der::Tag::Integer }, 111 }; 112 113 let mut try_from_body = Vec::new(); 114 for variant in &self.variants { 115 try_from_body.push(variant.to_try_from_tokens()); 116 } 117 118 quote! { 119 impl<#default_lifetime> ::der::DecodeValue<#default_lifetime> for #ident { 120 fn decode_value<R: ::der::Reader<#default_lifetime>>( 121 reader: &mut R, 122 header: ::der::Header 123 ) -> ::der::Result<Self> { 124 <#repr as ::der::DecodeValue>::decode_value(reader, header)?.try_into() 125 } 126 } 127 128 impl ::der::EncodeValue for #ident { 129 fn value_len(&self) -> ::der::Result<::der::Length> { 130 ::der::EncodeValue::value_len(&(*self as #repr)) 131 } 132 133 fn encode_value(&self, encoder: &mut impl ::der::Writer) -> ::der::Result<()> { 134 ::der::EncodeValue::encode_value(&(*self as #repr), encoder) 135 } 136 } 137 138 impl ::der::FixedTag for #ident { 139 const TAG: ::der::Tag = #tag; 140 } 141 142 impl TryFrom<#repr> for #ident { 143 type Error = ::der::Error; 144 145 fn try_from(n: #repr) -> ::der::Result<Self> { 146 match n { 147 #(#try_from_body)* 148 _ => Err(#tag.value_error()) 149 } 150 } 151 } 152 } 153 } 154 } 155 156 /// "IR" for a variant of a derived `Enumerated`. 157 pub struct EnumeratedVariant { 158 /// Variant name. 159 ident: Ident, 160 161 /// Integer value that this variant corresponds to. 162 discriminant: LitInt, 163 } 164 165 impl EnumeratedVariant { 166 /// Create a new [`ChoiceVariant`] from the input [`Variant`]. new(input: &Variant) -> syn::Result<Self>167 fn new(input: &Variant) -> syn::Result<Self> { 168 for attr in &input.attrs { 169 if attr.path().is_ident(ATTR_NAME) { 170 abort!( 171 attr, 172 "`asn1` attribute is not allowed on fields of `Enumerated` types" 173 ); 174 } 175 } 176 177 match &input.discriminant { 178 Some(( 179 _, 180 Expr::Lit(ExprLit { 181 lit: Lit::Int(discriminant), 182 .. 183 }), 184 )) => Ok(Self { 185 ident: input.ident.clone(), 186 discriminant: discriminant.clone(), 187 }), 188 Some((_, other)) => abort!(other, "invalid discriminant for `Enumerated`"), 189 None => abort!(input, "`Enumerated` variant has no discriminant"), 190 } 191 } 192 193 /// Write the body for the derived [`TryFrom`] impl. to_try_from_tokens(&self) -> TokenStream194 pub fn to_try_from_tokens(&self) -> TokenStream { 195 let ident = &self.ident; 196 let discriminant = &self.discriminant; 197 quote! { 198 #discriminant => Ok(Self::#ident), 199 } 200 } 201 } 202 203 #[cfg(test)] 204 mod tests { 205 use super::DeriveEnumerated; 206 use syn::parse_quote; 207 208 /// X.509 `CRLReason`. 209 #[test] crlreason_example()210 fn crlreason_example() { 211 let input = parse_quote! { 212 #[repr(u32)] 213 pub enum CrlReason { 214 Unspecified = 0, 215 KeyCompromise = 1, 216 CaCompromise = 2, 217 AffiliationChanged = 3, 218 Superseded = 4, 219 CessationOfOperation = 5, 220 CertificateHold = 6, 221 RemoveFromCrl = 8, 222 PrivilegeWithdrawn = 9, 223 AaCompromised = 10, 224 } 225 }; 226 227 let ir = DeriveEnumerated::new(input).unwrap(); 228 assert_eq!(ir.ident, "CrlReason"); 229 assert_eq!(ir.repr, "u32"); 230 assert_eq!(ir.variants.len(), 10); 231 232 let unspecified = &ir.variants[0]; 233 assert_eq!(unspecified.ident, "Unspecified"); 234 assert_eq!(unspecified.discriminant.to_string(), "0"); 235 236 let key_compromise = &ir.variants[1]; 237 assert_eq!(key_compromise.ident, "KeyCompromise"); 238 assert_eq!(key_compromise.discriminant.to_string(), "1"); 239 240 let key_compromise = &ir.variants[2]; 241 assert_eq!(key_compromise.ident, "CaCompromise"); 242 assert_eq!(key_compromise.discriminant.to_string(), "2"); 243 } 244 } 245