xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/TriangularOps.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/ceil_div.h>
3 #include <ATen/Context.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/MemoryOverlap.h>
7 #include <ATen/native/Resize.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/diag.h>
14 #include <ATen/ops/diag_native.h>
15 #include <ATen/ops/trace_native.h>
16 #include <ATen/ops/tril_native.h>
17 #include <ATen/ops/triu_native.h>
18 #endif
19 
20 #include <ATen/cuda/CUDAApplyUtils.cuh>
21 
22 #define BOOL_SWITCH(COND, CONST_NAME, ...)      \
23   [&] {                                         \
24     if (COND) {                                 \
25       constexpr static bool CONST_NAME = true;  \
26       return __VA_ARGS__();                     \
27     } else {                                    \
28       constexpr static bool CONST_NAME = false; \
29       return __VA_ARGS__();                     \
30     }                                           \
31   }()
32 
33 namespace at::native {
34 
35 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
36 
37 constexpr static int block_size = 128;
38 
39 template <typename scalar_t, typename IndexType, bool upper, int elements_per_thread, bool inplace>
C10_LAUNCH_BOUNDS_1(block_size)40 C10_LAUNCH_BOUNDS_1(block_size)
41 __global__ void triu_tril_kernel(
42     cuda::detail::TensorInfo<scalar_t, IndexType> result_info,
43     const cuda::detail::TensorInfo<const scalar_t, IndexType> self_info,
44     const int64_t k,
45     const int64_t N_padded,
46     const IndexType last_dim_padded) {
47   int64_t linear_idx = (blockIdx.x * blockDim.x + threadIdx.x) * elements_per_thread;
48   if (linear_idx >= N_padded) {
49     return;
50   }
51 
52   auto dims = self_info.dims;
53 
54   // Compute column index amd row index
55   IndexType col = linear_idx % last_dim_padded;
56   linear_idx /= last_dim_padded;
57   IndexType row = linear_idx % self_info.sizes[dims - 2];
58 
59   if constexpr (inplace) {
60     bool mask_all_true = upper ? (col - row >= k) : (col + elements_per_thread - row <= k);
61     if (mask_all_true)
62       return;
63   }
64 
65   // Compute offset
66   IndexType self_offset = 0, result_offset = 0;
67   self_offset += self_info.strides[dims - 1] * col;
68   result_offset += result_info.strides[dims - 1] * col;
69   linear_idx /= self_info.sizes[dims - 2];
70   self_offset += self_info.strides[dims - 2] * row;
71   result_offset += result_info.strides[dims - 2] * row;
72 
73   // Compute remaining offsets
74   IndexType running_index;
75   #pragma unroll
76   for (IndexType i = dims - 3; i >= 0; --i) {
77     running_index = linear_idx % self_info.sizes[i];
78     linear_idx /= self_info.sizes[i];
79     self_offset += running_index * self_info.strides[i];
80     result_offset += running_index * result_info.strides[i];
81   }
82 
83   if constexpr (inplace) {
84     #pragma unroll
85     for (int i = 0; i < elements_per_thread && col + i < self_info.sizes[dims - 1]; i++) {
86       bool mask = upper ? (col + i - row >= k) : (col + i - row <= k);
87       if (!mask)
88         result_info.data[result_offset + i * result_info.strides[dims - 1]] = scalar_t(0);
89     }
90   } else {
91     scalar_t frag[elements_per_thread] = {};
92     bool has_mask = (upper && col + elements_per_thread - row >= k) || (!upper && col - row <= k);
93     if (has_mask) {
94       #pragma unroll
95       for (int i = 0; i < elements_per_thread && col + i < self_info.sizes[dims - 1]; i++)
96         frag[i] = self_info.data[self_offset + i * self_info.strides[dims - 1]];
97 
98       #pragma unroll
99       for (int i = 0; i < elements_per_thread; i++) {
100         bool mask = upper ? (col + i - row >= k) : (col + i - row <= k);
101         frag[i] = mask ? frag[i] : scalar_t(0);
102       }
103     }
104 
105     #pragma unroll
106     for (int i = 0; i < elements_per_thread && col + i < self_info.sizes[dims - 1]; i++)
107       result_info.data[result_offset + i * result_info.strides[dims - 1]] = frag[i];
108   }
109 }
110 
111 template <bool upper>
triu_tril_cuda_template(const Tensor & result,const Tensor & self,int64_t k,const char * name)112 void triu_tril_cuda_template(const Tensor& result, const Tensor& self, int64_t k, const char* name) {
113   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
114       at::ScalarType::ComplexHalf,
115       at::ScalarType::Half,
116       at::ScalarType::BFloat16,
117       at::ScalarType::Bool,
118       self.scalar_type(), "triu_tril_cuda_template", [&] {
119     constexpr int elements_per_thread = sizeof(scalar_t) < 8 ? 8 / sizeof(scalar_t) : 1;
120     auto sizes = self.sizes();
121     int64_t last_dim_padded = round_up<int64_t>(sizes.back(), elements_per_thread);
122     int64_t N_padded = c10::multiply_integers(sizes.begin(), sizes.end() - 1) * last_dim_padded;
123     dim3 dim_block = block_size;
124     dim3 dim_grid((N_padded / elements_per_thread + dim_block.x - 1) / dim_block.x);
125     if (cuda::detail::canUse32BitIndexMath(result) && cuda::detail::canUse32BitIndexMath(self)) {
126       auto result_info = cuda::detail::getTensorInfo<scalar_t, int32_t>(result);
127       auto self_info = cuda::detail::getTensorInfo<const scalar_t, int32_t>(self);
128       BOOL_SWITCH(self.is_same(result), inplace, [&] {
129         triu_tril_kernel<scalar_t, int32_t, upper, elements_per_thread, inplace>
130           <<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
131             result_info, self_info, k, N_padded, last_dim_padded);
132       });
133       C10_CUDA_KERNEL_LAUNCH_CHECK();
134     } else {
135       auto result_info = cuda::detail::getTensorInfo<scalar_t, int64_t>(result);
136       auto self_info = cuda::detail::getTensorInfo<const scalar_t, int64_t>(self);
137       BOOL_SWITCH(self.is_same(result), inplace, [&] {
138         triu_tril_kernel<scalar_t, int64_t, upper, elements_per_thread, inplace>
139           <<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
140             result_info, self_info, k, N_padded, last_dim_padded);
141       });
142       C10_CUDA_KERNEL_LAUNCH_CHECK();
143     }
144   });
145 }
146 
TORCH_IMPL_FUNC(tril_cuda)147 TORCH_IMPL_FUNC(tril_cuda)(const Tensor& self, int64_t k, const Tensor &result) {
148   if (self.numel() != 0) {
149     triu_tril_cuda_template<false>(result, self, k, "tril");
150   }
151 }
152 
TORCH_IMPL_FUNC(triu_cuda)153 TORCH_IMPL_FUNC(triu_cuda)(const Tensor& self, int64_t k, const Tensor &result) {
154   if (self.numel() != 0) {
155     triu_tril_cuda_template<true>(result, self, k, "triu");
156   }
157 }
158 
trace_cuda(const Tensor & self)159 Tensor trace_cuda(const Tensor& self) {
160   TORCH_CHECK(self.dim() == 2, "expected a matrix");
161   return self.diagonal().sum();
162 }
163 
164 } // namespace at::native
165