1 // vim: tw=80
2 use super::*;
3 
4 use quote::ToTokens;
5 use std::collections::HashSet;
6 
7 use crate::{
8     mock_function::MockFunction,
9     mock_trait::MockTrait
10 };
11 
phantom_default_inits(generics: &Generics) -> Vec<TokenStream>12 fn phantom_default_inits(generics: &Generics) -> Vec<TokenStream> {
13     generics.params
14     .iter()
15     .enumerate()
16     .map(|(count, _param)| {
17         let phident = format_ident!("_t{count}");
18         quote!(#phident: ::std::marker::PhantomData)
19     }).collect()
20 }
21 
22 /// Generate any PhantomData field definitions
phantom_fields(generics: &Generics) -> Vec<TokenStream>23 fn phantom_fields(generics: &Generics) -> Vec<TokenStream> {
24     generics.params
25     .iter()
26     .enumerate()
27     .filter_map(|(count, param)| {
28         let phident = format_ident!("_t{count}");
29         match param {
30             syn::GenericParam::Lifetime(l) => {
31                 if !l.bounds.is_empty() {
32                     compile_error(l.bounds.span(),
33                         "#automock does not yet support lifetime bounds on structs");
34                 }
35                 let lifetime = &l.lifetime;
36                 Some(
37                 quote!(#phident: ::std::marker::PhantomData<&#lifetime ()>)
38                 )
39             },
40             syn::GenericParam::Type(tp) => {
41                 let ty = &tp.ident;
42                 Some(
43                 quote!(#phident: ::std::marker::PhantomData<#ty>)
44                 )
45             },
46             syn::GenericParam::Const(_) => {
47                 compile_error(param.span(),
48                     "#automock does not yet support generic constants");
49                 None
50             }
51         }
52     }).collect()
53 }
54 
55 /// Filter out multiple copies of the same trait, even if they're implemented on
56 /// different types.  But allow them if they have different attributes, which
57 /// probably indicates that they aren't meant to be compiled together.
unique_trait_iter<'a, I: Iterator<Item = &'a MockTrait>>(i: I) -> impl Iterator<Item = &'a MockTrait>58 fn unique_trait_iter<'a, I: Iterator<Item = &'a MockTrait>>(i: I)
59     -> impl Iterator<Item = &'a MockTrait>
60 {
61     let mut hs = HashSet::<(Path, Vec<Attribute>)>::default();
62     i.filter(move |mt| {
63         let impl_attrs = AttrFormatter::new(&mt.attrs)
64             .async_trait(false)
65             .doc(false)
66             .format();
67         let key = (mt.trait_path.clone(), impl_attrs);
68         if hs.contains(&key) {
69             false
70         } else {
71             hs.insert(key);
72             true
73         }
74     })
75 }
76 
77 /// A collection of methods defined in one spot
78 struct Methods(Vec<MockFunction>);
79 
80 impl Methods {
81     /// Are all of these methods static?
all_static(&self) -> bool82     fn all_static(&self) -> bool {
83         self.0.iter()
84             .all(|meth| meth.is_static())
85     }
86 
checkpoints(&self) -> Vec<impl ToTokens>87     fn checkpoints(&self) -> Vec<impl ToTokens> {
88         self.0.iter()
89             .filter(|meth| !meth.is_static())
90             .map(|meth| meth.checkpoint())
91             .collect::<Vec<_>>()
92     }
93 
94     /// Return a fragment of code to initialize struct fields during default()
default_inits(&self) -> Vec<TokenStream>95     fn default_inits(&self) -> Vec<TokenStream> {
96         self.0.iter()
97             .filter(|meth| !meth.is_static())
98             .map(|meth| {
99                 let name = meth.name();
100                 let attrs = AttrFormatter::new(&meth.attrs)
101                     .doc(false)
102                     .format();
103                 quote!(#(#attrs)* #name: Default::default())
104             }).collect::<Vec<_>>()
105     }
106 
field_definitions(&self, modname: &Ident) -> Vec<TokenStream>107     fn field_definitions(&self, modname: &Ident) -> Vec<TokenStream> {
108         self.0.iter()
109             .filter(|meth| !meth.is_static())
110             .map(|meth| meth.field_definition(Some(modname)))
111             .collect::<Vec<_>>()
112     }
113 
priv_mods(&self) -> Vec<impl ToTokens>114     fn priv_mods(&self) -> Vec<impl ToTokens> {
115         self.0.iter()
116             .map(|meth| meth.priv_module())
117             .collect::<Vec<_>>()
118     }
119 }
120 
121 pub(crate) struct MockItemStruct {
122     attrs: Vec<Attribute>,
123     consts: Vec<ImplItemConst>,
124     generics: Generics,
125     /// Should Mockall generate a Debug implementation?
126     auto_debug: bool,
127     /// Does the original struct have a `new` method?
128     has_new: bool,
129     /// Inherent methods of the mock struct
130     methods: Methods,
131     /// Name of the overall module that holds all of the mock stuff
132     modname: Ident,
133     name: Ident,
134     /// Is this a whole MockStruct or just a substructure for a trait impl?
135     traits: Vec<MockTrait>,
136     vis: Visibility,
137 }
138 
139 impl MockItemStruct {
debug_impl(&self) -> impl ToTokens140     fn debug_impl(&self) -> impl ToTokens {
141         if self.auto_debug {
142             let (ig, tg, wc) = self.generics.split_for_impl();
143             let struct_name = &self.name;
144             let struct_name_str = format!("{}", self.name);
145             quote!(
146                 impl #ig ::std::fmt::Debug for #struct_name #tg #wc {
147                     fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>)
148                         -> ::std::result::Result<(), std::fmt::Error>
149                     {
150                         f.debug_struct(#struct_name_str).finish()
151                     }
152                 }
153             )
154         } else {
155             quote!()
156         }
157     }
158 
new_method(&self) -> impl ToTokens159     fn new_method(&self) -> impl ToTokens {
160         if self.has_new {
161             TokenStream::new()
162         } else {
163             quote!(
164                 /// Create a new mock object with no expectations.
165                 ///
166                 /// This method will not be generated if the real struct
167                 /// already has a `new` method.  However, it *will* be
168                 /// generated if the struct implements a trait with a `new`
169                 /// method.  The trait's `new` method can still be called
170                 /// like `<MockX as TraitY>::new`
171                 pub fn new() -> Self {
172                     Self::default()
173                 }
174             )
175         }
176     }
177 
phantom_default_inits(&self) -> Vec<TokenStream>178     fn phantom_default_inits(&self) -> Vec<TokenStream> {
179         phantom_default_inits(&self.generics)
180     }
181 
phantom_fields(&self) -> Vec<TokenStream>182     fn phantom_fields(&self) -> Vec<TokenStream> {
183         phantom_fields(&self.generics)
184     }
185 }
186 
187 impl From<MockableStruct> for MockItemStruct {
from(mockable: MockableStruct) -> MockItemStruct188     fn from(mockable: MockableStruct) -> MockItemStruct {
189         let auto_debug = mockable.derives_debug();
190         let modname = gen_mod_ident(&mockable.name, None);
191         let generics = mockable.generics.clone();
192         let struct_name = &mockable.name;
193         let vis = mockable.vis;
194         let has_new = mockable.methods.iter()
195             .any(|meth| meth.sig.ident == "new") ||
196             mockable.impls.iter()
197             .any(|impl_|
198                 impl_.items.iter()
199                     .any(|ii| if let ImplItem::Fn(iif) = ii {
200                             iif.sig.ident == "new"
201                         } else {
202                             false
203                         }
204                     )
205             );
206         let methods = Methods(mockable.methods.into_iter()
207             .map(|meth|
208                 mock_function::Builder::new(&meth.sig, &meth.vis)
209                     .attrs(&meth.attrs)
210                     .struct_(struct_name)
211                     .struct_generics(&generics)
212                     .levels(2)
213                     .call_levels(0)
214                     .build()
215             ).collect::<Vec<_>>());
216         let structname = &mockable.name;
217         let traits = mockable.impls.into_iter()
218             .map(|i| MockTrait::new(structname, &generics, i, &vis))
219             .collect();
220 
221         MockItemStruct {
222             attrs: mockable.attrs,
223             auto_debug,
224             consts: mockable.consts,
225             generics,
226             has_new,
227             methods,
228             modname,
229             name: mockable.name,
230             traits,
231             vis
232         }
233     }
234 }
235 
236 impl ToTokens for MockItemStruct {
to_tokens(&self, tokens: &mut TokenStream)237     fn to_tokens(&self, tokens: &mut TokenStream) {
238         let attrs = AttrFormatter::new(&self.attrs)
239             .async_trait(false)
240             .format();
241         let consts = &self.consts;
242         let debug_impl = self.debug_impl();
243         let struct_name = &self.name;
244         let (ig, tg, wc) = self.generics.split_for_impl();
245         let modname = &self.modname;
246         let calls = self.methods.0.iter()
247             .map(|meth| meth.call(Some(modname)))
248             .collect::<Vec<_>>();
249         let contexts = self.methods.0.iter()
250             .filter(|meth| meth.is_static())
251             .map(|meth| meth.context_fn(Some(modname)))
252             .collect::<Vec<_>>();
253         let expects = self.methods.0.iter()
254             .filter(|meth| !meth.is_static())
255             .map(|meth| meth.expect(modname, None))
256             .collect::<Vec<_>>();
257         let method_checkpoints = self.methods.checkpoints();
258         let new_method = self.new_method();
259         let priv_mods = self.methods.priv_mods();
260         let substructs = unique_trait_iter(self.traits.iter())
261             .map(|trait_| {
262                 MockItemTraitImpl {
263                     attrs: trait_.attrs.clone(),
264                     generics: self.generics.clone(),
265                     fieldname: format_ident!("{}_expectations",
266                                              trait_.ss_name()),
267                     methods: Methods(trait_.methods.clone()),
268                     modname: format_ident!("{}_{}", &self.modname,
269                                            trait_.ss_name()),
270                     name: format_ident!("{}_{}", &self.name, trait_.ss_name()),
271                 }
272             }).collect::<Vec<_>>();
273         let substruct_expectations = substructs.iter()
274             .filter(|ss| !ss.all_static())
275             .map(|ss| {
276                 let attrs = AttrFormatter::new(&ss.attrs)
277                     .async_trait(false)
278                     .doc(false)
279                     .format();
280                 let fieldname = &ss.fieldname;
281                 quote!(#(#attrs)* self.#fieldname.checkpoint();)
282             }).collect::<Vec<_>>();
283         let mut field_definitions = substructs.iter()
284             .filter(|ss| !ss.all_static())
285             .map(|ss| {
286                 let attrs = AttrFormatter::new(&ss.attrs)
287                     .async_trait(false)
288                     .doc(false)
289                     .format();
290                 let fieldname = &ss.fieldname;
291                 let tyname = &ss.name;
292                 quote!(#(#attrs)* #fieldname: #tyname #tg)
293             }).collect::<Vec<_>>();
294         field_definitions.extend(self.methods.field_definitions(modname));
295         field_definitions.extend(self.phantom_fields());
296         let mut default_inits = substructs.iter()
297             .filter(|ss| !ss.all_static())
298             .map(|ss| {
299                 let attrs = AttrFormatter::new(&ss.attrs)
300                     .async_trait(false)
301                     .doc(false)
302                     .format();
303                 let fieldname = &ss.fieldname;
304                 quote!(#(#attrs)* #fieldname: Default::default())
305             }).collect::<Vec<_>>();
306         default_inits.extend(self.methods.default_inits());
307         default_inits.extend(self.phantom_default_inits());
308         let trait_impls = self.traits.iter()
309             .map(|trait_| {
310                 let modname = format_ident!("{}_{}", &self.modname,
311                                             trait_.ss_name());
312                 trait_.trait_impl(&modname)
313             }).collect::<Vec<_>>();
314         let vis = &self.vis;
315         quote!(
316             #[allow(non_snake_case)]
317             #[allow(missing_docs)]
318             pub mod #modname {
319                 use super::*;
320                 #(#priv_mods)*
321             }
322             #[allow(non_camel_case_types)]
323             #[allow(non_snake_case)]
324             #[allow(missing_docs)]
325             #(#attrs)*
326             #vis struct #struct_name #ig #wc
327             {
328                 #(#field_definitions),*
329             }
330             #debug_impl
331             impl #ig ::std::default::Default for #struct_name #tg #wc {
332                 #[allow(clippy::default_trait_access)]
333                 fn default() -> Self {
334                     Self {
335                         #(#default_inits),*
336                     }
337                 }
338             }
339             #(#substructs)*
340             impl #ig #struct_name #tg #wc {
341                 #(#consts)*
342                 #(#calls)*
343                 #(#contexts)*
344                 #(#expects)*
345                 /// Validate that all current expectations for all methods have
346                 /// been satisfied, and discard them.
347                 pub fn checkpoint(&mut self) {
348                     #(#substruct_expectations)*
349                     #(#method_checkpoints)*
350                 }
351                 #new_method
352             }
353             #(#trait_impls)*
354         ).to_tokens(tokens);
355     }
356 }
357 
358 pub(crate) struct MockItemTraitImpl {
359     attrs: Vec<Attribute>,
360     generics: Generics,
361     /// Inherent methods of the mock struct
362     methods: Methods,
363     /// Name of the overall module that holds all of the mock stuff
364     modname: Ident,
365     name: Ident,
366     /// Name of the field of this type in the parent's structure
367     fieldname: Ident,
368 }
369 
370 impl MockItemTraitImpl {
371     /// Are all of this traits's methods static?
all_static(&self) -> bool372     fn all_static(&self) -> bool {
373         self.methods.all_static()
374     }
375 
phantom_default_inits(&self) -> Vec<TokenStream>376     fn phantom_default_inits(&self) -> Vec<TokenStream> {
377         phantom_default_inits(&self.generics)
378     }
379 
phantom_fields(&self) -> Vec<TokenStream>380     fn phantom_fields(&self) -> Vec<TokenStream> {
381         phantom_fields(&self.generics)
382     }
383 }
384 
385 impl ToTokens for MockItemTraitImpl {
to_tokens(&self, tokens: &mut TokenStream)386     fn to_tokens(&self, tokens: &mut TokenStream) {
387         let attrs = AttrFormatter::new(&self.attrs)
388             .async_trait(false)
389             .doc(false)
390             .format();
391         let struct_name = &self.name;
392         let (ig, tg, wc) = self.generics.split_for_impl();
393         let modname = &self.modname;
394         let method_checkpoints = self.methods.checkpoints();
395         let mut default_inits = self.methods.default_inits();
396         default_inits.extend(self.phantom_default_inits());
397         let mut field_definitions = self.methods.field_definitions(modname);
398         field_definitions.extend(self.phantom_fields());
399         let priv_mods = self.methods.priv_mods();
400         quote!(
401             #[allow(non_snake_case)]
402             #[allow(missing_docs)]
403             #(#attrs)*
404             pub mod #modname {
405                 use super::*;
406                 #(#priv_mods)*
407             }
408             #[allow(non_camel_case_types)]
409             #[allow(non_snake_case)]
410             #[allow(missing_docs)]
411             #(#attrs)*
412             struct #struct_name #ig #wc
413             {
414                 #(#field_definitions),*
415             }
416             #(#attrs)*
417             impl #ig ::std::default::Default for #struct_name #tg #wc {
418                 fn default() -> Self {
419                     Self {
420                         #(#default_inits),*
421                     }
422                 }
423             }
424             #(#attrs)*
425             impl #ig #struct_name #tg #wc {
426                 /// Validate that all current expectations for all methods have
427                 /// been satisfied, and discard them.
428                 pub fn checkpoint(&mut self) {
429                     #(#method_checkpoints)*
430                 }
431             }
432         ).to_tokens(tokens);
433     }
434 }
435