1 #include <gmock/gmock.h>
2 #include <gtest/gtest.h>
3
4 #include <c10/cuda/CUDADeviceAssertion.h>
5 #include <c10/cuda/CUDAException.h>
6 #include <c10/cuda/CUDAFunctions.h>
7 #include <c10/cuda/CUDAStream.h>
8
9 #include <chrono>
10 #include <iostream>
11 #include <string>
12 #include <thread>
13
14 using ::testing::HasSubstr;
15
16 /**
17 * Device kernel that takes 2 arguments
18 * @param bad_thread represents the thread we want to trigger assertion on.
19 * @param bad_block represents the block we want to trigger assertion on.
20 * This kernel will only trigger a device side assertion for <<bad_block,
21 * bad_thread>> pair. all the other blocks and threads pairs will basically be
22 * no-op.
23 */
cuda_device_assertions_fail_on_thread_block_kernel(const int bad_thread,const int bad_block,TORCH_DSA_KERNEL_ARGS)24 __global__ void cuda_device_assertions_fail_on_thread_block_kernel(
25 const int bad_thread,
26 const int bad_block,
27 TORCH_DSA_KERNEL_ARGS) {
28 if (threadIdx.x == bad_thread && blockIdx.x == bad_block) {
29 CUDA_KERNEL_ASSERT2(false); // This comparison necessarily needs to fail
30 }
31 }
32
33 /**
34 * TEST: Triggering device side assertion on only 1 thread from <<<1024,128>>>
35 * grid. kernel used is unique, it take 2 parameters to tell which particular
36 * block and thread it should assert, all the other threads of the kernel will
37 * be basically no-op.
38 */
cuda_device_assertions_catches_thread_and_block_and_device()39 void cuda_device_assertions_catches_thread_and_block_and_device() {
40 const auto stream = c10::cuda::getStreamFromPool();
41 TORCH_DSA_KERNEL_LAUNCH(
42 cuda_device_assertions_fail_on_thread_block_kernel,
43 1024, /* Blocks */
44 128, /* Threads */
45 0, /* Shared mem */
46 stream, /* Stream */
47 29, /* bad thread */
48 937 /* bad block */
49 );
50
51 try {
52 c10::cuda::device_synchronize();
53 throw std::runtime_error("Test didn't fail, but should have.");
54 } catch (const c10::Error& err) {
55 const auto err_str = std::string(err.what());
56 ASSERT_THAT(
57 err_str, HasSubstr("Thread ID that failed assertion = [29,0,0]"));
58 ASSERT_THAT(
59 err_str, HasSubstr("Block ID that failed assertion = [937,0,0]"));
60 ASSERT_THAT(err_str, HasSubstr("Device that launched kernel = 0"));
61 ASSERT_THAT(
62 err_str,
63 HasSubstr(
64 "Name of kernel launched that led to failure = cuda_device_assertions_fail_on_thread_block_kernel"));
65 ASSERT_THAT(
66 err_str, HasSubstr("File containing kernel launch = " __FILE__));
67 ASSERT_THAT(
68 err_str,
69 HasSubstr(
70 "Function containing kernel launch = " +
71 std::string(__FUNCTION__)));
72 ASSERT_THAT(
73 err_str,
74 HasSubstr(
75 "Stream kernel was launched on = " + std::to_string(stream.id())));
76 }
77 }
78
TEST(CUDATest,cuda_device_assertions_catches_thread_and_block_and_device)79 TEST(CUDATest, cuda_device_assertions_catches_thread_and_block_and_device) {
80 #ifdef TORCH_USE_CUDA_DSA
81 c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled_at_runtime = true;
82 cuda_device_assertions_catches_thread_and_block_and_device();
83 #else
84 GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled at compile time.";
85 #endif
86 }
87