1 #include <gmock/gmock.h>
2 #include <gtest/gtest.h>
3 
4 #include <c10/cuda/CUDADeviceAssertion.h>
5 #include <c10/cuda/CUDAException.h>
6 #include <c10/cuda/CUDAFunctions.h>
7 #include <c10/cuda/CUDAStream.h>
8 
9 #include <chrono>
10 #include <iostream>
11 #include <string>
12 #include <thread>
13 
14 using ::testing::HasSubstr;
15 
16 /**
17  * Device kernel that takes 2 arguments
18  * @param bad_thread represents the thread we want to trigger assertion on.
19  * @param bad_block represents the block we want to trigger assertion on.
20  * This kernel will only trigger a device side assertion for <<bad_block,
21  * bad_thread>> pair. all the other blocks and threads pairs will basically be
22  * no-op.
23  */
cuda_device_assertions_fail_on_thread_block_kernel(const int bad_thread,const int bad_block,TORCH_DSA_KERNEL_ARGS)24 __global__ void cuda_device_assertions_fail_on_thread_block_kernel(
25     const int bad_thread,
26     const int bad_block,
27     TORCH_DSA_KERNEL_ARGS) {
28   if (threadIdx.x == bad_thread && blockIdx.x == bad_block) {
29     CUDA_KERNEL_ASSERT2(false); // This comparison necessarily needs to fail
30   }
31 }
32 
33 /**
34  * TEST: Triggering device side assertion on only 1 thread from <<<1024,128>>>
35  * grid. kernel used is unique, it take 2 parameters to tell which particular
36  * block and thread it should assert, all the other threads of the kernel will
37  * be basically no-op.
38  */
cuda_device_assertions_catches_thread_and_block_and_device()39 void cuda_device_assertions_catches_thread_and_block_and_device() {
40   const auto stream = c10::cuda::getStreamFromPool();
41   TORCH_DSA_KERNEL_LAUNCH(
42       cuda_device_assertions_fail_on_thread_block_kernel,
43       1024, /* Blocks */
44       128, /* Threads */
45       0, /* Shared mem */
46       stream, /* Stream */
47       29, /* bad thread */
48       937 /* bad block */
49   );
50 
51   try {
52     c10::cuda::device_synchronize();
53     throw std::runtime_error("Test didn't fail, but should have.");
54   } catch (const c10::Error& err) {
55     const auto err_str = std::string(err.what());
56     ASSERT_THAT(
57         err_str, HasSubstr("Thread ID that failed assertion = [29,0,0]"));
58     ASSERT_THAT(
59         err_str, HasSubstr("Block ID that failed assertion = [937,0,0]"));
60     ASSERT_THAT(err_str, HasSubstr("Device that launched kernel = 0"));
61     ASSERT_THAT(
62         err_str,
63         HasSubstr(
64             "Name of kernel launched that led to failure = cuda_device_assertions_fail_on_thread_block_kernel"));
65     ASSERT_THAT(
66         err_str, HasSubstr("File containing kernel launch = " __FILE__));
67     ASSERT_THAT(
68         err_str,
69         HasSubstr(
70             "Function containing kernel launch = " +
71             std::string(__FUNCTION__)));
72     ASSERT_THAT(
73         err_str,
74         HasSubstr(
75             "Stream kernel was launched on = " + std::to_string(stream.id())));
76   }
77 }
78 
TEST(CUDATest,cuda_device_assertions_catches_thread_and_block_and_device)79 TEST(CUDATest, cuda_device_assertions_catches_thread_and_block_and_device) {
80 #ifdef TORCH_USE_CUDA_DSA
81   c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled_at_runtime = true;
82   cuda_device_assertions_catches_thread_and_block_and_device();
83 #else
84   GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled at compile time.";
85 #endif
86 }
87