1 //! Attribute-related types used by the proc macro
2 
3 use crate::{Asn1Type, Tag, TagMode, TagNumber};
4 use proc_macro2::{Span, TokenStream};
5 use quote::quote;
6 use std::{fmt::Debug, str::FromStr};
7 use syn::punctuated::Punctuated;
8 use syn::{parse::Parse, parse::ParseStream, Attribute, Ident, LitStr, Path, Token};
9 
10 /// Attribute name.
11 pub(crate) const ATTR_NAME: &str = "asn1";
12 
13 /// Attributes on a `struct` or `enum` type.
14 #[derive(Clone, Debug, Default)]
15 pub(crate) struct TypeAttrs {
16     /// Tagging mode for this type: `EXPLICIT` or `IMPLICIT`, supplied as
17     /// `#[asn1(tag_mode = "...")]`.
18     ///
19     /// The default value is `EXPLICIT`.
20     pub tag_mode: TagMode,
21 }
22 
23 impl TypeAttrs {
24     /// Parse attributes from a struct field or enum variant.
parse(attrs: &[Attribute]) -> syn::Result<Self>25     pub fn parse(attrs: &[Attribute]) -> syn::Result<Self> {
26         let mut tag_mode = None;
27 
28         let mut parsed_attrs = Vec::new();
29         AttrNameValue::from_attributes(attrs, &mut parsed_attrs)?;
30 
31         for attr in parsed_attrs {
32             // `tag_mode = "..."` attribute
33             let mode = attr.parse_value("tag_mode")?.ok_or_else(|| {
34                 syn::Error::new_spanned(
35                     &attr.name,
36                     "invalid `asn1` attribute (valid options are `tag_mode`)",
37                 )
38             })?;
39 
40             if tag_mode.is_some() {
41                 return Err(syn::Error::new_spanned(
42                     &attr.name,
43                     "duplicate ASN.1 `tag_mode` attribute",
44                 ));
45             }
46 
47             tag_mode = Some(mode);
48         }
49 
50         Ok(Self {
51             tag_mode: tag_mode.unwrap_or_default(),
52         })
53     }
54 }
55 
56 /// Field-level attributes.
57 #[derive(Clone, Debug, Default)]
58 pub(crate) struct FieldAttrs {
59     /// Value of the `#[asn1(type = "...")]` attribute if provided.
60     pub asn1_type: Option<Asn1Type>,
61 
62     /// Value of the `#[asn1(context_specific = "...")] attribute if provided.
63     pub context_specific: Option<TagNumber>,
64 
65     /// Indicates name of function that supplies the default value, which will be used in cases
66     /// where encoding is omitted per DER and to omit the encoding per DER
67     pub default: Option<Path>,
68 
69     /// Is this field "extensible", i.e. preceded by the `...` extensibility marker?
70     pub extensible: bool,
71 
72     /// Is this field `OPTIONAL`?
73     pub optional: bool,
74 
75     /// Tagging mode for this type: `EXPLICIT` or `IMPLICIT`, supplied as
76     /// `#[asn1(tag_mode = "...")]`.
77     ///
78     /// Inherits from the type-level tagging mode if specified, or otherwise
79     /// defaults to `EXPLICIT`.
80     pub tag_mode: TagMode,
81 
82     /// Is the inner type constructed?
83     pub constructed: bool,
84 }
85 
86 impl FieldAttrs {
87     /// Return true when either an optional or default ASN.1 attribute is associated
88     /// with a field. Default signifies optionality due to omission of default values in
89     /// DER encodings.
is_optional(&self) -> bool90     fn is_optional(&self) -> bool {
91         self.optional || self.default.is_some()
92     }
93 
94     /// Parse attributes from a struct field or enum variant.
parse(attrs: &[Attribute], type_attrs: &TypeAttrs) -> syn::Result<Self>95     pub fn parse(attrs: &[Attribute], type_attrs: &TypeAttrs) -> syn::Result<Self> {
96         let mut asn1_type = None;
97         let mut context_specific = None;
98         let mut default = None;
99         let mut extensible = None;
100         let mut optional = None;
101         let mut tag_mode = None;
102         let mut constructed = None;
103 
104         let mut parsed_attrs = Vec::new();
105         AttrNameValue::from_attributes(attrs, &mut parsed_attrs)?;
106 
107         for attr in parsed_attrs {
108             // `context_specific = "..."` attribute
109             if let Some(tag_number) = attr.parse_value("context_specific")? {
110                 if context_specific.is_some() {
111                     abort!(attr.name, "duplicate ASN.1 `context_specific` attribute");
112                 }
113 
114                 context_specific = Some(tag_number);
115             // `default` attribute
116             } else if attr.parse_value::<String>("default")?.is_some() {
117                 if default.is_some() {
118                     abort!(attr.name, "duplicate ASN.1 `default` attribute");
119                 }
120 
121                 default = Some(attr.value.parse().map_err(|e| {
122                     syn::Error::new_spanned(
123                         attr.value,
124                         format_args!("error parsing ASN.1 `default` attribute: {e}"),
125                     )
126                 })?);
127             // `extensible` attribute
128             } else if let Some(ext) = attr.parse_value("extensible")? {
129                 if extensible.is_some() {
130                     abort!(attr.name, "duplicate ASN.1 `extensible` attribute");
131                 }
132 
133                 extensible = Some(ext);
134             // `optional` attribute
135             } else if let Some(opt) = attr.parse_value("optional")? {
136                 if optional.is_some() {
137                     abort!(attr.name, "duplicate ASN.1 `optional` attribute");
138                 }
139 
140                 optional = Some(opt);
141             // `tag_mode` attribute
142             } else if let Some(mode) = attr.parse_value("tag_mode")? {
143                 if tag_mode.is_some() {
144                     abort!(attr.name, "duplicate ASN.1 `tag_mode` attribute");
145                 }
146 
147                 tag_mode = Some(mode);
148             // `type = "..."` attribute
149             } else if let Some(ty) = attr.parse_value("type")? {
150                 if asn1_type.is_some() {
151                     abort!(attr.name, "duplicate ASN.1 `type` attribute");
152                 }
153 
154                 asn1_type = Some(ty);
155             // `constructed = "..."` attribute
156             } else if let Some(ty) = attr.parse_value("constructed")? {
157                 if constructed.is_some() {
158                     abort!(attr.name, "duplicate ASN.1 `constructed` attribute");
159                 }
160 
161                 constructed = Some(ty);
162             } else {
163                 abort!(
164                     attr.name,
165                     "unknown field-level `asn1` attribute \
166                     (valid options are `context_specific`, `type`)",
167                 );
168             }
169         }
170 
171         Ok(Self {
172             asn1_type,
173             context_specific,
174             default,
175             extensible: extensible.unwrap_or_default(),
176             optional: optional.unwrap_or_default(),
177             tag_mode: tag_mode.unwrap_or(type_attrs.tag_mode),
178             constructed: constructed.unwrap_or_default(),
179         })
180     }
181 
182     /// Get the expected [`Tag`] for this field.
tag(&self) -> syn::Result<Option<Tag>>183     pub fn tag(&self) -> syn::Result<Option<Tag>> {
184         match self.context_specific {
185             Some(tag_number) => Ok(Some(Tag::ContextSpecific {
186                 constructed: self.constructed,
187                 number: tag_number,
188             })),
189 
190             None => match self.tag_mode {
191                 TagMode::Explicit => Ok(self.asn1_type.map(Tag::Universal)),
192                 TagMode::Implicit => Err(syn::Error::new(
193                     Span::call_site(),
194                     "implicit tagging requires a `tag_number`",
195                 )),
196             },
197         }
198     }
199 
200     /// Get a `der::Decoder` object which respects these field attributes.
decoder(&self) -> TokenStream201     pub fn decoder(&self) -> TokenStream {
202         if let Some(tag_number) = self.context_specific {
203             let type_params = self.asn1_type.map(|ty| ty.type_path()).unwrap_or_default();
204             let tag_number = tag_number.to_tokens();
205 
206             let context_specific = match self.tag_mode {
207                 TagMode::Explicit => {
208                     if self.extensible || self.is_optional() {
209                         quote! {
210                             ::der::asn1::ContextSpecific::<#type_params>::decode_explicit(
211                                 reader,
212                                 #tag_number
213                             )?
214                         }
215                     } else {
216                         quote! {
217                             match ::der::asn1::ContextSpecific::<#type_params>::decode(reader)? {
218                                 field if field.tag_number == #tag_number => Some(field),
219                                 _ => None
220                             }
221                         }
222                     }
223                 }
224                 TagMode::Implicit => {
225                     quote! {
226                         ::der::asn1::ContextSpecific::<#type_params>::decode_implicit(
227                             reader,
228                             #tag_number
229                         )?
230                     }
231                 }
232             };
233 
234             if self.is_optional() {
235                 if let Some(default) = &self.default {
236                     quote!(#context_specific.map(|cs| cs.value).unwrap_or_else(#default))
237                 } else {
238                     quote!(#context_specific.map(|cs| cs.value))
239                 }
240             } else {
241                 // TODO(tarcieri): better error handling?
242                 let constructed = self.constructed;
243                 quote! {
244                     #context_specific.ok_or_else(|| {
245                         der::Tag::ContextSpecific {
246                             number: #tag_number,
247                             constructed: #constructed
248                         }.value_error()
249                     })?.value
250                 }
251             }
252         } else if let Some(default) = &self.default {
253             let type_params = self.asn1_type.map(|ty| ty.type_path()).unwrap_or_default();
254             self.asn1_type.map(|ty| ty.decoder()).unwrap_or_else(|| {
255                 quote! {
256                     Option::<#type_params>::decode(reader)?.unwrap_or_else(#default),
257                 }
258             })
259         } else {
260             self.asn1_type
261                 .map(|ty| ty.decoder())
262                 .unwrap_or_else(|| quote!(reader.decode()?))
263         }
264     }
265 
266     /// Get tokens to encode the binding using `::der::EncodeValue`.
value_encode(&self, binding: &TokenStream) -> TokenStream267     pub fn value_encode(&self, binding: &TokenStream) -> TokenStream {
268         match self.context_specific {
269             Some(tag_number) => {
270                 let tag_number = tag_number.to_tokens();
271                 let tag_mode = self.tag_mode.to_tokens();
272                 quote! {
273                     ::der::asn1::ContextSpecificRef {
274                         tag_number: #tag_number,
275                         tag_mode: #tag_mode,
276                         value: #binding,
277                     }.encode_value(encoder)
278                 }
279             }
280 
281             None => self
282                 .asn1_type
283                 .map(|ty| {
284                     let encoder_obj = ty.encoder(binding);
285                     quote!(#encoder_obj.encode_value(encoder))
286                 })
287                 .unwrap_or_else(|| quote!(#binding.encode_value(encoder))),
288         }
289     }
290 }
291 
292 /// Name/value pair attribute.
293 pub(crate) struct AttrNameValue {
294     /// Attribute name.
295     pub name: Path,
296 
297     /// Attribute value.
298     pub value: LitStr,
299 }
300 
301 impl Parse for AttrNameValue {
parse(input: ParseStream<'_>) -> syn::Result<Self>302     fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
303         let name = match input.parse() {
304             Ok(name) => name,
305             // If it doesn't parse as a path, check if it's the keyword `type`
306             // The asn1 macro uses this even though Path cannot technically contain
307             // non-identifiers, so it needs to be forced in.
308             Err(e) => {
309                 if let Ok(tok) = input.parse::<Token![type]>() {
310                     Path::from(Ident::new("type", tok.span))
311                 } else {
312                     // If it still doesn't parse, report the original error rather than the
313                     // one produced by the workaround.
314                     return Err(e);
315                 }
316             }
317         };
318         input.parse::<Token![=]>()?;
319         let value = input.parse()?;
320         Ok(Self { name, value })
321     }
322 }
323 
324 impl AttrNameValue {
parse_attribute(attr: &Attribute) -> syn::Result<impl IntoIterator<Item = Self>>325     pub fn parse_attribute(attr: &Attribute) -> syn::Result<impl IntoIterator<Item = Self>> {
326         attr.parse_args_with(Punctuated::<Self, Token![,]>::parse_terminated)
327     }
328 
329     /// Parse a slice of attributes.
from_attributes(attrs: &[Attribute], out: &mut Vec<Self>) -> syn::Result<()>330     pub fn from_attributes(attrs: &[Attribute], out: &mut Vec<Self>) -> syn::Result<()> {
331         for attr in attrs {
332             if !attr.path().is_ident(ATTR_NAME) {
333                 continue;
334             }
335 
336             match Self::parse_attribute(attr) {
337                 Ok(parsed) => out.extend(parsed),
338                 Err(e) => abort!(attr, e),
339             }
340         }
341 
342         Ok(())
343     }
344 
345     /// Parse an attribute value if the name matches the specified one.
parse_value<T>(&self, name: &str) -> syn::Result<Option<T>> where T: FromStr + Debug, T::Err: Debug,346     pub fn parse_value<T>(&self, name: &str) -> syn::Result<Option<T>>
347     where
348         T: FromStr + Debug,
349         T::Err: Debug,
350     {
351         Ok(if self.name.is_ident(name) {
352             Some(
353                 self.value
354                     .value()
355                     .parse()
356                     .map_err(|_| syn::Error::new_spanned(&self.name, "error parsing attribute"))?,
357             )
358         } else {
359             None
360         })
361     }
362 }
363