1 // Copyright 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use quote::quote;
16 use syn::{parse_macro_input, Attribute, ItemFn, ReturnType};
17 
18 /// Marks a test to be run by the Google Rust test runner.
19 ///
20 /// Annotate tests the same way ordinary Rust tests are annotated:
21 ///
22 /// ```ignore
23 /// #[googletest::test]
24 /// fn should_work() {
25 ///     ...
26 /// }
27 /// ```
28 ///
29 /// The test function is not required to have a return type. If it does have a
30 /// return type, that type must be [`googletest::Result`]. One may do this if
31 /// one wishes to use both fatal and non-fatal assertions in the same test. For
32 /// example:
33 ///
34 /// ```ignore
35 /// #[googletest::test]
36 /// fn should_work() -> googletest::Result<()> {
37 ///     let value = 2;
38 ///     expect_that!(value, gt(0));
39 ///     verify_that!(value, eq(2))
40 /// }
41 /// ```
42 ///
43 /// This macro can be used with `#[should_panic]` to indicate that the test is
44 /// expected to panic. For example:
45 ///
46 /// ```ignore
47 /// #[googletest::test]
48 /// #[should_panic]
49 /// fn passes_due_to_should_panic() {
50 ///     let value = 2;
51 ///     expect_that!(value, gt(0));
52 ///     panic!("This panics");
53 /// }
54 /// ```
55 ///
56 /// Using `#[should_panic]` modifies the behaviour of `#[googletest::test]` so
57 /// that the test panics (and passes) if any non-fatal assertion occurs.
58 /// For example, the following test passes:
59 ///
60 /// ```ignore
61 /// #[googletest::test]
62 /// #[should_panic]
63 /// fn passes_due_to_should_panic_and_failing_assertion() {
64 ///     let value = 2;
65 ///     expect_that!(value, eq(0));
66 /// }
67 /// ```
68 ///
69 /// [`googletest::Result`]: type.Result.html
70 #[proc_macro_attribute]
test( _args: proc_macro::TokenStream, input: proc_macro::TokenStream, ) -> proc_macro::TokenStream71 pub fn test(
72     _args: proc_macro::TokenStream,
73     input: proc_macro::TokenStream,
74 ) -> proc_macro::TokenStream {
75     let mut parsed_fn = parse_macro_input!(input as ItemFn);
76     let attrs = parsed_fn.attrs.drain(..).collect::<Vec<_>>();
77     let (mut sig, block) = (parsed_fn.sig, parsed_fn.block);
78     let (outer_return_type, trailer) =
79         if attrs.iter().any(|attr| attr.path().is_ident("should_panic")) {
80             (quote! { () }, quote! { .unwrap(); })
81         } else {
82             (
83                 quote! { std::result::Result<(), googletest::internal::test_outcome::TestFailure> },
84                 quote! {},
85             )
86         };
87     let output_type = match sig.output.clone() {
88         ReturnType::Type(_, output_type) => Some(output_type),
89         ReturnType::Default => None,
90     };
91     sig.output = ReturnType::Default;
92     let (maybe_closure, invocation) = if sig.asyncness.is_some() {
93         (
94             // In the async case, the ? operator returns from the *block* rather than the
95             // surrounding function. So we just put the test content in an async block. Async
96             // closures are still unstable (see https://github.com/rust-lang/rust/issues/62290),
97             // so we can't use the same solution as the sync case below.
98             quote! {},
99             quote! {
100                 async { #block }.await
101             },
102         )
103     } else {
104         (
105             // In the sync case, the ? operator returns from the surrounding function. So we must
106             // create a separate closure from which the ? operator can return in order to capture
107             // the output.
108             quote! {
109                 let test = move || #block;
110             },
111             quote! {
112                 test()
113             },
114         )
115     };
116     let function = if let Some(output_type) = output_type {
117         quote! {
118             #(#attrs)*
119             #sig -> #outer_return_type {
120                 #maybe_closure
121                 use googletest::internal::test_outcome::TestOutcome;
122                 TestOutcome::init_current_test_outcome();
123                 let result: #output_type = #invocation;
124                 TestOutcome::close_current_test_outcome(result)
125                 #trailer
126             }
127         }
128     } else {
129         quote! {
130             #(#attrs)*
131             #sig -> #outer_return_type {
132                 #maybe_closure
133                 use googletest::internal::test_outcome::TestOutcome;
134                 TestOutcome::init_current_test_outcome();
135                 #invocation;
136                 TestOutcome::close_current_test_outcome(googletest::Result::Ok(()))
137                 #trailer
138             }
139         }
140     };
141     let output = if attrs.iter().any(is_test_attribute) {
142         function
143     } else {
144         quote! {
145             #[::core::prelude::v1::test]
146             #function
147         }
148     };
149     output.into()
150 }
151 
is_test_attribute(attr: &Attribute) -> bool152 fn is_test_attribute(attr: &Attribute) -> bool {
153     let first_segment = match attr.path().segments.first() {
154         Some(first_segment) => first_segment,
155         None => return false,
156     };
157     let last_segment = match attr.path().segments.last() {
158         Some(last_segment) => last_segment,
159         None => return false,
160     };
161     last_segment.ident == "test"
162         || (first_segment.ident == "rstest"
163             && last_segment.ident == "rstest"
164             && attr.path().segments.len() <= 2)
165 }
166