1 #![doc(html_root_url = "https://docs.rs/prost-derive/0.12.2")]
2 // The `quote!` macro requires deep recursion.
3 #![recursion_limit = "4096"]
4
5 extern crate alloc;
6 extern crate proc_macro;
7
8 use anyhow::{bail, Error};
9 use itertools::Itertools;
10 use proc_macro::TokenStream;
11 use proc_macro2::Span;
12 use quote::quote;
13 use syn::{
14 punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
15 FieldsUnnamed, Ident, Index, Variant,
16 };
17
18 mod field;
19 use crate::field::Field;
20
try_message(input: TokenStream) -> Result<TokenStream, Error>21 fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
22 let input: DeriveInput = syn::parse(input)?;
23
24 let ident = input.ident;
25
26 syn::custom_keyword!(skip_debug);
27 let skip_debug = input
28 .attrs
29 .into_iter()
30 .any(|a| a.path().is_ident("prost") && a.parse_args::<skip_debug>().is_ok());
31
32 let variant_data = match input.data {
33 Data::Struct(variant_data) => variant_data,
34 Data::Enum(..) => bail!("Message can not be derived for an enum"),
35 Data::Union(..) => bail!("Message can not be derived for a union"),
36 };
37
38 let generics = &input.generics;
39 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
40
41 let (is_struct, fields) = match variant_data {
42 DataStruct {
43 fields: Fields::Named(FieldsNamed { named: fields, .. }),
44 ..
45 } => (true, fields.into_iter().collect()),
46 DataStruct {
47 fields:
48 Fields::Unnamed(FieldsUnnamed {
49 unnamed: fields, ..
50 }),
51 ..
52 } => (false, fields.into_iter().collect()),
53 DataStruct {
54 fields: Fields::Unit,
55 ..
56 } => (false, Vec::new()),
57 };
58
59 let mut next_tag: u32 = 1;
60 let mut fields = fields
61 .into_iter()
62 .enumerate()
63 .flat_map(|(i, field)| {
64 let field_ident = field.ident.map(|x| quote!(#x)).unwrap_or_else(|| {
65 let index = Index {
66 index: i as u32,
67 span: Span::call_site(),
68 };
69 quote!(#index)
70 });
71 match Field::new(field.attrs, Some(next_tag)) {
72 Ok(Some(field)) => {
73 next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
74 Some(Ok((field_ident, field)))
75 }
76 Ok(None) => None,
77 Err(err) => Some(Err(
78 err.context(format!("invalid message field {}.{}", ident, field_ident))
79 )),
80 }
81 })
82 .collect::<Result<Vec<_>, _>>()?;
83
84 // We want Debug to be in declaration order
85 let unsorted_fields = fields.clone();
86
87 // Sort the fields by tag number so that fields will be encoded in tag order.
88 // TODO: This encodes oneof fields in the position of their lowest tag,
89 // regardless of the currently occupied variant, is that consequential?
90 // See: https://developers.google.com/protocol-buffers/docs/encoding#order
91 fields.sort_by_key(|&(_, ref field)| field.tags().into_iter().min().unwrap());
92 let fields = fields;
93
94 let mut tags = fields
95 .iter()
96 .flat_map(|&(_, ref field)| field.tags())
97 .collect::<Vec<_>>();
98 let num_tags = tags.len();
99 tags.sort_unstable();
100 tags.dedup();
101 if tags.len() != num_tags {
102 bail!("message {} has fields with duplicate tags", ident);
103 }
104
105 let encoded_len = fields
106 .iter()
107 .map(|&(ref field_ident, ref field)| field.encoded_len(quote!(self.#field_ident)));
108
109 let encode = fields
110 .iter()
111 .map(|&(ref field_ident, ref field)| field.encode(quote!(self.#field_ident)));
112
113 let merge = fields.iter().map(|&(ref field_ident, ref field)| {
114 let merge = field.merge(quote!(value));
115 let tags = field.tags().into_iter().map(|tag| quote!(#tag));
116 let tags = Itertools::intersperse(tags, quote!(|));
117
118 quote! {
119 #(#tags)* => {
120 let mut value = &mut self.#field_ident;
121 #merge.map_err(|mut error| {
122 error.push(STRUCT_NAME, stringify!(#field_ident));
123 error
124 })
125 },
126 }
127 });
128
129 let struct_name = if fields.is_empty() {
130 quote!()
131 } else {
132 quote!(
133 const STRUCT_NAME: &'static str = stringify!(#ident);
134 )
135 };
136
137 let clear = fields
138 .iter()
139 .map(|&(ref field_ident, ref field)| field.clear(quote!(self.#field_ident)));
140
141 let default = if is_struct {
142 let default = fields.iter().map(|(field_ident, field)| {
143 let value = field.default();
144 quote!(#field_ident: #value,)
145 });
146 quote! {#ident {
147 #(#default)*
148 }}
149 } else {
150 let default = fields.iter().map(|(_, field)| {
151 let value = field.default();
152 quote!(#value,)
153 });
154 quote! {#ident (
155 #(#default)*
156 )}
157 };
158
159 let methods = fields
160 .iter()
161 .flat_map(|&(ref field_ident, ref field)| field.methods(field_ident))
162 .collect::<Vec<_>>();
163 let methods = if methods.is_empty() {
164 quote!()
165 } else {
166 quote! {
167 #[allow(dead_code)]
168 impl #impl_generics #ident #ty_generics #where_clause {
169 #(#methods)*
170 }
171 }
172 };
173
174 let expanded = quote! {
175 impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause {
176 #[allow(unused_variables)]
177 fn encode_raw<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
178 #(#encode)*
179 }
180
181 #[allow(unused_variables)]
182 fn merge_field<B>(
183 &mut self,
184 tag: u32,
185 wire_type: ::prost::encoding::WireType,
186 buf: &mut B,
187 ctx: ::prost::encoding::DecodeContext,
188 ) -> ::core::result::Result<(), ::prost::DecodeError>
189 where B: ::prost::bytes::Buf {
190 #struct_name
191 match tag {
192 #(#merge)*
193 _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
194 }
195 }
196
197 #[inline]
198 fn encoded_len(&self) -> usize {
199 0 #(+ #encoded_len)*
200 }
201
202 fn clear(&mut self) {
203 #(#clear;)*
204 }
205 }
206
207 impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
208 fn default() -> Self {
209 #default
210 }
211 }
212 };
213 let expanded = if skip_debug {
214 expanded
215 } else {
216 let debugs = unsorted_fields.iter().map(|&(ref field_ident, ref field)| {
217 let wrapper = field.debug(quote!(self.#field_ident));
218 let call = if is_struct {
219 quote!(builder.field(stringify!(#field_ident), &wrapper))
220 } else {
221 quote!(builder.field(&wrapper))
222 };
223 quote! {
224 let builder = {
225 let wrapper = #wrapper;
226 #call
227 };
228 }
229 });
230 let debug_builder = if is_struct {
231 quote!(f.debug_struct(stringify!(#ident)))
232 } else {
233 quote!(f.debug_tuple(stringify!(#ident)))
234 };
235 quote! {
236 #expanded
237
238 impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
239 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
240 let mut builder = #debug_builder;
241 #(#debugs;)*
242 builder.finish()
243 }
244 }
245 }
246 };
247
248 let expanded = quote! {
249 #expanded
250
251 #methods
252 };
253
254 Ok(expanded.into())
255 }
256
257 #[proc_macro_derive(Message, attributes(prost))]
message(input: TokenStream) -> TokenStream258 pub fn message(input: TokenStream) -> TokenStream {
259 try_message(input).unwrap()
260 }
261
try_enumeration(input: TokenStream) -> Result<TokenStream, Error>262 fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
263 let input: DeriveInput = syn::parse(input)?;
264 let ident = input.ident;
265
266 let generics = &input.generics;
267 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
268
269 let punctuated_variants = match input.data {
270 Data::Enum(DataEnum { variants, .. }) => variants,
271 Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
272 Data::Union(..) => bail!("Enumeration can not be derived for a union"),
273 };
274
275 // Map the variants into 'fields'.
276 let mut variants: Vec<(Ident, Expr)> = Vec::new();
277 for Variant {
278 ident,
279 fields,
280 discriminant,
281 ..
282 } in punctuated_variants
283 {
284 match fields {
285 Fields::Unit => (),
286 Fields::Named(_) | Fields::Unnamed(_) => {
287 bail!("Enumeration variants may not have fields")
288 }
289 }
290
291 match discriminant {
292 Some((_, expr)) => variants.push((ident, expr)),
293 None => bail!("Enumeration variants must have a discriminant"),
294 }
295 }
296
297 if variants.is_empty() {
298 panic!("Enumeration must have at least one variant");
299 }
300
301 let default = variants[0].0.clone();
302
303 let is_valid = variants
304 .iter()
305 .map(|&(_, ref value)| quote!(#value => true));
306 let from = variants.iter().map(
307 |&(ref variant, ref value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)),
308 );
309
310 let try_from = variants.iter().map(
311 |&(ref variant, ref value)| quote!(#value => ::core::result::Result::Ok(#ident::#variant)),
312 );
313
314 let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident);
315 let from_i32_doc = format!(
316 "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.",
317 ident
318 );
319
320 let expanded = quote! {
321 impl #impl_generics #ident #ty_generics #where_clause {
322 #[doc=#is_valid_doc]
323 pub fn is_valid(value: i32) -> bool {
324 match value {
325 #(#is_valid,)*
326 _ => false,
327 }
328 }
329
330 #[deprecated = "Use the TryFrom<i32> implementation instead"]
331 #[doc=#from_i32_doc]
332 pub fn from_i32(value: i32) -> ::core::option::Option<#ident> {
333 match value {
334 #(#from,)*
335 _ => ::core::option::Option::None,
336 }
337 }
338 }
339
340 impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
341 fn default() -> #ident {
342 #ident::#default
343 }
344 }
345
346 impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause {
347 fn from(value: #ident) -> i32 {
348 value as i32
349 }
350 }
351
352 impl #impl_generics ::core::convert::TryFrom::<i32> for #ident #ty_generics #where_clause {
353 type Error = ::prost::DecodeError;
354
355 fn try_from(value: i32) -> ::core::result::Result<#ident, ::prost::DecodeError> {
356 match value {
357 #(#try_from,)*
358 _ => ::core::result::Result::Err(::prost::DecodeError::new("invalid enumeration value")),
359 }
360 }
361 }
362 };
363
364 Ok(expanded.into())
365 }
366
367 #[proc_macro_derive(Enumeration, attributes(prost))]
enumeration(input: TokenStream) -> TokenStream368 pub fn enumeration(input: TokenStream) -> TokenStream {
369 try_enumeration(input).unwrap()
370 }
371
try_oneof(input: TokenStream) -> Result<TokenStream, Error>372 fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
373 let input: DeriveInput = syn::parse(input)?;
374
375 let ident = input.ident;
376
377 syn::custom_keyword!(skip_debug);
378 let skip_debug = input
379 .attrs
380 .into_iter()
381 .any(|a| a.path().is_ident("prost") && a.parse_args::<skip_debug>().is_ok());
382
383 let variants = match input.data {
384 Data::Enum(DataEnum { variants, .. }) => variants,
385 Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
386 Data::Union(..) => bail!("Oneof can not be derived for a union"),
387 };
388
389 let generics = &input.generics;
390 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
391
392 // Map the variants into 'fields'.
393 let mut fields: Vec<(Ident, Field)> = Vec::new();
394 for Variant {
395 attrs,
396 ident: variant_ident,
397 fields: variant_fields,
398 ..
399 } in variants
400 {
401 let variant_fields = match variant_fields {
402 Fields::Unit => Punctuated::new(),
403 Fields::Named(FieldsNamed { named: fields, .. })
404 | Fields::Unnamed(FieldsUnnamed {
405 unnamed: fields, ..
406 }) => fields,
407 };
408 if variant_fields.len() != 1 {
409 bail!("Oneof enum variants must have a single field");
410 }
411 match Field::new_oneof(attrs)? {
412 Some(field) => fields.push((variant_ident, field)),
413 None => bail!("invalid oneof variant: oneof variants may not be ignored"),
414 }
415 }
416
417 let mut tags = fields
418 .iter()
419 .flat_map(|&(ref variant_ident, ref field)| -> Result<u32, Error> {
420 if field.tags().len() > 1 {
421 bail!(
422 "invalid oneof variant {}::{}: oneof variants may only have a single tag",
423 ident,
424 variant_ident
425 );
426 }
427 Ok(field.tags()[0])
428 })
429 .collect::<Vec<_>>();
430 tags.sort_unstable();
431 tags.dedup();
432 if tags.len() != fields.len() {
433 panic!("invalid oneof {}: variants have duplicate tags", ident);
434 }
435
436 let encode = fields.iter().map(|&(ref variant_ident, ref field)| {
437 let encode = field.encode(quote!(*value));
438 quote!(#ident::#variant_ident(ref value) => { #encode })
439 });
440
441 let merge = fields.iter().map(|&(ref variant_ident, ref field)| {
442 let tag = field.tags()[0];
443 let merge = field.merge(quote!(value));
444 quote! {
445 #tag => {
446 match field {
447 ::core::option::Option::Some(#ident::#variant_ident(ref mut value)) => {
448 #merge
449 },
450 _ => {
451 let mut owned_value = ::core::default::Default::default();
452 let value = &mut owned_value;
453 #merge.map(|_| *field = ::core::option::Option::Some(#ident::#variant_ident(owned_value)))
454 },
455 }
456 }
457 }
458 });
459
460 let encoded_len = fields.iter().map(|&(ref variant_ident, ref field)| {
461 let encoded_len = field.encoded_len(quote!(*value));
462 quote!(#ident::#variant_ident(ref value) => #encoded_len)
463 });
464
465 let expanded = quote! {
466 impl #impl_generics #ident #ty_generics #where_clause {
467 /// Encodes the message to a buffer.
468 pub fn encode<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
469 match *self {
470 #(#encode,)*
471 }
472 }
473
474 /// Decodes an instance of the message from a buffer, and merges it into self.
475 pub fn merge<B>(
476 field: &mut ::core::option::Option<#ident #ty_generics>,
477 tag: u32,
478 wire_type: ::prost::encoding::WireType,
479 buf: &mut B,
480 ctx: ::prost::encoding::DecodeContext,
481 ) -> ::core::result::Result<(), ::prost::DecodeError>
482 where B: ::prost::bytes::Buf {
483 match tag {
484 #(#merge,)*
485 _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
486 }
487 }
488
489 /// Returns the encoded length of the message without a length delimiter.
490 #[inline]
491 pub fn encoded_len(&self) -> usize {
492 match *self {
493 #(#encoded_len,)*
494 }
495 }
496 }
497
498 };
499 let expanded = if skip_debug {
500 expanded
501 } else {
502 let debug = fields.iter().map(|&(ref variant_ident, ref field)| {
503 let wrapper = field.debug(quote!(*value));
504 quote!(#ident::#variant_ident(ref value) => {
505 let wrapper = #wrapper;
506 f.debug_tuple(stringify!(#variant_ident))
507 .field(&wrapper)
508 .finish()
509 })
510 });
511 quote! {
512 #expanded
513
514 impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
515 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
516 match *self {
517 #(#debug,)*
518 }
519 }
520 }
521 }
522 };
523
524 Ok(expanded.into())
525 }
526
527 #[proc_macro_derive(Oneof, attributes(prost))]
oneof(input: TokenStream) -> TokenStream528 pub fn oneof(input: TokenStream) -> TokenStream {
529 try_oneof(input).unwrap()
530 }
531