1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAException.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda {
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker #ifdef TORCH_USE_CUDA_DSA
9*da0073e9SAndroid Build Coastguard Worker // Copy string from `src` to `dst`
dstrcpy(char * dst,const char * src)10*da0073e9SAndroid Build Coastguard Worker static __device__ void dstrcpy(char* dst, const char* src) {
11*da0073e9SAndroid Build Coastguard Worker int i = 0;
12*da0073e9SAndroid Build Coastguard Worker // Copy string from source to destination, ensuring that it
13*da0073e9SAndroid Build Coastguard Worker // isn't longer than `C10_CUDA_DSA_MAX_STR_LEN-1`
14*da0073e9SAndroid Build Coastguard Worker while (*src != '\0' && i++ < C10_CUDA_DSA_MAX_STR_LEN - 1) {
15*da0073e9SAndroid Build Coastguard Worker *dst++ = *src++;
16*da0073e9SAndroid Build Coastguard Worker }
17*da0073e9SAndroid Build Coastguard Worker *dst = '\0';
18*da0073e9SAndroid Build Coastguard Worker }
19*da0073e9SAndroid Build Coastguard Worker
dsa_add_new_assertion_failure(DeviceAssertionsData * assertions_data,const char * assertion_msg,const char * filename,const char * function_name,const int line_number,const uint32_t caller,const dim3 block_id,const dim3 thread_id)20*da0073e9SAndroid Build Coastguard Worker static __device__ void dsa_add_new_assertion_failure(
21*da0073e9SAndroid Build Coastguard Worker DeviceAssertionsData* assertions_data,
22*da0073e9SAndroid Build Coastguard Worker const char* assertion_msg,
23*da0073e9SAndroid Build Coastguard Worker const char* filename,
24*da0073e9SAndroid Build Coastguard Worker const char* function_name,
25*da0073e9SAndroid Build Coastguard Worker const int line_number,
26*da0073e9SAndroid Build Coastguard Worker const uint32_t caller,
27*da0073e9SAndroid Build Coastguard Worker const dim3 block_id,
28*da0073e9SAndroid Build Coastguard Worker const dim3 thread_id) {
29*da0073e9SAndroid Build Coastguard Worker // `assertions_data` may be nullptr if device-side assertion checking
30*da0073e9SAndroid Build Coastguard Worker // is disabled at run-time. If it is disabled at compile time this
31*da0073e9SAndroid Build Coastguard Worker // function will never be called
32*da0073e9SAndroid Build Coastguard Worker if (!assertions_data) {
33*da0073e9SAndroid Build Coastguard Worker return;
34*da0073e9SAndroid Build Coastguard Worker }
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker // Atomically increment so other threads can fail at the same time
37*da0073e9SAndroid Build Coastguard Worker // Note that incrementing this means that the CPU can observe that
38*da0073e9SAndroid Build Coastguard Worker // a failure has happened and can begin to respond before we've
39*da0073e9SAndroid Build Coastguard Worker // written information about that failure out to the buffer.
40*da0073e9SAndroid Build Coastguard Worker const auto nid = atomicAdd(&(assertions_data->assertion_count), 1);
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker if (nid >= C10_CUDA_DSA_ASSERTION_COUNT) {
43*da0073e9SAndroid Build Coastguard Worker // At this point we're ran out of assertion buffer space.
44*da0073e9SAndroid Build Coastguard Worker // We could print a message about this, but that'd get
45*da0073e9SAndroid Build Coastguard Worker // spammy if a lot of threads did it, so we just silently
46*da0073e9SAndroid Build Coastguard Worker // ignore any other assertion failures. In most cases the
47*da0073e9SAndroid Build Coastguard Worker // failures will all probably be analogous anyway.
48*da0073e9SAndroid Build Coastguard Worker return;
49*da0073e9SAndroid Build Coastguard Worker }
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker // Write information about the assertion failure to memory.
52*da0073e9SAndroid Build Coastguard Worker // Note that this occurs only after the `assertion_count`
53*da0073e9SAndroid Build Coastguard Worker // increment broadcasts that there's been a problem.
54*da0073e9SAndroid Build Coastguard Worker auto& self = assertions_data->assertions[nid];
55*da0073e9SAndroid Build Coastguard Worker dstrcpy(self.assertion_msg, assertion_msg);
56*da0073e9SAndroid Build Coastguard Worker dstrcpy(self.filename, filename);
57*da0073e9SAndroid Build Coastguard Worker dstrcpy(self.function_name, function_name);
58*da0073e9SAndroid Build Coastguard Worker self.line_number = line_number;
59*da0073e9SAndroid Build Coastguard Worker self.caller = caller;
60*da0073e9SAndroid Build Coastguard Worker self.block_id[0] = block_id.x;
61*da0073e9SAndroid Build Coastguard Worker self.block_id[1] = block_id.y;
62*da0073e9SAndroid Build Coastguard Worker self.block_id[2] = block_id.z;
63*da0073e9SAndroid Build Coastguard Worker self.thread_id[0] = thread_id.x;
64*da0073e9SAndroid Build Coastguard Worker self.thread_id[1] = thread_id.y;
65*da0073e9SAndroid Build Coastguard Worker self.thread_id[2] = thread_id.z;
66*da0073e9SAndroid Build Coastguard Worker }
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker // Emulates a kernel assertion. The assertion won't stop the kernel's progress,
69*da0073e9SAndroid Build Coastguard Worker // so you should assume everything the kernel produces is garbage if there's an
70*da0073e9SAndroid Build Coastguard Worker // assertion failure.
71*da0073e9SAndroid Build Coastguard Worker // NOTE: This assumes that `assertions_data` and `assertion_caller_id` are
72*da0073e9SAndroid Build Coastguard Worker // arguments of the kernel and therefore accessible.
73*da0073e9SAndroid Build Coastguard Worker #define CUDA_KERNEL_ASSERT2(condition) \
74*da0073e9SAndroid Build Coastguard Worker do { \
75*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(!(condition))) { \
76*da0073e9SAndroid Build Coastguard Worker /* Has an atomic element so threads can fail at the same time */ \
77*da0073e9SAndroid Build Coastguard Worker c10::cuda::dsa_add_new_assertion_failure( \
78*da0073e9SAndroid Build Coastguard Worker assertions_data, \
79*da0073e9SAndroid Build Coastguard Worker C10_STRINGIZE(condition), \
80*da0073e9SAndroid Build Coastguard Worker __FILE__, \
81*da0073e9SAndroid Build Coastguard Worker __FUNCTION__, \
82*da0073e9SAndroid Build Coastguard Worker __LINE__, \
83*da0073e9SAndroid Build Coastguard Worker assertion_caller_id, \
84*da0073e9SAndroid Build Coastguard Worker blockIdx, \
85*da0073e9SAndroid Build Coastguard Worker threadIdx); \
86*da0073e9SAndroid Build Coastguard Worker /* Now that the kernel has failed we early exit the kernel, but */ \
87*da0073e9SAndroid Build Coastguard Worker /* otherwise keep going and rely on the host to check UVM and */ \
88*da0073e9SAndroid Build Coastguard Worker /* determine we've had a problem */ \
89*da0073e9SAndroid Build Coastguard Worker return; \
90*da0073e9SAndroid Build Coastguard Worker } \
91*da0073e9SAndroid Build Coastguard Worker } while (false)
92*da0073e9SAndroid Build Coastguard Worker #else
93*da0073e9SAndroid Build Coastguard Worker #define CUDA_KERNEL_ASSERT2(condition) assert(condition)
94*da0073e9SAndroid Build Coastguard Worker #endif
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda
97