1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/ceil_div.h>
3 #include <ATen/Context.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/MemoryOverlap.h>
7 #include <ATen/native/Resize.h>
8
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/diag.h>
14 #include <ATen/ops/diag_native.h>
15 #include <ATen/ops/trace_native.h>
16 #include <ATen/ops/tril_native.h>
17 #include <ATen/ops/triu_native.h>
18 #endif
19
20 #include <ATen/cuda/CUDAApplyUtils.cuh>
21
22 #define BOOL_SWITCH(COND, CONST_NAME, ...) \
23 [&] { \
24 if (COND) { \
25 constexpr static bool CONST_NAME = true; \
26 return __VA_ARGS__(); \
27 } else { \
28 constexpr static bool CONST_NAME = false; \
29 return __VA_ARGS__(); \
30 } \
31 }()
32
33 namespace at::native {
34
35 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
36
37 constexpr static int block_size = 128;
38
39 template <typename scalar_t, typename IndexType, bool upper, int elements_per_thread, bool inplace>
C10_LAUNCH_BOUNDS_1(block_size)40 C10_LAUNCH_BOUNDS_1(block_size)
41 __global__ void triu_tril_kernel(
42 cuda::detail::TensorInfo<scalar_t, IndexType> result_info,
43 const cuda::detail::TensorInfo<const scalar_t, IndexType> self_info,
44 const int64_t k,
45 const int64_t N_padded,
46 const IndexType last_dim_padded) {
47 int64_t linear_idx = (blockIdx.x * blockDim.x + threadIdx.x) * elements_per_thread;
48 if (linear_idx >= N_padded) {
49 return;
50 }
51
52 auto dims = self_info.dims;
53
54 // Compute column index amd row index
55 IndexType col = linear_idx % last_dim_padded;
56 linear_idx /= last_dim_padded;
57 IndexType row = linear_idx % self_info.sizes[dims - 2];
58
59 if constexpr (inplace) {
60 bool mask_all_true = upper ? (col - row >= k) : (col + elements_per_thread - row <= k);
61 if (mask_all_true)
62 return;
63 }
64
65 // Compute offset
66 IndexType self_offset = 0, result_offset = 0;
67 self_offset += self_info.strides[dims - 1] * col;
68 result_offset += result_info.strides[dims - 1] * col;
69 linear_idx /= self_info.sizes[dims - 2];
70 self_offset += self_info.strides[dims - 2] * row;
71 result_offset += result_info.strides[dims - 2] * row;
72
73 // Compute remaining offsets
74 IndexType running_index;
75 #pragma unroll
76 for (IndexType i = dims - 3; i >= 0; --i) {
77 running_index = linear_idx % self_info.sizes[i];
78 linear_idx /= self_info.sizes[i];
79 self_offset += running_index * self_info.strides[i];
80 result_offset += running_index * result_info.strides[i];
81 }
82
83 if constexpr (inplace) {
84 #pragma unroll
85 for (int i = 0; i < elements_per_thread && col + i < self_info.sizes[dims - 1]; i++) {
86 bool mask = upper ? (col + i - row >= k) : (col + i - row <= k);
87 if (!mask)
88 result_info.data[result_offset + i * result_info.strides[dims - 1]] = scalar_t(0);
89 }
90 } else {
91 scalar_t frag[elements_per_thread] = {};
92 bool has_mask = (upper && col + elements_per_thread - row >= k) || (!upper && col - row <= k);
93 if (has_mask) {
94 #pragma unroll
95 for (int i = 0; i < elements_per_thread && col + i < self_info.sizes[dims - 1]; i++)
96 frag[i] = self_info.data[self_offset + i * self_info.strides[dims - 1]];
97
98 #pragma unroll
99 for (int i = 0; i < elements_per_thread; i++) {
100 bool mask = upper ? (col + i - row >= k) : (col + i - row <= k);
101 frag[i] = mask ? frag[i] : scalar_t(0);
102 }
103 }
104
105 #pragma unroll
106 for (int i = 0; i < elements_per_thread && col + i < self_info.sizes[dims - 1]; i++)
107 result_info.data[result_offset + i * result_info.strides[dims - 1]] = frag[i];
108 }
109 }
110
111 template <bool upper>
triu_tril_cuda_template(const Tensor & result,const Tensor & self,int64_t k,const char * name)112 void triu_tril_cuda_template(const Tensor& result, const Tensor& self, int64_t k, const char* name) {
113 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
114 at::ScalarType::ComplexHalf,
115 at::ScalarType::Half,
116 at::ScalarType::BFloat16,
117 at::ScalarType::Bool,
118 self.scalar_type(), "triu_tril_cuda_template", [&] {
119 constexpr int elements_per_thread = sizeof(scalar_t) < 8 ? 8 / sizeof(scalar_t) : 1;
120 auto sizes = self.sizes();
121 int64_t last_dim_padded = round_up<int64_t>(sizes.back(), elements_per_thread);
122 int64_t N_padded = c10::multiply_integers(sizes.begin(), sizes.end() - 1) * last_dim_padded;
123 dim3 dim_block = block_size;
124 dim3 dim_grid((N_padded / elements_per_thread + dim_block.x - 1) / dim_block.x);
125 if (cuda::detail::canUse32BitIndexMath(result) && cuda::detail::canUse32BitIndexMath(self)) {
126 auto result_info = cuda::detail::getTensorInfo<scalar_t, int32_t>(result);
127 auto self_info = cuda::detail::getTensorInfo<const scalar_t, int32_t>(self);
128 BOOL_SWITCH(self.is_same(result), inplace, [&] {
129 triu_tril_kernel<scalar_t, int32_t, upper, elements_per_thread, inplace>
130 <<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
131 result_info, self_info, k, N_padded, last_dim_padded);
132 });
133 C10_CUDA_KERNEL_LAUNCH_CHECK();
134 } else {
135 auto result_info = cuda::detail::getTensorInfo<scalar_t, int64_t>(result);
136 auto self_info = cuda::detail::getTensorInfo<const scalar_t, int64_t>(self);
137 BOOL_SWITCH(self.is_same(result), inplace, [&] {
138 triu_tril_kernel<scalar_t, int64_t, upper, elements_per_thread, inplace>
139 <<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
140 result_info, self_info, k, N_padded, last_dim_padded);
141 });
142 C10_CUDA_KERNEL_LAUNCH_CHECK();
143 }
144 });
145 }
146
TORCH_IMPL_FUNC(tril_cuda)147 TORCH_IMPL_FUNC(tril_cuda)(const Tensor& self, int64_t k, const Tensor &result) {
148 if (self.numel() != 0) {
149 triu_tril_cuda_template<false>(result, self, k, "tril");
150 }
151 }
152
TORCH_IMPL_FUNC(triu_cuda)153 TORCH_IMPL_FUNC(triu_cuda)(const Tensor& self, int64_t k, const Tensor &result) {
154 if (self.numel() != 0) {
155 triu_tril_cuda_template<true>(result, self, k, "triu");
156 }
157 }
158
trace_cuda(const Tensor & self)159 Tensor trace_cuda(const Tensor& self) {
160 TORCH_CHECK(self.dim() == 2, "expected a matrix");
161 return self.diagonal().sum();
162 }
163
164 } // namespace at::native
165