1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5 use std::str::FromStr;
6
7 use darling::{
8 ast::{self, Fields},
9 FromDeriveInput, FromField, FromVariant,
10 };
11 use proc_macro2::{Literal, TokenStream};
12 use quote::quote;
13 use syn::{
14 parse_macro_input, parse_quote, Attribute, DeriveInput, Expr, Generics, Ident, Index, Lit, Path,
15 };
16
17 enum Operation {
18 Size,
19 Encode,
20 Decode,
21 }
22
23 impl Operation {
path(&self) -> Path24 fn path(&self) -> Path {
25 match self {
26 Operation::Size => parse_quote! { mls_rs_codec::MlsSize },
27 Operation::Encode => parse_quote! { mls_rs_codec::MlsEncode },
28 Operation::Decode => parse_quote! { mls_rs_codec::MlsDecode },
29 }
30 }
31
call(&self) -> TokenStream32 fn call(&self) -> TokenStream {
33 match self {
34 Operation::Size => quote! { mls_encoded_len },
35 Operation::Encode => quote! { mls_encode },
36 Operation::Decode => quote! { mls_decode },
37 }
38 }
39
extras(&self) -> TokenStream40 fn extras(&self) -> TokenStream {
41 match self {
42 Operation::Size => quote! {},
43 Operation::Encode => quote! { , writer },
44 Operation::Decode => quote! { reader },
45 }
46 }
47
is_result(&self) -> bool48 fn is_result(&self) -> bool {
49 match self {
50 Operation::Size => false,
51 Operation::Encode => true,
52 Operation::Decode => true,
53 }
54 }
55 }
56
57 #[derive(Debug, FromField)]
58 #[darling(attributes(mls_codec))]
59 struct MlsFieldReceiver {
60 ident: Option<Ident>,
61 with: Option<Path>,
62 }
63
64 impl MlsFieldReceiver {
call_tokens(&self, index: Index) -> TokenStream65 pub fn call_tokens(&self, index: Index) -> TokenStream {
66 if let Some(ref ident) = self.ident {
67 quote! { &self.#ident }
68 } else {
69 quote! { &self.#index }
70 }
71 }
72
name(&self, index: Index) -> TokenStream73 pub fn name(&self, index: Index) -> TokenStream {
74 if let Some(ref ident) = self.ident {
75 quote! {#ident: }
76 } else {
77 quote! { #index: }
78 }
79 }
80 }
81
82 #[derive(Debug, FromVariant)]
83 #[darling(attributes(mls_codec))]
84 struct MlsVariantReceiver {
85 ident: Ident,
86 discriminant: Option<Expr>,
87 fields: ast::Fields<MlsFieldReceiver>,
88 }
89
90 #[derive(FromDeriveInput)]
91 #[darling(attributes(mls_codec), forward_attrs(repr))]
92 struct MlsInputReceiver {
93 attrs: Vec<Attribute>,
94 ident: Ident,
95 generics: Generics,
96 data: ast::Data<MlsVariantReceiver, MlsFieldReceiver>,
97 }
98
99 impl MlsInputReceiver {
handle_input(&self, operation: Operation) -> TokenStream100 fn handle_input(&self, operation: Operation) -> TokenStream {
101 match self.data {
102 ast::Data::Struct(ref s) => struct_impl(s, operation),
103 ast::Data::Enum(ref e) => enum_impl(&self.ident, &self.attrs, e, operation),
104 }
105 }
106 }
107
repr_ident(attrs: &[Attribute]) -> Option<Ident>108 fn repr_ident(attrs: &[Attribute]) -> Option<Ident> {
109 let repr_path = attrs
110 .iter()
111 .filter(|attr| matches!(attr.style, syn::AttrStyle::Outer))
112 .find(|attr| attr.path().is_ident("repr"))
113 .map(|repr| repr.parse_args())
114 .transpose()
115 .ok()
116 .flatten();
117
118 let Some(Expr::Path(path)) = repr_path else {
119 return None;
120 };
121
122 path.path
123 .segments
124 .iter()
125 .find(|s| s.ident != "C")
126 .map(|path| path.ident.clone())
127 }
128
129 /// Provides the discriminant for a given variant. If the variant does not specify a suffix
130 /// and a `repr_ident` is provided, it will be appended to number.
discriminant_for_variant( variant: &MlsVariantReceiver, repr_ident: &Option<Ident>, ) -> TokenStream131 fn discriminant_for_variant(
132 variant: &MlsVariantReceiver,
133 repr_ident: &Option<Ident>,
134 ) -> TokenStream {
135 let discriminant = variant
136 .discriminant
137 .clone()
138 .expect("Enum discriminants must be explicitly defined");
139
140 let Expr::Lit(lit_expr) = &discriminant else {
141 return quote! {#discriminant};
142 };
143
144 let Lit::Int(lit_int) = &lit_expr.lit else {
145 return quote! {#discriminant};
146 };
147
148 if lit_int.suffix().is_empty() {
149 // This is dirty and there is probably a better way of doing this but I'm way too much of a noob at
150 // proc macros to pull it off...
151 // TODO: Add proper support for correctly ignoring transparent, packed and modifiers
152 let str = format!(
153 "{}{}",
154 lit_int.base10_digits(),
155 &repr_ident.clone().expect("Expected a repr(u*) to be provided or for the variant's discriminant to be defined with suffixed literals.")
156 );
157 Literal::from_str(&str)
158 .map(|l| quote! {#l})
159 .ok()
160 .unwrap_or_else(|| quote! {#discriminant})
161 } else {
162 quote! {#discriminant}
163 }
164 }
165
enum_impl( ident: &Ident, attrs: &[Attribute], variants: &[MlsVariantReceiver], operation: Operation, ) -> TokenStream166 fn enum_impl(
167 ident: &Ident,
168 attrs: &[Attribute],
169 variants: &[MlsVariantReceiver],
170 operation: Operation,
171 ) -> TokenStream {
172 let handle_error = operation.is_result().then_some(quote! { ? });
173 let path = operation.path();
174 let call = operation.call();
175 let extras = operation.extras();
176 let enum_name = &ident;
177 let repr_ident = repr_ident(attrs);
178 if matches!(operation, Operation::Decode) {
179 let cases = variants.iter().map(|variant| {
180 let variant_name = &variant.ident;
181
182 let discriminant = discriminant_for_variant(variant, &repr_ident);
183
184 // TODO: Support more than 1 field
185 match variant.fields.len() {
186 0 => quote! { #discriminant => Ok(#enum_name::#variant_name), },
187 1 =>{
188 let path = variant.fields.fields[0].with.as_ref().unwrap_or(&path);
189 quote! { #discriminant => Ok(#enum_name::#variant_name(#path::#call(#extras) #handle_error)), }
190 },
191 _ => panic!("Enum discriminants with more than 1 field are not currently supported")
192 }
193 });
194
195 return quote! {
196 let discriminant = #path::#call(#extras)#handle_error;
197
198 match discriminant {
199 #(#cases)*
200 _ => Err(mls_rs_codec::Error::UnsupportedEnumDiscriminant),
201 }
202 };
203 }
204
205 let cases = variants.iter().map(|variant| {
206 let variant_name = &variant.ident;
207
208 let discriminant = discriminant_for_variant(variant, &repr_ident);
209
210 let (parameter, field) = if variant.fields.is_empty() {
211 (None, None)
212 } else {
213 let path = variant.fields.fields[0].with.as_ref().unwrap_or(&path);
214
215 let start = match operation {
216 Operation::Size => Some(quote! { + }),
217 Operation::Encode => Some(quote! {;}),
218 Operation::Decode => None,
219 };
220
221 (
222 Some(quote! {(ref val)}),
223 Some(quote! { #start #path::#call (val #extras) #handle_error }),
224 )
225 };
226
227 let discrim = quote! { #path::#call (&#discriminant #extras) #handle_error };
228
229 quote! { #enum_name::#variant_name #parameter => { #discrim #field }}
230 });
231
232 let enum_impl = quote! {
233 match self {
234 #(#cases)*
235 }
236 };
237
238 if operation.is_result() {
239 quote! {
240 Ok(#enum_impl)
241 }
242 } else {
243 enum_impl
244 }
245 }
246
struct_impl(s: &Fields<MlsFieldReceiver>, operation: Operation) -> TokenStream247 fn struct_impl(s: &Fields<MlsFieldReceiver>, operation: Operation) -> TokenStream {
248 let recurse = s.fields.iter().enumerate().map(|(index, field)| {
249 let (call_tokens, field_name) = match operation {
250 Operation::Size | Operation::Encode => {
251 (field.call_tokens(Index::from(index)), quote! {})
252 }
253 Operation::Decode => (quote! {}, field.name(Index::from(index))),
254 };
255
256 let handle_error = operation.is_result().then_some(quote! { ? });
257 let path = field.with.clone().unwrap_or(operation.path());
258 let call = operation.call();
259 let extras = operation.extras();
260
261 quote! {
262 #field_name #path::#call (#call_tokens #extras) #handle_error
263 }
264 });
265
266 match operation {
267 Operation::Size => quote! { 0 #(+ #recurse)* },
268 Operation::Encode => quote! { #(#recurse;)* Ok(()) },
269 Operation::Decode => quote! { Ok(Self { #(#recurse,)* }) },
270 }
271 }
272
derive_impl<F>( input: proc_macro::TokenStream, trait_name: TokenStream, function_def: TokenStream, internals: F, ) -> proc_macro::TokenStream where F: FnOnce(&MlsInputReceiver) -> TokenStream,273 fn derive_impl<F>(
274 input: proc_macro::TokenStream,
275 trait_name: TokenStream,
276 function_def: TokenStream,
277 internals: F,
278 ) -> proc_macro::TokenStream
279 where
280 F: FnOnce(&MlsInputReceiver) -> TokenStream,
281 {
282 let input = parse_macro_input!(input as DeriveInput);
283
284 let input = MlsInputReceiver::from_derive_input(&input).unwrap();
285
286 let name = &input.ident;
287
288 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
289
290 // Generate an expression to sum up the heap size of each field.
291 let function_impl = internals(&input);
292
293 let expanded = quote! {
294 // The generated impl.
295 impl #impl_generics #trait_name for #name #ty_generics #where_clause {
296 #function_def {
297 #function_impl
298 }
299 }
300 };
301
302 // Hand the output tokens back to the compiler.
303 proc_macro::TokenStream::from(expanded)
304 }
305
306 #[proc_macro_derive(MlsSize, attributes(mls_codec))]
derive_size(input: proc_macro::TokenStream) -> proc_macro::TokenStream307 pub fn derive_size(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
308 let trait_name = quote! { mls_rs_codec::MlsSize };
309 let function_def = quote! {fn mls_encoded_len(&self) -> usize };
310
311 derive_impl(input, trait_name, function_def, |input| {
312 input.handle_input(Operation::Size)
313 })
314 }
315
316 #[proc_macro_derive(MlsEncode, attributes(mls_codec))]
derive_encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream317 pub fn derive_encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
318 let trait_name = quote! { mls_rs_codec::MlsEncode };
319
320 let function_def = quote! { fn mls_encode(&self, writer: &mut mls_rs_codec::Vec<u8>) -> Result<(), mls_rs_codec::Error> };
321
322 derive_impl(input, trait_name, function_def, |input| {
323 input.handle_input(Operation::Encode)
324 })
325 }
326
327 #[proc_macro_derive(MlsDecode, attributes(mls_codec))]
derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream328 pub fn derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
329 let trait_name = quote! { mls_rs_codec::MlsDecode };
330
331 let function_def =
332 quote! { fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> };
333
334 derive_impl(input, trait_name, function_def, |input| {
335 input.handle_input(Operation::Decode)
336 })
337 }
338