1 //! Support for deriving the `Decode` and `Encode` traits on enums for 2 //! the purposes of decoding/encoding ASN.1 `CHOICE` types as mapped to 3 //! enum variants. 4 5 mod variant; 6 7 use self::variant::ChoiceVariant; 8 use crate::{default_lifetime, TypeAttrs}; 9 use proc_macro2::TokenStream; 10 use quote::quote; 11 use syn::{DeriveInput, Ident, Lifetime}; 12 13 /// Derive the `Choice` trait for an enum. 14 pub(crate) struct DeriveChoice { 15 /// Name of the enum type. 16 ident: Ident, 17 18 /// Lifetime of the type. 19 lifetime: Option<Lifetime>, 20 21 /// Variants of this `Choice`. 22 variants: Vec<ChoiceVariant>, 23 } 24 25 impl DeriveChoice { 26 /// Parse [`DeriveInput`]. new(input: DeriveInput) -> syn::Result<Self>27 pub fn new(input: DeriveInput) -> syn::Result<Self> { 28 let data = match input.data { 29 syn::Data::Enum(data) => data, 30 _ => abort!( 31 input.ident, 32 "can't derive `Choice` on this type: only `enum` types are allowed", 33 ), 34 }; 35 36 // TODO(tarcieri): properly handle multiple lifetimes 37 let lifetime = input 38 .generics 39 .lifetimes() 40 .next() 41 .map(|lt| lt.lifetime.clone()); 42 43 let type_attrs = TypeAttrs::parse(&input.attrs)?; 44 let variants = data 45 .variants 46 .iter() 47 .map(|variant| ChoiceVariant::new(variant, &type_attrs)) 48 .collect::<syn::Result<_>>()?; 49 50 Ok(Self { 51 ident: input.ident, 52 lifetime, 53 variants, 54 }) 55 } 56 57 /// Lower the derived output into a [`TokenStream`]. to_tokens(&self) -> TokenStream58 pub fn to_tokens(&self) -> TokenStream { 59 let ident = &self.ident; 60 61 let lifetime = match self.lifetime { 62 Some(ref lifetime) => quote!(#lifetime), 63 None => { 64 let lifetime = default_lifetime(); 65 quote!(#lifetime) 66 } 67 }; 68 69 // Lifetime parameters 70 // TODO(tarcieri): support multiple lifetimes 71 let lt_params = self 72 .lifetime 73 .as_ref() 74 .map(|_| lifetime.clone()) 75 .unwrap_or_default(); 76 77 let mut can_decode_body = Vec::new(); 78 let mut decode_body = Vec::new(); 79 let mut encode_body = Vec::new(); 80 let mut value_len_body = Vec::new(); 81 let mut tagged_body = Vec::new(); 82 83 for variant in &self.variants { 84 can_decode_body.push(variant.tag.to_tokens()); 85 decode_body.push(variant.to_decode_tokens()); 86 encode_body.push(variant.to_encode_value_tokens()); 87 value_len_body.push(variant.to_value_len_tokens()); 88 tagged_body.push(variant.to_tagged_tokens()); 89 } 90 91 quote! { 92 impl<#lifetime> ::der::Choice<#lifetime> for #ident<#lt_params> { 93 fn can_decode(tag: ::der::Tag) -> bool { 94 matches!(tag, #(#can_decode_body)|*) 95 } 96 } 97 98 impl<#lifetime> ::der::Decode<#lifetime> for #ident<#lt_params> { 99 fn decode<R: ::der::Reader<#lifetime>>(reader: &mut R) -> ::der::Result<Self> { 100 use der::Reader as _; 101 match reader.peek_tag()? { 102 #(#decode_body)* 103 actual => Err(der::ErrorKind::TagUnexpected { 104 expected: None, 105 actual 106 } 107 .into()), 108 } 109 } 110 } 111 112 impl<#lt_params> ::der::EncodeValue for #ident<#lt_params> { 113 fn encode_value(&self, encoder: &mut impl ::der::Writer) -> ::der::Result<()> { 114 match self { 115 #(#encode_body)* 116 } 117 } 118 119 fn value_len(&self) -> ::der::Result<::der::Length> { 120 match self { 121 #(#value_len_body)* 122 } 123 } 124 } 125 126 impl<#lt_params> ::der::Tagged for #ident<#lt_params> { 127 fn tag(&self) -> ::der::Tag { 128 match self { 129 #(#tagged_body)* 130 } 131 } 132 } 133 } 134 } 135 } 136 137 #[cfg(test)] 138 mod tests { 139 use super::DeriveChoice; 140 use crate::{Asn1Type, Tag, TagMode}; 141 use syn::parse_quote; 142 143 /// Based on `Time` as defined in RFC 5280: 144 /// <https://tools.ietf.org/html/rfc5280#page-117> 145 /// 146 /// ```text 147 /// Time ::= CHOICE { 148 /// utcTime UTCTime, 149 /// generalTime GeneralizedTime } 150 /// ``` 151 #[test] time_example()152 fn time_example() { 153 let input = parse_quote! { 154 pub enum Time { 155 #[asn1(type = "UTCTime")] 156 UtcTime(UtcTime), 157 158 #[asn1(type = "GeneralizedTime")] 159 GeneralTime(GeneralizedTime), 160 } 161 }; 162 163 let ir = DeriveChoice::new(input).unwrap(); 164 assert_eq!(ir.ident, "Time"); 165 assert_eq!(ir.lifetime, None); 166 assert_eq!(ir.variants.len(), 2); 167 168 let utc_time = &ir.variants[0]; 169 assert_eq!(utc_time.ident, "UtcTime"); 170 assert_eq!(utc_time.attrs.asn1_type, Some(Asn1Type::UtcTime)); 171 assert_eq!(utc_time.attrs.context_specific, None); 172 assert_eq!(utc_time.attrs.tag_mode, TagMode::Explicit); 173 assert_eq!(utc_time.tag, Tag::Universal(Asn1Type::UtcTime)); 174 175 let general_time = &ir.variants[1]; 176 assert_eq!(general_time.ident, "GeneralTime"); 177 assert_eq!( 178 general_time.attrs.asn1_type, 179 Some(Asn1Type::GeneralizedTime) 180 ); 181 assert_eq!(general_time.attrs.context_specific, None); 182 assert_eq!(general_time.attrs.tag_mode, TagMode::Explicit); 183 assert_eq!(general_time.tag, Tag::Universal(Asn1Type::GeneralizedTime)); 184 } 185 186 /// `IMPLICIT` tagged example 187 #[test] implicit_example()188 fn implicit_example() { 189 let input = parse_quote! { 190 #[asn1(tag_mode = "IMPLICIT")] 191 pub enum ImplicitChoice<'a> { 192 #[asn1(context_specific = "0", type = "BIT STRING")] 193 BitString(BitString<'a>), 194 195 #[asn1(context_specific = "1", type = "GeneralizedTime")] 196 Time(GeneralizedTime), 197 198 #[asn1(context_specific = "2", type = "UTF8String")] 199 Utf8String(String), 200 } 201 }; 202 203 let ir = DeriveChoice::new(input).unwrap(); 204 assert_eq!(ir.ident, "ImplicitChoice"); 205 assert_eq!(ir.lifetime.unwrap().to_string(), "'a"); 206 assert_eq!(ir.variants.len(), 3); 207 208 let bit_string = &ir.variants[0]; 209 assert_eq!(bit_string.ident, "BitString"); 210 assert_eq!(bit_string.attrs.asn1_type, Some(Asn1Type::BitString)); 211 assert_eq!( 212 bit_string.attrs.context_specific, 213 Some("0".parse().unwrap()) 214 ); 215 assert_eq!(bit_string.attrs.tag_mode, TagMode::Implicit); 216 assert_eq!( 217 bit_string.tag, 218 Tag::ContextSpecific { 219 constructed: false, 220 number: "0".parse().unwrap() 221 } 222 ); 223 224 let time = &ir.variants[1]; 225 assert_eq!(time.ident, "Time"); 226 assert_eq!(time.attrs.asn1_type, Some(Asn1Type::GeneralizedTime)); 227 assert_eq!(time.attrs.context_specific, Some("1".parse().unwrap())); 228 assert_eq!(time.attrs.tag_mode, TagMode::Implicit); 229 assert_eq!( 230 time.tag, 231 Tag::ContextSpecific { 232 constructed: false, 233 number: "1".parse().unwrap() 234 } 235 ); 236 237 let utf8_string = &ir.variants[2]; 238 assert_eq!(utf8_string.ident, "Utf8String"); 239 assert_eq!(utf8_string.attrs.asn1_type, Some(Asn1Type::Utf8String)); 240 assert_eq!( 241 utf8_string.attrs.context_specific, 242 Some("2".parse().unwrap()) 243 ); 244 assert_eq!(utf8_string.attrs.tag_mode, TagMode::Implicit); 245 assert_eq!( 246 utf8_string.tag, 247 Tag::ContextSpecific { 248 constructed: false, 249 number: "2".parse().unwrap() 250 } 251 ); 252 } 253 } 254