1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use std::str::FromStr;
6 
7 use darling::{
8     ast::{self, Fields},
9     FromDeriveInput, FromField, FromVariant,
10 };
11 use proc_macro2::{Literal, TokenStream};
12 use quote::quote;
13 use syn::{
14     parse_macro_input, parse_quote, Attribute, DeriveInput, Expr, Generics, Ident, Index, Lit, Path,
15 };
16 
17 enum Operation {
18     Size,
19     Encode,
20     Decode,
21 }
22 
23 impl Operation {
path(&self) -> Path24     fn path(&self) -> Path {
25         match self {
26             Operation::Size => parse_quote! { mls_rs_codec::MlsSize },
27             Operation::Encode => parse_quote! { mls_rs_codec::MlsEncode },
28             Operation::Decode => parse_quote! { mls_rs_codec::MlsDecode },
29         }
30     }
31 
call(&self) -> TokenStream32     fn call(&self) -> TokenStream {
33         match self {
34             Operation::Size => quote! { mls_encoded_len },
35             Operation::Encode => quote! { mls_encode },
36             Operation::Decode => quote! { mls_decode },
37         }
38     }
39 
extras(&self) -> TokenStream40     fn extras(&self) -> TokenStream {
41         match self {
42             Operation::Size => quote! {},
43             Operation::Encode => quote! { , writer },
44             Operation::Decode => quote! { reader },
45         }
46     }
47 
is_result(&self) -> bool48     fn is_result(&self) -> bool {
49         match self {
50             Operation::Size => false,
51             Operation::Encode => true,
52             Operation::Decode => true,
53         }
54     }
55 }
56 
57 #[derive(Debug, FromField)]
58 #[darling(attributes(mls_codec))]
59 struct MlsFieldReceiver {
60     ident: Option<Ident>,
61     with: Option<Path>,
62 }
63 
64 impl MlsFieldReceiver {
call_tokens(&self, index: Index) -> TokenStream65     pub fn call_tokens(&self, index: Index) -> TokenStream {
66         if let Some(ref ident) = self.ident {
67             quote! { &self.#ident }
68         } else {
69             quote! { &self.#index }
70         }
71     }
72 
name(&self, index: Index) -> TokenStream73     pub fn name(&self, index: Index) -> TokenStream {
74         if let Some(ref ident) = self.ident {
75             quote! {#ident: }
76         } else {
77             quote! { #index: }
78         }
79     }
80 }
81 
82 #[derive(Debug, FromVariant)]
83 #[darling(attributes(mls_codec))]
84 struct MlsVariantReceiver {
85     ident: Ident,
86     discriminant: Option<Expr>,
87     fields: ast::Fields<MlsFieldReceiver>,
88 }
89 
90 #[derive(FromDeriveInput)]
91 #[darling(attributes(mls_codec), forward_attrs(repr))]
92 struct MlsInputReceiver {
93     attrs: Vec<Attribute>,
94     ident: Ident,
95     generics: Generics,
96     data: ast::Data<MlsVariantReceiver, MlsFieldReceiver>,
97 }
98 
99 impl MlsInputReceiver {
handle_input(&self, operation: Operation) -> TokenStream100     fn handle_input(&self, operation: Operation) -> TokenStream {
101         match self.data {
102             ast::Data::Struct(ref s) => struct_impl(s, operation),
103             ast::Data::Enum(ref e) => enum_impl(&self.ident, &self.attrs, e, operation),
104         }
105     }
106 }
107 
repr_ident(attrs: &[Attribute]) -> Option<Ident>108 fn repr_ident(attrs: &[Attribute]) -> Option<Ident> {
109     let repr_path = attrs
110         .iter()
111         .filter(|attr| matches!(attr.style, syn::AttrStyle::Outer))
112         .find(|attr| attr.path().is_ident("repr"))
113         .map(|repr| repr.parse_args())
114         .transpose()
115         .ok()
116         .flatten();
117 
118     let Some(Expr::Path(path)) = repr_path else {
119         return None;
120     };
121 
122     path.path
123         .segments
124         .iter()
125         .find(|s| s.ident != "C")
126         .map(|path| path.ident.clone())
127 }
128 
129 /// Provides the discriminant for a given variant. If the variant does not specify a suffix
130 /// and a `repr_ident` is provided, it will be appended to number.
discriminant_for_variant( variant: &MlsVariantReceiver, repr_ident: &Option<Ident>, ) -> TokenStream131 fn discriminant_for_variant(
132     variant: &MlsVariantReceiver,
133     repr_ident: &Option<Ident>,
134 ) -> TokenStream {
135     let discriminant = variant
136         .discriminant
137         .clone()
138         .expect("Enum discriminants must be explicitly defined");
139 
140     let Expr::Lit(lit_expr) = &discriminant else {
141         return quote! {#discriminant};
142     };
143 
144     let Lit::Int(lit_int) = &lit_expr.lit else {
145         return quote! {#discriminant};
146     };
147 
148     if lit_int.suffix().is_empty() {
149         // This is dirty and there is probably a better way of doing this but I'm way too much of a noob at
150         // proc macros to pull it off...
151         // TODO: Add proper support for correctly ignoring transparent, packed and modifiers
152         let str = format!(
153             "{}{}",
154             lit_int.base10_digits(),
155             &repr_ident.clone().expect("Expected a repr(u*) to be provided or for the variant's discriminant to be defined with suffixed literals.")
156         );
157         Literal::from_str(&str)
158             .map(|l| quote! {#l})
159             .ok()
160             .unwrap_or_else(|| quote! {#discriminant})
161     } else {
162         quote! {#discriminant}
163     }
164 }
165 
enum_impl( ident: &Ident, attrs: &[Attribute], variants: &[MlsVariantReceiver], operation: Operation, ) -> TokenStream166 fn enum_impl(
167     ident: &Ident,
168     attrs: &[Attribute],
169     variants: &[MlsVariantReceiver],
170     operation: Operation,
171 ) -> TokenStream {
172     let handle_error = operation.is_result().then_some(quote! { ? });
173     let path = operation.path();
174     let call = operation.call();
175     let extras = operation.extras();
176     let enum_name = &ident;
177     let repr_ident = repr_ident(attrs);
178     if matches!(operation, Operation::Decode) {
179         let cases = variants.iter().map(|variant| {
180             let variant_name = &variant.ident;
181 
182             let discriminant = discriminant_for_variant(variant, &repr_ident);
183 
184             // TODO: Support more than 1 field
185             match variant.fields.len() {
186                 0 => quote! { #discriminant => Ok(#enum_name::#variant_name), },
187                 1 =>{
188                     let path = variant.fields.fields[0].with.as_ref().unwrap_or(&path);
189                     quote! { #discriminant => Ok(#enum_name::#variant_name(#path::#call(#extras) #handle_error)), }
190                 },
191                 _ => panic!("Enum discriminants with more than 1 field are not currently supported")
192             }
193         });
194 
195         return quote! {
196             let discriminant = #path::#call(#extras)#handle_error;
197 
198             match discriminant {
199                 #(#cases)*
200                 _ => Err(mls_rs_codec::Error::UnsupportedEnumDiscriminant),
201             }
202         };
203     }
204 
205     let cases = variants.iter().map(|variant| {
206         let variant_name = &variant.ident;
207 
208         let discriminant = discriminant_for_variant(variant, &repr_ident);
209 
210         let (parameter, field) = if variant.fields.is_empty() {
211             (None, None)
212         } else {
213             let path = variant.fields.fields[0].with.as_ref().unwrap_or(&path);
214 
215             let start = match operation {
216                 Operation::Size => Some(quote! { + }),
217                 Operation::Encode => Some(quote! {;}),
218                 Operation::Decode => None,
219             };
220 
221             (
222                 Some(quote! {(ref val)}),
223                 Some(quote! { #start #path::#call (val #extras) #handle_error }),
224             )
225         };
226 
227         let discrim = quote! { #path::#call (&#discriminant #extras) #handle_error };
228 
229         quote! { #enum_name::#variant_name #parameter => { #discrim #field }}
230     });
231 
232     let enum_impl = quote! {
233         match self {
234             #(#cases)*
235         }
236     };
237 
238     if operation.is_result() {
239         quote! {
240             Ok(#enum_impl)
241         }
242     } else {
243         enum_impl
244     }
245 }
246 
struct_impl(s: &Fields<MlsFieldReceiver>, operation: Operation) -> TokenStream247 fn struct_impl(s: &Fields<MlsFieldReceiver>, operation: Operation) -> TokenStream {
248     let recurse = s.fields.iter().enumerate().map(|(index, field)| {
249         let (call_tokens, field_name) = match operation {
250             Operation::Size | Operation::Encode => {
251                 (field.call_tokens(Index::from(index)), quote! {})
252             }
253             Operation::Decode => (quote! {}, field.name(Index::from(index))),
254         };
255 
256         let handle_error = operation.is_result().then_some(quote! { ? });
257         let path = field.with.clone().unwrap_or(operation.path());
258         let call = operation.call();
259         let extras = operation.extras();
260 
261         quote! {
262            #field_name #path::#call (#call_tokens #extras) #handle_error
263         }
264     });
265 
266     match operation {
267         Operation::Size => quote! { 0 #(+ #recurse)* },
268         Operation::Encode => quote! { #(#recurse;)* Ok(()) },
269         Operation::Decode => quote! { Ok(Self { #(#recurse,)* }) },
270     }
271 }
272 
derive_impl<F>( input: proc_macro::TokenStream, trait_name: TokenStream, function_def: TokenStream, internals: F, ) -> proc_macro::TokenStream where F: FnOnce(&MlsInputReceiver) -> TokenStream,273 fn derive_impl<F>(
274     input: proc_macro::TokenStream,
275     trait_name: TokenStream,
276     function_def: TokenStream,
277     internals: F,
278 ) -> proc_macro::TokenStream
279 where
280     F: FnOnce(&MlsInputReceiver) -> TokenStream,
281 {
282     let input = parse_macro_input!(input as DeriveInput);
283 
284     let input = MlsInputReceiver::from_derive_input(&input).unwrap();
285 
286     let name = &input.ident;
287 
288     let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
289 
290     // Generate an expression to sum up the heap size of each field.
291     let function_impl = internals(&input);
292 
293     let expanded = quote! {
294         // The generated impl.
295         impl #impl_generics #trait_name for #name #ty_generics #where_clause {
296             #function_def {
297                 #function_impl
298             }
299         }
300     };
301 
302     // Hand the output tokens back to the compiler.
303     proc_macro::TokenStream::from(expanded)
304 }
305 
306 #[proc_macro_derive(MlsSize, attributes(mls_codec))]
derive_size(input: proc_macro::TokenStream) -> proc_macro::TokenStream307 pub fn derive_size(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
308     let trait_name = quote! { mls_rs_codec::MlsSize };
309     let function_def = quote! {fn mls_encoded_len(&self) -> usize };
310 
311     derive_impl(input, trait_name, function_def, |input| {
312         input.handle_input(Operation::Size)
313     })
314 }
315 
316 #[proc_macro_derive(MlsEncode, attributes(mls_codec))]
derive_encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream317 pub fn derive_encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
318     let trait_name = quote! { mls_rs_codec::MlsEncode };
319 
320     let function_def = quote! { fn mls_encode(&self, writer: &mut mls_rs_codec::Vec<u8>) -> Result<(), mls_rs_codec::Error> };
321 
322     derive_impl(input, trait_name, function_def, |input| {
323         input.handle_input(Operation::Encode)
324     })
325 }
326 
327 #[proc_macro_derive(MlsDecode, attributes(mls_codec))]
derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream328 pub fn derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
329     let trait_name = quote! { mls_rs_codec::MlsDecode };
330 
331     let function_def =
332         quote! { fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> };
333 
334     derive_impl(input, trait_name, function_def, |input| {
335         input.handle_input(Operation::Decode)
336     })
337 }
338