xref: /aosp_15_r20/external/pytorch/c10/cuda/test/impl/CUDAAssertionsTest_from_2_processes.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gmock/gmock.h>
2 #include <gtest/gtest.h>
3 
4 #include <c10/cuda/CUDADeviceAssertion.h>
5 #include <c10/cuda/CUDAException.h>
6 #include <c10/cuda/CUDAFunctions.h>
7 #include <c10/cuda/CUDAStream.h>
8 
9 #include <chrono>
10 #include <iostream>
11 #include <string>
12 #include <thread>
13 
14 using ::testing::HasSubstr;
15 
16 const auto max_assertions_failure_str =
17     "Assertion failure " + std::to_string(C10_CUDA_DSA_ASSERTION_COUNT - 1);
18 
19 /**
20  * Device kernel that takes a single integer parameter as argument and
21  * will always trigger a device side assertion.
22  */
cuda_always_fail_assertion_kernel(const int a,TORCH_DSA_KERNEL_ARGS)23 __global__ void cuda_always_fail_assertion_kernel(
24     const int a,
25     TORCH_DSA_KERNEL_ARGS) {
26   CUDA_KERNEL_ASSERT2(a != a);
27 }
28 
29 /**
30  * Device kernel that takes a single integer parameter as argument and
31  * will never trigger a device side assertion.
32  */
cuda_always_succeed_assertion_kernel(const int a,TORCH_DSA_KERNEL_ARGS)33 __global__ void cuda_always_succeed_assertion_kernel(
34     const int a,
35     TORCH_DSA_KERNEL_ARGS) {
36   CUDA_KERNEL_ASSERT2(a == a);
37 }
38 
39 // Windows doesn't like `fork`
40 #ifndef _MSC_VER
41 /**
42  * TEST: Triggering device side assertion from 2 different processes from CPU.
43  * The following code is testing if two processes from CPU that are running
44  * GPU kernels (not necessarily simultaneously) and are asserting & writing
45  * to the respective UVMs, mess up anything for each other.
46  * Once parent process's kernel launch fails and causes a device-side assertion
47  * and is still alive when the second process is interacting with the GPU,
48  * trying to launch another kernel.
49  */
cuda_device_assertions_from_2_processes()50 void cuda_device_assertions_from_2_processes() {
51   const auto n1 = fork();
52   if (n1 == 0) {
53     // This is the parent process, that will call an assertion failure.
54     // This should execute before the child process.
55     // We are achieving this by putting the child process to sleep.
56     TORCH_DSA_KERNEL_LAUNCH(
57         cuda_always_fail_assertion_kernel,
58         1, /* Blocks */
59         1, /* Threads */
60         0, /* Shared mem */
61         c10::cuda::getStreamFromPool(), /* Stream */
62         1);
63     try {
64       c10::cuda::device_synchronize();
65       throw std::runtime_error("Test didn't fail, but should have.");
66     } catch (const c10::Error& err) {
67       const auto err_str = std::string(err.what());
68       ASSERT_THAT(
69           err_str,
70           HasSubstr(
71               "1 CUDA device-side assertion failures were found on GPU #0!"));
72     }
73     // Keep this alive so we can see what happened to the other process
74     std::this_thread::sleep_for(std::chrono::milliseconds(3000));
75   } else {
76     // This is the child process
77     // We put it to sleep for next 2 seconds, to make sure that the parent has
78     // asserted a failure already.
79     std::this_thread::sleep_for(std::chrono::milliseconds(2000));
80     TORCH_DSA_KERNEL_LAUNCH(
81         cuda_always_succeed_assertion_kernel,
82         1, /* Blocks */
83         1, /* Threads */
84         0, /* Shared mem */
85         c10::cuda::getStreamFromPool(), /* Stream */
86         1);
87     try {
88       c10::cuda::device_synchronize();
89     } catch (const c10::Error& err) {
90       ASSERT_TRUE(false); // This kernel should not have failed, but did.
91     }
92     // End the child process
93     exit(0);
94   }
95 }
96 
TEST(CUDATest,cuda_device_assertions_from_2_processes)97 TEST(CUDATest, cuda_device_assertions_from_2_processes) {
98 #ifdef TORCH_USE_CUDA_DSA
99   c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled_at_runtime = true;
100   cuda_device_assertions_from_2_processes();
101 #else
102   GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled at compile time.";
103 #endif
104 }
105 
106 #else
107 
108 #endif
109