xref: /aosp_15_r20/external/pytorch/c10/cuda/CUDADeviceAssertion.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAException.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda {
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker #ifdef TORCH_USE_CUDA_DSA
9*da0073e9SAndroid Build Coastguard Worker // Copy string from `src` to `dst`
dstrcpy(char * dst,const char * src)10*da0073e9SAndroid Build Coastguard Worker static __device__ void dstrcpy(char* dst, const char* src) {
11*da0073e9SAndroid Build Coastguard Worker   int i = 0;
12*da0073e9SAndroid Build Coastguard Worker   // Copy string from source to destination, ensuring that it
13*da0073e9SAndroid Build Coastguard Worker   // isn't longer than `C10_CUDA_DSA_MAX_STR_LEN-1`
14*da0073e9SAndroid Build Coastguard Worker   while (*src != '\0' && i++ < C10_CUDA_DSA_MAX_STR_LEN - 1) {
15*da0073e9SAndroid Build Coastguard Worker     *dst++ = *src++;
16*da0073e9SAndroid Build Coastguard Worker   }
17*da0073e9SAndroid Build Coastguard Worker   *dst = '\0';
18*da0073e9SAndroid Build Coastguard Worker }
19*da0073e9SAndroid Build Coastguard Worker 
dsa_add_new_assertion_failure(DeviceAssertionsData * assertions_data,const char * assertion_msg,const char * filename,const char * function_name,const int line_number,const uint32_t caller,const dim3 block_id,const dim3 thread_id)20*da0073e9SAndroid Build Coastguard Worker static __device__ void dsa_add_new_assertion_failure(
21*da0073e9SAndroid Build Coastguard Worker     DeviceAssertionsData* assertions_data,
22*da0073e9SAndroid Build Coastguard Worker     const char* assertion_msg,
23*da0073e9SAndroid Build Coastguard Worker     const char* filename,
24*da0073e9SAndroid Build Coastguard Worker     const char* function_name,
25*da0073e9SAndroid Build Coastguard Worker     const int line_number,
26*da0073e9SAndroid Build Coastguard Worker     const uint32_t caller,
27*da0073e9SAndroid Build Coastguard Worker     const dim3 block_id,
28*da0073e9SAndroid Build Coastguard Worker     const dim3 thread_id) {
29*da0073e9SAndroid Build Coastguard Worker   // `assertions_data` may be nullptr if device-side assertion checking
30*da0073e9SAndroid Build Coastguard Worker   // is disabled at run-time. If it is disabled at compile time this
31*da0073e9SAndroid Build Coastguard Worker   // function will never be called
32*da0073e9SAndroid Build Coastguard Worker   if (!assertions_data) {
33*da0073e9SAndroid Build Coastguard Worker     return;
34*da0073e9SAndroid Build Coastguard Worker   }
35*da0073e9SAndroid Build Coastguard Worker 
36*da0073e9SAndroid Build Coastguard Worker   // Atomically increment so other threads can fail at the same time
37*da0073e9SAndroid Build Coastguard Worker   // Note that incrementing this means that the CPU can observe that
38*da0073e9SAndroid Build Coastguard Worker   // a failure has happened and can begin to respond before we've
39*da0073e9SAndroid Build Coastguard Worker   // written information about that failure out to the buffer.
40*da0073e9SAndroid Build Coastguard Worker   const auto nid = atomicAdd(&(assertions_data->assertion_count), 1);
41*da0073e9SAndroid Build Coastguard Worker 
42*da0073e9SAndroid Build Coastguard Worker   if (nid >= C10_CUDA_DSA_ASSERTION_COUNT) {
43*da0073e9SAndroid Build Coastguard Worker     // At this point we're ran out of assertion buffer space.
44*da0073e9SAndroid Build Coastguard Worker     // We could print a message about this, but that'd get
45*da0073e9SAndroid Build Coastguard Worker     // spammy if a lot of threads did it, so we just silently
46*da0073e9SAndroid Build Coastguard Worker     // ignore any other assertion failures. In most cases the
47*da0073e9SAndroid Build Coastguard Worker     // failures will all probably be analogous anyway.
48*da0073e9SAndroid Build Coastguard Worker     return;
49*da0073e9SAndroid Build Coastguard Worker   }
50*da0073e9SAndroid Build Coastguard Worker 
51*da0073e9SAndroid Build Coastguard Worker   // Write information about the assertion failure to memory.
52*da0073e9SAndroid Build Coastguard Worker   // Note that this occurs only after the `assertion_count`
53*da0073e9SAndroid Build Coastguard Worker   // increment broadcasts that there's been a problem.
54*da0073e9SAndroid Build Coastguard Worker   auto& self = assertions_data->assertions[nid];
55*da0073e9SAndroid Build Coastguard Worker   dstrcpy(self.assertion_msg, assertion_msg);
56*da0073e9SAndroid Build Coastguard Worker   dstrcpy(self.filename, filename);
57*da0073e9SAndroid Build Coastguard Worker   dstrcpy(self.function_name, function_name);
58*da0073e9SAndroid Build Coastguard Worker   self.line_number = line_number;
59*da0073e9SAndroid Build Coastguard Worker   self.caller = caller;
60*da0073e9SAndroid Build Coastguard Worker   self.block_id[0] = block_id.x;
61*da0073e9SAndroid Build Coastguard Worker   self.block_id[1] = block_id.y;
62*da0073e9SAndroid Build Coastguard Worker   self.block_id[2] = block_id.z;
63*da0073e9SAndroid Build Coastguard Worker   self.thread_id[0] = thread_id.x;
64*da0073e9SAndroid Build Coastguard Worker   self.thread_id[1] = thread_id.y;
65*da0073e9SAndroid Build Coastguard Worker   self.thread_id[2] = thread_id.z;
66*da0073e9SAndroid Build Coastguard Worker }
67*da0073e9SAndroid Build Coastguard Worker 
68*da0073e9SAndroid Build Coastguard Worker // Emulates a kernel assertion. The assertion won't stop the kernel's progress,
69*da0073e9SAndroid Build Coastguard Worker // so you should assume everything the kernel produces is garbage if there's an
70*da0073e9SAndroid Build Coastguard Worker // assertion failure.
71*da0073e9SAndroid Build Coastguard Worker // NOTE: This assumes that `assertions_data` and  `assertion_caller_id` are
72*da0073e9SAndroid Build Coastguard Worker //       arguments of the kernel and therefore accessible.
73*da0073e9SAndroid Build Coastguard Worker #define CUDA_KERNEL_ASSERT2(condition)                                   \
74*da0073e9SAndroid Build Coastguard Worker   do {                                                                   \
75*da0073e9SAndroid Build Coastguard Worker     if (C10_UNLIKELY(!(condition))) {                                    \
76*da0073e9SAndroid Build Coastguard Worker       /* Has an atomic element so threads can fail at the same time */   \
77*da0073e9SAndroid Build Coastguard Worker       c10::cuda::dsa_add_new_assertion_failure(                          \
78*da0073e9SAndroid Build Coastguard Worker           assertions_data,                                               \
79*da0073e9SAndroid Build Coastguard Worker           C10_STRINGIZE(condition),                                      \
80*da0073e9SAndroid Build Coastguard Worker           __FILE__,                                                      \
81*da0073e9SAndroid Build Coastguard Worker           __FUNCTION__,                                                  \
82*da0073e9SAndroid Build Coastguard Worker           __LINE__,                                                      \
83*da0073e9SAndroid Build Coastguard Worker           assertion_caller_id,                                           \
84*da0073e9SAndroid Build Coastguard Worker           blockIdx,                                                      \
85*da0073e9SAndroid Build Coastguard Worker           threadIdx);                                                    \
86*da0073e9SAndroid Build Coastguard Worker       /* Now that the kernel has failed we early exit the kernel, but */ \
87*da0073e9SAndroid Build Coastguard Worker       /* otherwise keep going and rely on the host to check UVM and */   \
88*da0073e9SAndroid Build Coastguard Worker       /* determine we've had a problem */                                \
89*da0073e9SAndroid Build Coastguard Worker       return;                                                            \
90*da0073e9SAndroid Build Coastguard Worker     }                                                                    \
91*da0073e9SAndroid Build Coastguard Worker   } while (false)
92*da0073e9SAndroid Build Coastguard Worker #else
93*da0073e9SAndroid Build Coastguard Worker #define CUDA_KERNEL_ASSERT2(condition) assert(condition)
94*da0073e9SAndroid Build Coastguard Worker #endif
95*da0073e9SAndroid Build Coastguard Worker 
96*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda
97