1 //! Support for deriving the `Decode` and `Encode` traits on enums for
2 //! the purposes of decoding/encoding ASN.1 `CHOICE` types as mapped to
3 //! enum variants.
4 
5 mod variant;
6 
7 use self::variant::ChoiceVariant;
8 use crate::{default_lifetime, TypeAttrs};
9 use proc_macro2::TokenStream;
10 use quote::quote;
11 use syn::{DeriveInput, Ident, Lifetime};
12 
13 /// Derive the `Choice` trait for an enum.
14 pub(crate) struct DeriveChoice {
15     /// Name of the enum type.
16     ident: Ident,
17 
18     /// Lifetime of the type.
19     lifetime: Option<Lifetime>,
20 
21     /// Variants of this `Choice`.
22     variants: Vec<ChoiceVariant>,
23 }
24 
25 impl DeriveChoice {
26     /// Parse [`DeriveInput`].
new(input: DeriveInput) -> syn::Result<Self>27     pub fn new(input: DeriveInput) -> syn::Result<Self> {
28         let data = match input.data {
29             syn::Data::Enum(data) => data,
30             _ => abort!(
31                 input.ident,
32                 "can't derive `Choice` on this type: only `enum` types are allowed",
33             ),
34         };
35 
36         // TODO(tarcieri): properly handle multiple lifetimes
37         let lifetime = input
38             .generics
39             .lifetimes()
40             .next()
41             .map(|lt| lt.lifetime.clone());
42 
43         let type_attrs = TypeAttrs::parse(&input.attrs)?;
44         let variants = data
45             .variants
46             .iter()
47             .map(|variant| ChoiceVariant::new(variant, &type_attrs))
48             .collect::<syn::Result<_>>()?;
49 
50         Ok(Self {
51             ident: input.ident,
52             lifetime,
53             variants,
54         })
55     }
56 
57     /// Lower the derived output into a [`TokenStream`].
to_tokens(&self) -> TokenStream58     pub fn to_tokens(&self) -> TokenStream {
59         let ident = &self.ident;
60 
61         let lifetime = match self.lifetime {
62             Some(ref lifetime) => quote!(#lifetime),
63             None => {
64                 let lifetime = default_lifetime();
65                 quote!(#lifetime)
66             }
67         };
68 
69         // Lifetime parameters
70         // TODO(tarcieri): support multiple lifetimes
71         let lt_params = self
72             .lifetime
73             .as_ref()
74             .map(|_| lifetime.clone())
75             .unwrap_or_default();
76 
77         let mut can_decode_body = Vec::new();
78         let mut decode_body = Vec::new();
79         let mut encode_body = Vec::new();
80         let mut value_len_body = Vec::new();
81         let mut tagged_body = Vec::new();
82 
83         for variant in &self.variants {
84             can_decode_body.push(variant.tag.to_tokens());
85             decode_body.push(variant.to_decode_tokens());
86             encode_body.push(variant.to_encode_value_tokens());
87             value_len_body.push(variant.to_value_len_tokens());
88             tagged_body.push(variant.to_tagged_tokens());
89         }
90 
91         quote! {
92             impl<#lifetime> ::der::Choice<#lifetime> for #ident<#lt_params> {
93                 fn can_decode(tag: ::der::Tag) -> bool {
94                     matches!(tag, #(#can_decode_body)|*)
95                 }
96             }
97 
98             impl<#lifetime> ::der::Decode<#lifetime> for #ident<#lt_params> {
99                 fn decode<R: ::der::Reader<#lifetime>>(reader: &mut R) -> ::der::Result<Self> {
100                     use der::Reader as _;
101                     match reader.peek_tag()? {
102                         #(#decode_body)*
103                         actual => Err(der::ErrorKind::TagUnexpected {
104                             expected: None,
105                             actual
106                         }
107                         .into()),
108                     }
109                 }
110             }
111 
112             impl<#lt_params> ::der::EncodeValue for #ident<#lt_params> {
113                 fn encode_value(&self, encoder: &mut impl ::der::Writer) -> ::der::Result<()> {
114                     match self {
115                         #(#encode_body)*
116                     }
117                 }
118 
119                 fn value_len(&self) -> ::der::Result<::der::Length> {
120                     match self {
121                         #(#value_len_body)*
122                     }
123                 }
124             }
125 
126             impl<#lt_params> ::der::Tagged for #ident<#lt_params> {
127                 fn tag(&self) -> ::der::Tag {
128                     match self {
129                         #(#tagged_body)*
130                     }
131                 }
132             }
133         }
134     }
135 }
136 
137 #[cfg(test)]
138 mod tests {
139     use super::DeriveChoice;
140     use crate::{Asn1Type, Tag, TagMode};
141     use syn::parse_quote;
142 
143     /// Based on `Time` as defined in RFC 5280:
144     /// <https://tools.ietf.org/html/rfc5280#page-117>
145     ///
146     /// ```text
147     /// Time ::= CHOICE {
148     ///      utcTime        UTCTime,
149     ///      generalTime    GeneralizedTime }
150     /// ```
151     #[test]
time_example()152     fn time_example() {
153         let input = parse_quote! {
154             pub enum Time {
155                 #[asn1(type = "UTCTime")]
156                 UtcTime(UtcTime),
157 
158                 #[asn1(type = "GeneralizedTime")]
159                 GeneralTime(GeneralizedTime),
160             }
161         };
162 
163         let ir = DeriveChoice::new(input).unwrap();
164         assert_eq!(ir.ident, "Time");
165         assert_eq!(ir.lifetime, None);
166         assert_eq!(ir.variants.len(), 2);
167 
168         let utc_time = &ir.variants[0];
169         assert_eq!(utc_time.ident, "UtcTime");
170         assert_eq!(utc_time.attrs.asn1_type, Some(Asn1Type::UtcTime));
171         assert_eq!(utc_time.attrs.context_specific, None);
172         assert_eq!(utc_time.attrs.tag_mode, TagMode::Explicit);
173         assert_eq!(utc_time.tag, Tag::Universal(Asn1Type::UtcTime));
174 
175         let general_time = &ir.variants[1];
176         assert_eq!(general_time.ident, "GeneralTime");
177         assert_eq!(
178             general_time.attrs.asn1_type,
179             Some(Asn1Type::GeneralizedTime)
180         );
181         assert_eq!(general_time.attrs.context_specific, None);
182         assert_eq!(general_time.attrs.tag_mode, TagMode::Explicit);
183         assert_eq!(general_time.tag, Tag::Universal(Asn1Type::GeneralizedTime));
184     }
185 
186     /// `IMPLICIT` tagged example
187     #[test]
implicit_example()188     fn implicit_example() {
189         let input = parse_quote! {
190             #[asn1(tag_mode = "IMPLICIT")]
191             pub enum ImplicitChoice<'a> {
192                 #[asn1(context_specific = "0", type = "BIT STRING")]
193                 BitString(BitString<'a>),
194 
195                 #[asn1(context_specific = "1", type = "GeneralizedTime")]
196                 Time(GeneralizedTime),
197 
198                 #[asn1(context_specific = "2", type = "UTF8String")]
199                 Utf8String(String),
200             }
201         };
202 
203         let ir = DeriveChoice::new(input).unwrap();
204         assert_eq!(ir.ident, "ImplicitChoice");
205         assert_eq!(ir.lifetime.unwrap().to_string(), "'a");
206         assert_eq!(ir.variants.len(), 3);
207 
208         let bit_string = &ir.variants[0];
209         assert_eq!(bit_string.ident, "BitString");
210         assert_eq!(bit_string.attrs.asn1_type, Some(Asn1Type::BitString));
211         assert_eq!(
212             bit_string.attrs.context_specific,
213             Some("0".parse().unwrap())
214         );
215         assert_eq!(bit_string.attrs.tag_mode, TagMode::Implicit);
216         assert_eq!(
217             bit_string.tag,
218             Tag::ContextSpecific {
219                 constructed: false,
220                 number: "0".parse().unwrap()
221             }
222         );
223 
224         let time = &ir.variants[1];
225         assert_eq!(time.ident, "Time");
226         assert_eq!(time.attrs.asn1_type, Some(Asn1Type::GeneralizedTime));
227         assert_eq!(time.attrs.context_specific, Some("1".parse().unwrap()));
228         assert_eq!(time.attrs.tag_mode, TagMode::Implicit);
229         assert_eq!(
230             time.tag,
231             Tag::ContextSpecific {
232                 constructed: false,
233                 number: "1".parse().unwrap()
234             }
235         );
236 
237         let utf8_string = &ir.variants[2];
238         assert_eq!(utf8_string.ident, "Utf8String");
239         assert_eq!(utf8_string.attrs.asn1_type, Some(Asn1Type::Utf8String));
240         assert_eq!(
241             utf8_string.attrs.context_specific,
242             Some("2".parse().unwrap())
243         );
244         assert_eq!(utf8_string.attrs.tag_mode, TagMode::Implicit);
245         assert_eq!(
246             utf8_string.tag,
247             Tag::ContextSpecific {
248                 constructed: false,
249                 number: "2".parse().unwrap()
250             }
251         );
252     }
253 }
254