xref: /aosp_15_r20/external/cronet/third_party/rust/chromium_crates_io/vendor/prost-derive-0.12.4/src/lib.rs (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 #![doc(html_root_url = "https://docs.rs/prost-derive/0.12.2")]
2 // The `quote!` macro requires deep recursion.
3 #![recursion_limit = "4096"]
4 
5 extern crate alloc;
6 extern crate proc_macro;
7 
8 use anyhow::{bail, Error};
9 use itertools::Itertools;
10 use proc_macro::TokenStream;
11 use proc_macro2::Span;
12 use quote::quote;
13 use syn::{
14     punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
15     FieldsUnnamed, Ident, Index, Variant,
16 };
17 
18 mod field;
19 use crate::field::Field;
20 
try_message(input: TokenStream) -> Result<TokenStream, Error>21 fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
22     let input: DeriveInput = syn::parse(input)?;
23 
24     let ident = input.ident;
25 
26     syn::custom_keyword!(skip_debug);
27     let skip_debug = input
28         .attrs
29         .into_iter()
30         .any(|a| a.path().is_ident("prost") && a.parse_args::<skip_debug>().is_ok());
31 
32     let variant_data = match input.data {
33         Data::Struct(variant_data) => variant_data,
34         Data::Enum(..) => bail!("Message can not be derived for an enum"),
35         Data::Union(..) => bail!("Message can not be derived for a union"),
36     };
37 
38     let generics = &input.generics;
39     let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
40 
41     let (is_struct, fields) = match variant_data {
42         DataStruct {
43             fields: Fields::Named(FieldsNamed { named: fields, .. }),
44             ..
45         } => (true, fields.into_iter().collect()),
46         DataStruct {
47             fields:
48                 Fields::Unnamed(FieldsUnnamed {
49                     unnamed: fields, ..
50                 }),
51             ..
52         } => (false, fields.into_iter().collect()),
53         DataStruct {
54             fields: Fields::Unit,
55             ..
56         } => (false, Vec::new()),
57     };
58 
59     let mut next_tag: u32 = 1;
60     let mut fields = fields
61         .into_iter()
62         .enumerate()
63         .flat_map(|(i, field)| {
64             let field_ident = field.ident.map(|x| quote!(#x)).unwrap_or_else(|| {
65                 let index = Index {
66                     index: i as u32,
67                     span: Span::call_site(),
68                 };
69                 quote!(#index)
70             });
71             match Field::new(field.attrs, Some(next_tag)) {
72                 Ok(Some(field)) => {
73                     next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
74                     Some(Ok((field_ident, field)))
75                 }
76                 Ok(None) => None,
77                 Err(err) => Some(Err(
78                     err.context(format!("invalid message field {}.{}", ident, field_ident))
79                 )),
80             }
81         })
82         .collect::<Result<Vec<_>, _>>()?;
83 
84     // We want Debug to be in declaration order
85     let unsorted_fields = fields.clone();
86 
87     // Sort the fields by tag number so that fields will be encoded in tag order.
88     // TODO: This encodes oneof fields in the position of their lowest tag,
89     // regardless of the currently occupied variant, is that consequential?
90     // See: https://developers.google.com/protocol-buffers/docs/encoding#order
91     fields.sort_by_key(|&(_, ref field)| field.tags().into_iter().min().unwrap());
92     let fields = fields;
93 
94     let mut tags = fields
95         .iter()
96         .flat_map(|&(_, ref field)| field.tags())
97         .collect::<Vec<_>>();
98     let num_tags = tags.len();
99     tags.sort_unstable();
100     tags.dedup();
101     if tags.len() != num_tags {
102         bail!("message {} has fields with duplicate tags", ident);
103     }
104 
105     let encoded_len = fields
106         .iter()
107         .map(|&(ref field_ident, ref field)| field.encoded_len(quote!(self.#field_ident)));
108 
109     let encode = fields
110         .iter()
111         .map(|&(ref field_ident, ref field)| field.encode(quote!(self.#field_ident)));
112 
113     let merge = fields.iter().map(|&(ref field_ident, ref field)| {
114         let merge = field.merge(quote!(value));
115         let tags = field.tags().into_iter().map(|tag| quote!(#tag));
116         let tags = Itertools::intersperse(tags, quote!(|));
117 
118         quote! {
119             #(#tags)* => {
120                 let mut value = &mut self.#field_ident;
121                 #merge.map_err(|mut error| {
122                     error.push(STRUCT_NAME, stringify!(#field_ident));
123                     error
124                 })
125             },
126         }
127     });
128 
129     let struct_name = if fields.is_empty() {
130         quote!()
131     } else {
132         quote!(
133             const STRUCT_NAME: &'static str = stringify!(#ident);
134         )
135     };
136 
137     let clear = fields
138         .iter()
139         .map(|&(ref field_ident, ref field)| field.clear(quote!(self.#field_ident)));
140 
141     let default = if is_struct {
142         let default = fields.iter().map(|(field_ident, field)| {
143             let value = field.default();
144             quote!(#field_ident: #value,)
145         });
146         quote! {#ident {
147             #(#default)*
148         }}
149     } else {
150         let default = fields.iter().map(|(_, field)| {
151             let value = field.default();
152             quote!(#value,)
153         });
154         quote! {#ident (
155             #(#default)*
156         )}
157     };
158 
159     let methods = fields
160         .iter()
161         .flat_map(|&(ref field_ident, ref field)| field.methods(field_ident))
162         .collect::<Vec<_>>();
163     let methods = if methods.is_empty() {
164         quote!()
165     } else {
166         quote! {
167             #[allow(dead_code)]
168             impl #impl_generics #ident #ty_generics #where_clause {
169                 #(#methods)*
170             }
171         }
172     };
173 
174     let expanded = quote! {
175         impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause {
176             #[allow(unused_variables)]
177             fn encode_raw<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
178                 #(#encode)*
179             }
180 
181             #[allow(unused_variables)]
182             fn merge_field<B>(
183                 &mut self,
184                 tag: u32,
185                 wire_type: ::prost::encoding::WireType,
186                 buf: &mut B,
187                 ctx: ::prost::encoding::DecodeContext,
188             ) -> ::core::result::Result<(), ::prost::DecodeError>
189             where B: ::prost::bytes::Buf {
190                 #struct_name
191                 match tag {
192                     #(#merge)*
193                     _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
194                 }
195             }
196 
197             #[inline]
198             fn encoded_len(&self) -> usize {
199                 0 #(+ #encoded_len)*
200             }
201 
202             fn clear(&mut self) {
203                 #(#clear;)*
204             }
205         }
206 
207         impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
208             fn default() -> Self {
209                 #default
210             }
211         }
212     };
213     let expanded = if skip_debug {
214         expanded
215     } else {
216         let debugs = unsorted_fields.iter().map(|&(ref field_ident, ref field)| {
217             let wrapper = field.debug(quote!(self.#field_ident));
218             let call = if is_struct {
219                 quote!(builder.field(stringify!(#field_ident), &wrapper))
220             } else {
221                 quote!(builder.field(&wrapper))
222             };
223             quote! {
224                  let builder = {
225                      let wrapper = #wrapper;
226                      #call
227                  };
228             }
229         });
230         let debug_builder = if is_struct {
231             quote!(f.debug_struct(stringify!(#ident)))
232         } else {
233             quote!(f.debug_tuple(stringify!(#ident)))
234         };
235         quote! {
236             #expanded
237 
238             impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
239                 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
240                     let mut builder = #debug_builder;
241                     #(#debugs;)*
242                     builder.finish()
243                 }
244             }
245         }
246     };
247 
248     let expanded = quote! {
249         #expanded
250 
251         #methods
252     };
253 
254     Ok(expanded.into())
255 }
256 
257 #[proc_macro_derive(Message, attributes(prost))]
message(input: TokenStream) -> TokenStream258 pub fn message(input: TokenStream) -> TokenStream {
259     try_message(input).unwrap()
260 }
261 
try_enumeration(input: TokenStream) -> Result<TokenStream, Error>262 fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
263     let input: DeriveInput = syn::parse(input)?;
264     let ident = input.ident;
265 
266     let generics = &input.generics;
267     let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
268 
269     let punctuated_variants = match input.data {
270         Data::Enum(DataEnum { variants, .. }) => variants,
271         Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
272         Data::Union(..) => bail!("Enumeration can not be derived for a union"),
273     };
274 
275     // Map the variants into 'fields'.
276     let mut variants: Vec<(Ident, Expr)> = Vec::new();
277     for Variant {
278         ident,
279         fields,
280         discriminant,
281         ..
282     } in punctuated_variants
283     {
284         match fields {
285             Fields::Unit => (),
286             Fields::Named(_) | Fields::Unnamed(_) => {
287                 bail!("Enumeration variants may not have fields")
288             }
289         }
290 
291         match discriminant {
292             Some((_, expr)) => variants.push((ident, expr)),
293             None => bail!("Enumeration variants must have a discriminant"),
294         }
295     }
296 
297     if variants.is_empty() {
298         panic!("Enumeration must have at least one variant");
299     }
300 
301     let default = variants[0].0.clone();
302 
303     let is_valid = variants
304         .iter()
305         .map(|&(_, ref value)| quote!(#value => true));
306     let from = variants.iter().map(
307         |&(ref variant, ref value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)),
308     );
309 
310     let try_from = variants.iter().map(
311         |&(ref variant, ref value)| quote!(#value => ::core::result::Result::Ok(#ident::#variant)),
312     );
313 
314     let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident);
315     let from_i32_doc = format!(
316         "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.",
317         ident
318     );
319 
320     let expanded = quote! {
321         impl #impl_generics #ident #ty_generics #where_clause {
322             #[doc=#is_valid_doc]
323             pub fn is_valid(value: i32) -> bool {
324                 match value {
325                     #(#is_valid,)*
326                     _ => false,
327                 }
328             }
329 
330             #[deprecated = "Use the TryFrom<i32> implementation instead"]
331             #[doc=#from_i32_doc]
332             pub fn from_i32(value: i32) -> ::core::option::Option<#ident> {
333                 match value {
334                     #(#from,)*
335                     _ => ::core::option::Option::None,
336                 }
337             }
338         }
339 
340         impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
341             fn default() -> #ident {
342                 #ident::#default
343             }
344         }
345 
346         impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause {
347             fn from(value: #ident) -> i32 {
348                 value as i32
349             }
350         }
351 
352         impl #impl_generics ::core::convert::TryFrom::<i32> for #ident #ty_generics #where_clause {
353             type Error = ::prost::DecodeError;
354 
355             fn try_from(value: i32) -> ::core::result::Result<#ident, ::prost::DecodeError> {
356                 match value {
357                     #(#try_from,)*
358                     _ => ::core::result::Result::Err(::prost::DecodeError::new("invalid enumeration value")),
359                 }
360             }
361         }
362     };
363 
364     Ok(expanded.into())
365 }
366 
367 #[proc_macro_derive(Enumeration, attributes(prost))]
enumeration(input: TokenStream) -> TokenStream368 pub fn enumeration(input: TokenStream) -> TokenStream {
369     try_enumeration(input).unwrap()
370 }
371 
try_oneof(input: TokenStream) -> Result<TokenStream, Error>372 fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
373     let input: DeriveInput = syn::parse(input)?;
374 
375     let ident = input.ident;
376 
377     syn::custom_keyword!(skip_debug);
378     let skip_debug = input
379         .attrs
380         .into_iter()
381         .any(|a| a.path().is_ident("prost") && a.parse_args::<skip_debug>().is_ok());
382 
383     let variants = match input.data {
384         Data::Enum(DataEnum { variants, .. }) => variants,
385         Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
386         Data::Union(..) => bail!("Oneof can not be derived for a union"),
387     };
388 
389     let generics = &input.generics;
390     let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
391 
392     // Map the variants into 'fields'.
393     let mut fields: Vec<(Ident, Field)> = Vec::new();
394     for Variant {
395         attrs,
396         ident: variant_ident,
397         fields: variant_fields,
398         ..
399     } in variants
400     {
401         let variant_fields = match variant_fields {
402             Fields::Unit => Punctuated::new(),
403             Fields::Named(FieldsNamed { named: fields, .. })
404             | Fields::Unnamed(FieldsUnnamed {
405                 unnamed: fields, ..
406             }) => fields,
407         };
408         if variant_fields.len() != 1 {
409             bail!("Oneof enum variants must have a single field");
410         }
411         match Field::new_oneof(attrs)? {
412             Some(field) => fields.push((variant_ident, field)),
413             None => bail!("invalid oneof variant: oneof variants may not be ignored"),
414         }
415     }
416 
417     let mut tags = fields
418         .iter()
419         .flat_map(|&(ref variant_ident, ref field)| -> Result<u32, Error> {
420             if field.tags().len() > 1 {
421                 bail!(
422                     "invalid oneof variant {}::{}: oneof variants may only have a single tag",
423                     ident,
424                     variant_ident
425                 );
426             }
427             Ok(field.tags()[0])
428         })
429         .collect::<Vec<_>>();
430     tags.sort_unstable();
431     tags.dedup();
432     if tags.len() != fields.len() {
433         panic!("invalid oneof {}: variants have duplicate tags", ident);
434     }
435 
436     let encode = fields.iter().map(|&(ref variant_ident, ref field)| {
437         let encode = field.encode(quote!(*value));
438         quote!(#ident::#variant_ident(ref value) => { #encode })
439     });
440 
441     let merge = fields.iter().map(|&(ref variant_ident, ref field)| {
442         let tag = field.tags()[0];
443         let merge = field.merge(quote!(value));
444         quote! {
445             #tag => {
446                 match field {
447                     ::core::option::Option::Some(#ident::#variant_ident(ref mut value)) => {
448                         #merge
449                     },
450                     _ => {
451                         let mut owned_value = ::core::default::Default::default();
452                         let value = &mut owned_value;
453                         #merge.map(|_| *field = ::core::option::Option::Some(#ident::#variant_ident(owned_value)))
454                     },
455                 }
456             }
457         }
458     });
459 
460     let encoded_len = fields.iter().map(|&(ref variant_ident, ref field)| {
461         let encoded_len = field.encoded_len(quote!(*value));
462         quote!(#ident::#variant_ident(ref value) => #encoded_len)
463     });
464 
465     let expanded = quote! {
466         impl #impl_generics #ident #ty_generics #where_clause {
467             /// Encodes the message to a buffer.
468             pub fn encode<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
469                 match *self {
470                     #(#encode,)*
471                 }
472             }
473 
474             /// Decodes an instance of the message from a buffer, and merges it into self.
475             pub fn merge<B>(
476                 field: &mut ::core::option::Option<#ident #ty_generics>,
477                 tag: u32,
478                 wire_type: ::prost::encoding::WireType,
479                 buf: &mut B,
480                 ctx: ::prost::encoding::DecodeContext,
481             ) -> ::core::result::Result<(), ::prost::DecodeError>
482             where B: ::prost::bytes::Buf {
483                 match tag {
484                     #(#merge,)*
485                     _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
486                 }
487             }
488 
489             /// Returns the encoded length of the message without a length delimiter.
490             #[inline]
491             pub fn encoded_len(&self) -> usize {
492                 match *self {
493                     #(#encoded_len,)*
494                 }
495             }
496         }
497 
498     };
499     let expanded = if skip_debug {
500         expanded
501     } else {
502         let debug = fields.iter().map(|&(ref variant_ident, ref field)| {
503             let wrapper = field.debug(quote!(*value));
504             quote!(#ident::#variant_ident(ref value) => {
505                 let wrapper = #wrapper;
506                 f.debug_tuple(stringify!(#variant_ident))
507                     .field(&wrapper)
508                     .finish()
509             })
510         });
511         quote! {
512             #expanded
513 
514             impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
515                 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
516                     match *self {
517                         #(#debug,)*
518                     }
519                 }
520             }
521         }
522     };
523 
524     Ok(expanded.into())
525 }
526 
527 #[proc_macro_derive(Oneof, attributes(prost))]
oneof(input: TokenStream) -> TokenStream528 pub fn oneof(input: TokenStream) -> TokenStream {
529     try_oneof(input).unwrap()
530 }
531