1 #![cfg(feature = "invocation")]
2
3 use std::{
4 sync::{
5 atomic::{AtomicUsize, Ordering},
6 Arc, Barrier,
7 },
8 thread::spawn,
9 time::Duration,
10 };
11
12 use jni::{objects::AutoLocal, sys::jint, Executor};
13
14 use rusty_fork::rusty_fork_test;
15
16 mod util;
17 use util::{jvm, AtomicIntegerProxy};
18
19 #[test]
single_thread()20 fn single_thread() {
21 let executor = Executor::new(jvm().clone());
22 test_single_thread(executor);
23 }
24
25 #[test]
serialized_threads()26 fn serialized_threads() {
27 let executor = Executor::new(jvm().clone());
28 test_serialized_threads(executor);
29 }
30
31 #[test]
concurrent_threads()32 fn concurrent_threads() {
33 let executor = Executor::new(jvm().clone());
34 const THREAD_NUM: usize = 8;
35 test_concurrent_threads(executor, THREAD_NUM)
36 }
37
test_single_thread(executor: Executor)38 fn test_single_thread(executor: Executor) {
39 let mut atomic = AtomicIntegerProxy::new(executor, 0).unwrap();
40 assert_eq!(0, atomic.get().unwrap());
41 assert_eq!(1, atomic.increment_and_get().unwrap());
42 assert_eq!(3, atomic.add_and_get(2).unwrap());
43 assert_eq!(3, atomic.get().unwrap());
44 }
45
test_serialized_threads(executor: Executor)46 fn test_serialized_threads(executor: Executor) {
47 let mut atomic = AtomicIntegerProxy::new(executor, 0).unwrap();
48 assert_eq!(0, atomic.get().unwrap());
49 let jh = spawn(move || {
50 assert_eq!(1, atomic.increment_and_get().unwrap());
51 assert_eq!(3, atomic.add_and_get(2).unwrap());
52 atomic
53 });
54 let mut atomic = jh.join().unwrap();
55 assert_eq!(3, atomic.get().unwrap());
56 }
57
test_concurrent_threads(executor: Executor, thread_num: usize)58 fn test_concurrent_threads(executor: Executor, thread_num: usize) {
59 const ITERS_PER_THREAD: usize = 10_000;
60
61 let mut atomic = AtomicIntegerProxy::new(executor, 0).unwrap();
62 let barrier = Arc::new(Barrier::new(thread_num));
63 let mut threads = Vec::new();
64
65 for _ in 0..thread_num {
66 let barrier = Arc::clone(&barrier);
67 let mut atomic = atomic.clone();
68 let jh = spawn(move || {
69 barrier.wait();
70 for _ in 0..ITERS_PER_THREAD {
71 atomic.increment_and_get().unwrap();
72 }
73 });
74 threads.push(jh);
75 }
76 for jh in threads {
77 jh.join().unwrap();
78 }
79 let expected = (ITERS_PER_THREAD * thread_num) as jint;
80 assert_eq!(expected, atomic.get().unwrap());
81 }
82
83 // We need to test `JavaVM::destroy()` in a separate process otherwise it will break
84 // all the other tests
85 rusty_fork_test! {
86 #[test]
87 fn test_destroy() {
88 const THREAD_NUM: usize = 2;
89 const DAEMON_THREAD_NUM: usize = 2;
90 static MATH_CLASS: &str = "java/lang/Math";
91
92 // We don't test this using an `Executor` because we don't want to
93 // attach all the threads as daemon threads.
94
95 let jvm = jvm().clone();
96
97 let atomic = Arc::new(AtomicUsize::new(0));
98
99 let attach_barrier = Arc::new(Barrier::new(THREAD_NUM + DAEMON_THREAD_NUM + 1));
100 let daemons_detached_barrier = Arc::new(Barrier::new(DAEMON_THREAD_NUM + 1));
101 let mut threads = Vec::new();
102
103 for _ in 0..THREAD_NUM {
104 let attach_barrier = Arc::clone(&attach_barrier);
105 let jvm = jvm.clone();
106 let atomic = atomic.clone();
107 let jh = spawn(move || {
108 let mut env = jvm.attach_current_thread().unwrap();
109 println!("java thread attach");
110 attach_barrier.wait();
111 println!("java thread run");
112 std::thread::sleep(Duration::from_millis(250));
113
114 println!("use before destroy...");
115 // Make some token JNI call
116 let _class = AutoLocal::new(env.find_class(MATH_CLASS).unwrap(), &env);
117
118 atomic.fetch_add(1, Ordering::SeqCst);
119
120 println!("java thread finished");
121 });
122 threads.push(jh);
123 }
124
125 for _ in 0..DAEMON_THREAD_NUM {
126 let attach_barrier = Arc::clone(&attach_barrier);
127 let daemons_detached_barrier = Arc::clone(&daemons_detached_barrier);
128 let jvm = jvm.clone();
129 let atomic = atomic.clone();
130 let jh = spawn(move || {
131 // We have to be _very_ careful to ensure we have finished accessing the
132 // JavaVM before it gets destroyed, including dropping the AutoLocal
133 // for the `MATH_CLASS`
134 {
135 let mut env = jvm.attach_current_thread_as_daemon().unwrap();
136 println!("daemon thread attach");
137 attach_barrier.wait();
138 println!("daemon thread run");
139
140 println!("daemon JVM use before destroy...");
141
142 let _class = AutoLocal::new(env.find_class(MATH_CLASS).unwrap(), &env);
143 }
144
145 // For it to be safe to call `JavaVM::destroy()` we need to ensure that
146 // daemon threads are detached from the JavaVM ahead of time because
147 // `JavaVM::destroy()` does not synchronize and wait for them to exit
148 // which means we would effectively trigger a use-after-free when daemon
149 // threads exit and they try to automatically detach from the `JavaVM`
150 //
151 // # Safety
152 // We won't be accessing any (invalid) `JNIEnv` once we have detached this
153 // thread
154 unsafe {
155 jvm.detach_current_thread();
156 }
157
158 daemons_detached_barrier.wait();
159
160 for _ in 0..10 {
161 std::thread::sleep(Duration::from_millis(100));
162 println!("daemon thread running");
163 }
164
165 atomic.fetch_add(1, Ordering::SeqCst);
166
167 println!("daemon thread finished");
168 });
169 threads.push(jh);
170 }
171
172 // At this point we at least know that all threads have been attached
173 // to the JVM
174 println!("MAIN: waiting for threads attached barrier");
175 attach_barrier.wait();
176
177 // Before we try and destroy the JavaVM we need to be sure that the daemon
178 // threads have finished using the VM since `jvm.destroy()` won't wait
179 // for daemon threads to exit.
180 println!("MAIN: waiting for daemon threads detached barrier");
181 daemons_detached_barrier.wait();
182
183 // # Safety
184 //
185 // We drop the `jvm` variable immediately after `destroy()` returns to avoid
186 // any use-after-free.
187 unsafe {
188 println!("MAIN: calling DestroyJavaVM()...");
189 jvm.destroy().unwrap();
190 drop(jvm);
191 println!("MAIN: jvm destroyed");
192 }
193
194 println!("MAIN: joining (waiting for) all threads");
195 let mut joined = 0;
196 for jh in threads {
197 jh.join().unwrap();
198 joined += 1;
199 println!(
200 "joined {joined} threads, atomic = {}",
201 atomic.load(Ordering::SeqCst)
202 );
203 }
204
205 assert_eq!(
206 atomic.load(Ordering::SeqCst),
207 THREAD_NUM + DAEMON_THREAD_NUM
208 );
209 }
210
211 }
212