1 //! Sequence field IR and lowerings
2 
3 use crate::{Asn1Type, FieldAttrs, TagMode, TagNumber, TypeAttrs};
4 use proc_macro2::TokenStream;
5 use quote::quote;
6 use syn::{Field, Ident, Path, Type};
7 
8 /// "IR" for a field of a derived `Sequence`.
9 pub(super) struct SequenceField {
10     /// Variant name.
11     pub(super) ident: Ident,
12 
13     /// Field-level attributes.
14     pub(super) attrs: FieldAttrs,
15 
16     /// Field type
17     pub(super) field_type: Type,
18 }
19 
20 impl SequenceField {
21     /// Create a new [`SequenceField`] from the input [`Field`].
new(field: &Field, type_attrs: &TypeAttrs) -> syn::Result<Self>22     pub(super) fn new(field: &Field, type_attrs: &TypeAttrs) -> syn::Result<Self> {
23         let ident = field.ident.as_ref().cloned().ok_or_else(|| {
24             syn::Error::new_spanned(
25                 field,
26                 "no name on struct field i.e. tuple structs unsupported",
27             )
28         })?;
29 
30         let attrs = FieldAttrs::parse(&field.attrs, type_attrs)?;
31 
32         if attrs.asn1_type.is_some() && attrs.default.is_some() {
33             return Err(syn::Error::new_spanned(
34                 ident,
35                 "ASN.1 `type` and `default` options cannot be combined",
36             ));
37         }
38 
39         if attrs.default.is_some() && attrs.optional {
40             return Err(syn::Error::new_spanned(
41                 ident,
42                 "`optional` and `default` field qualifiers are mutually exclusive",
43             ));
44         }
45 
46         Ok(Self {
47             ident,
48             attrs,
49             field_type: field.ty.clone(),
50         })
51     }
52 
53     /// Derive code for decoding a field of a sequence.
to_decode_tokens(&self) -> TokenStream54     pub(super) fn to_decode_tokens(&self) -> TokenStream {
55         let mut lowerer = LowerFieldDecoder::new(&self.attrs);
56 
57         if self.attrs.asn1_type.is_some() {
58             lowerer.apply_asn1_type(self.attrs.optional);
59         }
60 
61         if let Some(default) = &self.attrs.default {
62             // TODO(tarcieri): default in conjunction with ASN.1 types?
63             debug_assert!(
64                 self.attrs.asn1_type.is_none(),
65                 "`type` and `default` are mutually exclusive"
66             );
67 
68             // TODO(tarcieri): support for context-specific fields with defaults?
69             if self.attrs.context_specific.is_none() {
70                 lowerer.apply_default(default, &self.field_type);
71             }
72         }
73 
74         lowerer.into_tokens(&self.ident)
75     }
76 
77     /// Derive code for encoding a field of a sequence.
to_encode_tokens(&self) -> TokenStream78     pub(super) fn to_encode_tokens(&self) -> TokenStream {
79         let mut lowerer = LowerFieldEncoder::new(&self.ident);
80         let attrs = &self.attrs;
81 
82         if let Some(ty) = &attrs.asn1_type {
83             // TODO(tarcieri): default in conjunction with ASN.1 types?
84             debug_assert!(
85                 attrs.default.is_none(),
86                 "`type` and `default` are mutually exclusive"
87             );
88             lowerer.apply_asn1_type(ty, attrs.optional);
89         }
90 
91         if let Some(tag_number) = &attrs.context_specific {
92             lowerer.apply_context_specific(tag_number, &attrs.tag_mode, attrs.optional);
93         }
94 
95         if let Some(default) = &attrs.default {
96             debug_assert!(
97                 !attrs.optional,
98                 "`default`, and `optional` are mutually exclusive"
99             );
100             lowerer.apply_default(&self.ident, default, &self.field_type);
101         }
102 
103         lowerer.into_tokens()
104     }
105 }
106 
107 /// AST lowerer for field decoders.
108 struct LowerFieldDecoder {
109     /// Decoder-in-progress.
110     decoder: TokenStream,
111 }
112 
113 impl LowerFieldDecoder {
114     /// Create a new field decoder lowerer.
new(attrs: &FieldAttrs) -> Self115     fn new(attrs: &FieldAttrs) -> Self {
116         Self {
117             decoder: attrs.decoder(),
118         }
119     }
120 
121     ///  the field decoder to tokens.
into_tokens(self, ident: &Ident) -> TokenStream122     fn into_tokens(self, ident: &Ident) -> TokenStream {
123         let decoder = self.decoder;
124 
125         quote! {
126             let #ident = #decoder;
127         }
128     }
129 
130     /// Apply the ASN.1 type (if defined).
apply_asn1_type(&mut self, optional: bool)131     fn apply_asn1_type(&mut self, optional: bool) {
132         let decoder = &self.decoder;
133 
134         self.decoder = if optional {
135             quote! {
136                 #decoder.map(TryInto::try_into).transpose()?
137             }
138         } else {
139             quote! {
140                 #decoder.try_into()?
141             }
142         }
143     }
144 
145     /// Handle default value for a type.
apply_default(&mut self, default: &Path, field_type: &Type)146     fn apply_default(&mut self, default: &Path, field_type: &Type) {
147         self.decoder = quote! {
148             Option::<#field_type>::decode(reader)?.unwrap_or_else(#default);
149         };
150     }
151 }
152 
153 /// AST lowerer for field encoders.
154 struct LowerFieldEncoder {
155     /// Encoder-in-progress.
156     encoder: TokenStream,
157 }
158 
159 impl LowerFieldEncoder {
160     /// Create a new field encoder lowerer.
new(ident: &Ident) -> Self161     fn new(ident: &Ident) -> Self {
162         Self {
163             encoder: quote!(self.#ident),
164         }
165     }
166 
167     ///  the field encoder to tokens.
into_tokens(self) -> TokenStream168     fn into_tokens(self) -> TokenStream {
169         self.encoder
170     }
171 
172     /// Apply the ASN.1 type (if defined).
apply_asn1_type(&mut self, asn1_type: &Asn1Type, optional: bool)173     fn apply_asn1_type(&mut self, asn1_type: &Asn1Type, optional: bool) {
174         let binding = &self.encoder;
175 
176         self.encoder = if optional {
177             let map_arg = quote!(field);
178             let encoder = asn1_type.encoder(&map_arg);
179 
180             quote! {
181                 #binding.as_ref().map(|#map_arg| {
182                     der::Result::Ok(#encoder)
183                 }).transpose()?
184             }
185         } else {
186             let encoder = asn1_type.encoder(binding);
187             quote!(#encoder)
188         };
189     }
190 
191     /// Handle default value for a type.
apply_default(&mut self, ident: &Ident, default: &Path, field_type: &Type)192     fn apply_default(&mut self, ident: &Ident, default: &Path, field_type: &Type) {
193         let encoder = &self.encoder;
194 
195         self.encoder = quote! {
196             {
197                 let default_value: #field_type = #default();
198                 if &self.#ident == &default_value {
199                     None
200                 } else {
201                     Some(#encoder)
202                 }
203             }
204         };
205     }
206 
207     /// Make this field context-specific.
apply_context_specific( &mut self, tag_number: &TagNumber, tag_mode: &TagMode, optional: bool, )208     fn apply_context_specific(
209         &mut self,
210         tag_number: &TagNumber,
211         tag_mode: &TagMode,
212         optional: bool,
213     ) {
214         let encoder = &self.encoder;
215         let number_tokens = tag_number.to_tokens();
216         let mode_tokens = tag_mode.to_tokens();
217 
218         if optional {
219             self.encoder = quote! {
220                 #encoder.as_ref().map(|field| {
221                     ::der::asn1::ContextSpecificRef {
222                         tag_number: #number_tokens,
223                         tag_mode: #mode_tokens,
224                         value: field,
225                     }
226                 })
227             };
228         } else {
229             self.encoder = quote! {
230                 ::der::asn1::ContextSpecificRef {
231                     tag_number: #number_tokens,
232                     tag_mode: #mode_tokens,
233                     value: &#encoder,
234                 }
235             };
236         }
237     }
238 }
239 
240 #[cfg(test)]
241 mod tests {
242     use super::SequenceField;
243     use crate::{FieldAttrs, TagMode, TagNumber};
244     use proc_macro2::Span;
245     use quote::quote;
246     use syn::{punctuated::Punctuated, Ident, Path, PathSegment, Type, TypePath};
247 
248     /// Create a [`Type::Path`].
type_path(ident: Ident) -> Type249     pub fn type_path(ident: Ident) -> Type {
250         let mut segments = Punctuated::new();
251         segments.push_value(PathSegment {
252             ident,
253             arguments: Default::default(),
254         });
255 
256         Type::Path(TypePath {
257             qself: None,
258             path: Path {
259                 leading_colon: None,
260                 segments,
261             },
262         })
263     }
264 
265     #[test]
simple()266     fn simple() {
267         let span = Span::call_site();
268         let ident = Ident::new("example_field", span);
269 
270         let attrs = FieldAttrs {
271             asn1_type: None,
272             context_specific: None,
273             default: None,
274             extensible: false,
275             optional: false,
276             tag_mode: TagMode::Explicit,
277             constructed: false,
278         };
279 
280         let field_type = Ident::new("String", span);
281 
282         let field = SequenceField {
283             ident,
284             attrs,
285             field_type: type_path(field_type),
286         };
287 
288         assert_eq!(
289             field.to_decode_tokens().to_string(),
290             quote! {
291                 let example_field = reader.decode()?;
292             }
293             .to_string()
294         );
295 
296         assert_eq!(
297             field.to_encode_tokens().to_string(),
298             quote! {
299                 self.example_field
300             }
301             .to_string()
302         );
303     }
304 
305     #[test]
implicit()306     fn implicit() {
307         let span = Span::call_site();
308         let ident = Ident::new("implicit_field", span);
309 
310         let attrs = FieldAttrs {
311             asn1_type: None,
312             context_specific: Some(TagNumber(0)),
313             default: None,
314             extensible: false,
315             optional: false,
316             tag_mode: TagMode::Implicit,
317             constructed: false,
318         };
319 
320         let field_type = Ident::new("String", span);
321 
322         let field = SequenceField {
323             ident,
324             attrs,
325             field_type: type_path(field_type),
326         };
327 
328         assert_eq!(
329             field.to_decode_tokens().to_string(),
330             quote! {
331                 let implicit_field = ::der::asn1::ContextSpecific::<>::decode_implicit(
332                         reader,
333                         ::der::TagNumber::N0
334                     )?
335                     .ok_or_else(|| {
336                         der::Tag::ContextSpecific {
337                             number: ::der::TagNumber::N0,
338                             constructed: false
339                         }
340                         .value_error()
341                     })?
342                     .value;
343             }
344             .to_string()
345         );
346 
347         assert_eq!(
348             field.to_encode_tokens().to_string(),
349             quote! {
350                 ::der::asn1::ContextSpecificRef {
351                     tag_number: ::der::TagNumber::N0,
352                     tag_mode: ::der::TagMode::Implicit,
353                     value: &self.implicit_field,
354                 }
355             }
356             .to_string()
357         );
358     }
359 }
360