1 // Copyright 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use proc_macro2::TokenStream;
16 use syn::{punctuated::Punctuated, spanned::Spanned, ItemFn, LitStr, Token};
17 
18 mod meta;
19 mod meta_arg;
20 mod substitutions;
21 
22 use meta::JniMethodMeta;
23 use substitutions::substitute_method_chars;
24 
jni_method(meta: TokenStream, item: TokenStream) -> syn::Result<ItemFn>25 pub fn jni_method(meta: TokenStream, item: TokenStream) -> syn::Result<ItemFn> {
26     let meta = syn::parse2::<JniMethodMeta>(meta)?;
27     let mut func = syn::parse2::<ItemFn>(item)?;
28 
29     // Check that ABI is set to `extern "system"`
30     if let Some(
31         ref abi @ syn::Abi {
32             name: Some(ref abi_name),
33             ..
34         },
35     ) = func.sig.abi
36     {
37         if abi_name.value() != "system" {
38             return Err(syn::Error::new(
39                 abi.span(),
40                 "JNI methods are required to have the `extern \"system\"` ABI",
41             ));
42         }
43     } else {
44         return Err(syn::Error::new(
45             func.sig.span(),
46             "JNI methods are required to have the `extern \"system\"` ABI",
47         ));
48     }
49 
50     let export_attr = {
51         // Format the name of the function as expected by the JNI layer
52         let (method_name, method_name_span) = if let Some(meta_name) = &meta.method_name {
53             (meta_name.value(), meta_name.span())
54         } else {
55             (func.sig.ident.to_string(), func.sig.ident.span())
56         };
57         if method_name.starts_with("Java_") {
58             return Err(syn::Error::new(
59                 method_name_span,
60                 "The `jni_method` attribute will perform the JNI name formatting",
61             ));
62         }
63         let method_name = substitute_method_chars(&method_name);
64 
65         // NOTE: doesn't handle overload suffix
66         let link_name = LitStr::new(
67             &format!("Java_{class}_{method_name}", class = &meta.class_desc),
68             method_name_span,
69         );
70 
71         syn::parse_quote! { #[export_name = #link_name] }
72     };
73     func.attrs.push(export_attr);
74 
75     // Allow function name to be non_snake_case if we are using it as the Java method name
76     if meta.method_name.is_none() {
77         let allow_attr = syn::parse_quote! { #[allow(non_snake_case)] };
78         func.attrs.push(allow_attr);
79     }
80 
81     // Add a panic handler if requested
82     if let Some(panic_returns) = meta.panic_returns {
83         let block = &func.block;
84         let return_type = &func.sig.output;
85         let mut lifetimes = Punctuated::new();
86         for param in func.sig.generics.lifetimes() {
87             lifetimes.push_value(param.clone());
88             lifetimes.push_punct(<Token![,]>::default());
89         }
90 
91         let panic_check = quote::quote_spanned! { panic_returns.span() =>
92             #[cfg(not(panic = "unwind"))]
93             ::core::compile_error!("Cannot use `panic_returns` with non-unwinding panic handler");
94         };
95 
96         func.block = syn::parse_quote! {
97             {
98                 #panic_check
99                 match ::std::panic::catch_unwind(move || {
100                     #block
101                 }) {
102                     Ok(ret) => ret,
103                     Err(_err) => {
104                         fn __on_panic<#lifetimes>() #return_type { #panic_returns }
105                         __on_panic()
106                     },
107                 }
108             }
109         };
110     }
111 
112     // Return the modified function
113     Ok(func)
114 }
115 
116 #[cfg(test)]
117 #[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
118 mod tests {
119     use super::*;
120     use crate::test_util::contains_ident;
121     use quote::{quote, ToTokens};
122 
123     #[test]
can_parse()124     fn can_parse() {
125         let meta = quote! {
126             package = "com.example",
127             class = "Foo.Inner",
128             panic_returns = false,
129         };
130 
131         let func = quote! {
132             extern "system" fn nativeFoo<'local>(
133                 mut env: JNIEnv<'local>,
134                 this: JObject<'local>
135             ) -> jint {
136                 123
137             }
138         };
139 
140         let out = jni_method(meta, func).unwrap();
141 
142         assert!(contains_ident(out.into_token_stream(), "catch_unwind"));
143     }
144 
parse_example_output() -> syn::ItemFn145     fn parse_example_output() -> syn::ItemFn {
146         let meta = quote! {
147             package = "com.example",
148             class = "Foo.Inner",
149             panic_returns = false,
150         };
151 
152         let func = quote! {
153             extern "system" fn nativeFoo<'local>(
154                 mut env: JNIEnv<'local>,
155                 this: JObject<'local>
156             ) -> jint {
157                 123
158             }
159         };
160 
161         jni_method(meta, func).expect("failed to generate example")
162     }
163 
parse_example_output_method_name() -> syn::ItemFn164     fn parse_example_output_method_name() -> syn::ItemFn {
165         let meta = quote! {
166             package = "com.example",
167             class = "Foo.Inner",
168             method_name = "nativeBar",
169             panic_returns = false,
170         };
171 
172         let func = quote! {
173             extern "system" fn native_bar<'local>(
174                 mut env: JNIEnv<'local>,
175                 this: JObject<'local>
176             ) -> jint {
177                 123
178             }
179         };
180 
181         jni_method(meta, func).expect("failed to generate example")
182     }
183 
184     #[test]
check_output_is_itemfn()185     fn check_output_is_itemfn() {
186         let _item_fn = parse_example_output();
187     }
188 
189     #[test]
check_output_export_name()190     fn check_output_export_name() {
191         let out = parse_example_output();
192 
193         let export_name = out
194             .attrs
195             .iter()
196             .find_map(|attr| {
197                 let syn::Meta::NameValue(nv) = &attr.meta else {
198                     return None;
199                 };
200                 if !nv.path.is_ident("export_name") {
201                     return None;
202                 }
203                 let syn::Expr::Lit(syn::ExprLit {
204                     lit: syn::Lit::Str(lit_str),
205                     ..
206                 }) = &nv.value
207                 else {
208                     return None;
209                 };
210                 Some(lit_str.value())
211             })
212             .expect("Failed to find `export_name` attribute");
213         assert_eq!("Java_com_example_Foo_00024Inner_nativeFoo", export_name);
214     }
215 
216     #[test]
check_output_export_name_with_method_name()217     fn check_output_export_name_with_method_name() {
218         let out = parse_example_output_method_name();
219 
220         let export_name = out
221             .attrs
222             .iter()
223             .find_map(|attr| {
224                 let syn::Meta::NameValue(nv) = &attr.meta else {
225                     return None;
226                 };
227                 if !nv.path.is_ident("export_name") {
228                     return None;
229                 }
230                 let syn::Expr::Lit(syn::ExprLit {
231                     lit: syn::Lit::Str(lit_str),
232                     ..
233                 }) = &nv.value
234                 else {
235                     return None;
236                 };
237                 Some(lit_str.value())
238             })
239             .expect("Failed to find `export_name` attribute");
240         assert_eq!("Java_com_example_Foo_00024Inner_nativeBar", export_name);
241     }
242 
243     #[test]
check_output_allow_non_snake_case()244     fn check_output_allow_non_snake_case() {
245         let out = parse_example_output();
246 
247         let _allow_attr = out
248             .attrs
249             .iter()
250             .find(|attr| {
251                 let syn::Meta::List(ml) = &attr.meta else {
252                     return false;
253                 };
254                 if !ml.path.is_ident("allow") {
255                     return false;
256                 }
257                 let Ok(value) = syn::parse2::<syn::Path>(ml.tokens.clone()) else {
258                     return false;
259                 };
260                 value.is_ident("non_snake_case")
261             })
262             .expect("Failed to find `allow(non_snake_case)` attribute");
263     }
264 
265     #[test]
check_output_allow_non_snake_case_not_present_with_method_name()266     fn check_output_allow_non_snake_case_not_present_with_method_name() {
267         let out = parse_example_output_method_name();
268 
269         let allow_attr = out.attrs.iter().find(|attr| {
270             let syn::Meta::List(ml) = &attr.meta else {
271                 return false;
272             };
273             if !ml.path.is_ident("allow") {
274                 return false;
275             }
276             let Ok(value) = syn::parse2::<syn::Path>(ml.tokens.clone()) else {
277                 return false;
278             };
279             value.is_ident("non_snake_case")
280         });
281         assert!(allow_attr.is_none());
282     }
283 
284     #[test]
no_panic_returns()285     fn no_panic_returns() {
286         let meta = quote! {
287             package = "com.example",
288             class = "Foo.Inner",
289         };
290 
291         let func = quote! {
292             extern "system" fn nativeFoo<'local>(
293                 mut env: JNIEnv<'local>,
294                 this: JObject<'local>
295             ) -> jint {
296                 123
297             }
298         };
299 
300         let out = match jni_method(meta, func) {
301             Ok(item_fn) => item_fn.into_token_stream(),
302             Err(err) => err.into_compile_error(),
303         };
304         assert!(!contains_ident(out.clone(), "compile_error"));
305         assert!(!contains_ident(out, "catch_unwind"));
306     }
307 
308     #[test]
missing_extern()309     fn missing_extern() {
310         let meta = quote! {
311             package = "com.example",
312             class = "Foo.Inner",
313             panic_returns = false,
314         };
315 
316         let func = quote! {
317             fn nativeFoo<'local>(
318                 mut env: JNIEnv<'local>,
319                 this: JObject<'local>
320             ) -> jint {
321                 123
322             }
323         };
324 
325         let Err(err) = jni_method(meta, func) else {
326             panic!("Should fail to generate code");
327         };
328 
329         assert!(err
330             .to_string()
331             .contains("JNI methods are required to have the `extern \"system\"` ABI"));
332     }
333 
334     #[test]
wrong_extern()335     fn wrong_extern() {
336         let meta = quote! {
337             package = "com.example",
338             class = "Foo.Inner",
339             panic_returns = false,
340         };
341 
342         let func = quote! {
343             extern "C" fn nativeFoo<'local>(
344                 mut env: JNIEnv<'local>,
345                 this: JObject<'local>
346             ) -> jint {
347                 123
348             }
349         };
350 
351         let Err(err) = jni_method(meta, func) else {
352             panic!("Should fail to generate code");
353         };
354 
355         assert!(err
356             .to_string()
357             .contains("JNI methods are required to have the `extern \"system\"` ABI"));
358     }
359 
360     #[test]
already_mangled()361     fn already_mangled() {
362         let meta = quote! {
363             package = "com.example",
364             class = "Foo.Inner",
365             panic_returns = false,
366         };
367 
368         let func = quote! {
369             extern "system" fn Java_com_example_Foo_00024Inner_nativeFoo<'local>(
370                 mut env: JNIEnv<'local>,
371                 this: JObject<'local>
372             ) -> jint {
373                 123
374             }
375         };
376 
377         let Err(err) = jni_method(meta, func) else {
378             panic!("Should fail to generate code");
379         };
380 
381         assert!(err
382             .to_string()
383             .contains("The `jni_method` attribute will perform the JNI name formatting"));
384     }
385 
386     #[test]
already_mangled_method_name()387     fn already_mangled_method_name() {
388         let meta = quote! {
389             package = "com.example",
390             class = "Foo.Inner",
391             method_name = "Java_com_example_Foo_00024Inner_nativeFoo",
392             panic_returns = false,
393         };
394 
395         let func = quote! {
396             extern "system" fn native_foo<'local>(
397                 mut env: JNIEnv<'local>,
398                 this: JObject<'local>
399             ) -> jint {
400                 123
401             }
402         };
403 
404         let Err(err) = jni_method(meta, func) else {
405             panic!("Should fail to generate code");
406         };
407 
408         assert!(err
409             .to_string()
410             .contains("The `jni_method` attribute will perform the JNI name formatting"));
411     }
412 }
413