1 #include <gmock/gmock.h>
2 #include <gtest/gtest.h>
3
4 #include <c10/cuda/CUDADeviceAssertion.h>
5 #include <c10/cuda/CUDAException.h>
6 #include <c10/cuda/CUDAFunctions.h>
7 #include <c10/cuda/CUDAStream.h>
8
9 #include <chrono>
10 #include <iostream>
11 #include <string>
12 #include <thread>
13
14 using ::testing::HasSubstr;
15
16 const auto max_assertions_failure_str =
17 "Assertion failure " + std::to_string(C10_CUDA_DSA_ASSERTION_COUNT - 1);
18
19 /**
20 * Device kernel that takes a single integer parameter as argument and
21 * will always trigger a device side assertion.
22 */
cuda_always_fail_assertion_kernel(const int a,TORCH_DSA_KERNEL_ARGS)23 __global__ void cuda_always_fail_assertion_kernel(
24 const int a,
25 TORCH_DSA_KERNEL_ARGS) {
26 CUDA_KERNEL_ASSERT2(a != a);
27 }
28
29 /**
30 * TEST: Triggering device side assertion from multiple block but single thread
31 * <<<10,128>>>. Here we are triggering assertion on 10 blocks, each with only
32 * 128 thread.
33 */
cuda_device_assertions_multiple_writes_from_blocks_and_threads()34 void cuda_device_assertions_multiple_writes_from_blocks_and_threads() {
35 bool run_threads = false;
36
37 // Create a function to launch kernel that waits for a signal, to try to
38 // ensure everything is happening simultaneously
39 const auto launch_the_kernel = [&]() {
40 // Busy loop waiting for the signal to go
41 while (!run_threads) {
42 }
43
44 TORCH_DSA_KERNEL_LAUNCH(
45 cuda_always_fail_assertion_kernel,
46 10, /* Blocks */
47 128, /* Threads */
48 0, /* Shared mem */
49 c10::cuda::getCurrentCUDAStream(), /* Stream */
50 1);
51 };
52
53 // Spin up a bunch of busy-looping threads
54 std::vector<std::thread> threads;
55 for (int i = 0; i < 10; i++) {
56 threads.emplace_back(launch_the_kernel);
57 }
58
59 // Paranoid - wait for all the threads to get setup
60 std::this_thread::sleep_for(std::chrono::milliseconds(100));
61
62 // Mash
63 run_threads = true;
64
65 // Clean-up
66 for (auto& x : threads) {
67 x.join();
68 }
69
70 try {
71 c10::cuda::device_synchronize();
72 throw std::runtime_error("Test didn't fail, but should have.");
73 } catch (const c10::Error& err) {
74 const auto err_str = std::string(err.what());
75 ASSERT_THAT(err_str, HasSubstr(max_assertions_failure_str));
76 ASSERT_THAT(err_str, HasSubstr("Device that launched kernel = 0"));
77 ASSERT_THAT(
78 err_str,
79 HasSubstr(
80 "Name of kernel launched that led to failure = cuda_always_fail_assertion_kernel"));
81 ASSERT_THAT(
82 err_str, HasSubstr("File containing kernel launch = " __FILE__));
83 }
84 }
85
TEST(CUDATest,cuda_device_assertions_multiple_writes_from_blocks_and_threads)86 TEST(CUDATest, cuda_device_assertions_multiple_writes_from_blocks_and_threads) {
87 #ifdef TORCH_USE_CUDA_DSA
88 c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().enabled_at_runtime = true;
89 cuda_device_assertions_multiple_writes_from_blocks_and_threads();
90 #else
91 GTEST_SKIP() << "CUDA device-side assertions (DSA) was not enabled at compile time.";
92 #endif
93 }
94