1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorMeta.h>
6 #include <ATen/native/TriangularOpsUtils.h>
7 #include <ATen/TensorSubclassLikeUtils.h>
8 #include <c10/util/irange.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/arange.h>
15 #include <ATen/ops/empty_like.h>
16 #include <ATen/ops/trace_backward_native.h>
17 #include <ATen/ops/tril_native.h>
18 #include <ATen/ops/triu_native.h>
19 #include <ATen/ops/zeros.h>
20 #endif
21
22 namespace at::meta {
23
TORCH_META_FUNC(tril)24 TORCH_META_FUNC(tril)(const Tensor& self, int64_t k) {
25 TORCH_CHECK(self.dim() >= 2, "tril: input tensor must have at least 2 dimensions")
26 set_output_raw_strided(0, self.sizes(), {}, self.options());
27 }
28
TORCH_META_FUNC(triu)29 TORCH_META_FUNC(triu)(const Tensor& self, int64_t k) {
30 TORCH_CHECK(self.dim() >= 2, "triu: input tensor must have at least 2 dimensions")
31 set_output_raw_strided(0, self.sizes(), {}, self.options());
32 }
33
34 } // namespace at::meta
35
36 namespace at::native {
37 namespace {
38
39 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
40
41 template <typename scalar_t>
apply_triu_tril_single(scalar_t * result,const scalar_t * self,bool inplace,int64_t k,int64_t n,int64_t m,int64_t res_row_stride,int64_t res_col_stride,int64_t self_row_stride,int64_t self_col_stride,bool upper)42 void apply_triu_tril_single(
43 scalar_t* result,
44 const scalar_t* self,
45 bool inplace,
46 int64_t k,
47 int64_t n,
48 int64_t m,
49 int64_t res_row_stride,
50 int64_t res_col_stride,
51 int64_t self_row_stride,
52 int64_t self_col_stride,
53 bool upper) {
54 constexpr int64_t zero = 0;
55
56 if (upper) {
57 parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
58 for (int64_t i : c10::irange(start, end)) {
59 for (int64_t j = 0; j < std::min(m, i + k); j++) {
60 result[i * res_row_stride + j * res_col_stride] = static_cast<scalar_t>(0);
61 }
62 if (!inplace) { // copy the rest of the self if not inplace
63 for (int64_t j = std::max(zero, i + k); j < m; j++) {
64 result[i * res_row_stride + j * res_col_stride] = self[i * self_row_stride + j * self_col_stride];
65 }
66 }
67 }
68 });
69 } else {
70 parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
71 for (int64_t i : c10::irange(start, end)) {
72 for (int64_t j = std::max(zero, i + k + 1); j < m; j++) {
73 result[i * res_row_stride + j * res_col_stride] = static_cast<scalar_t>(0);
74 }
75 if (!inplace) { // copy the rest of the self if not inplace
76 for (int64_t j = zero; j < std::min(m, i + k + 1); j++) {
77 result[i * res_row_stride + j * res_col_stride] = self[i * self_row_stride + j * self_col_stride];
78 }
79 }
80 }
81 });
82 }
83 }
84
85 template <typename scalar_t>
apply_triu_tril(const Tensor & result,const Tensor & self,bool inplace,int64_t k,bool upper)86 void apply_triu_tril(const Tensor& result, const Tensor& self, bool inplace, int64_t k, bool upper) {
87 auto n = self.size(-2);
88 auto m = self.size(-1);
89 auto self_data = self.const_data_ptr<scalar_t>();
90 auto self_stride = (self.dim() > 2 && self.stride(-3) > 0) ? self.stride(-3) : 1;
91 auto batchsize = batchCountTrilTriu(result);
92 auto self_row_stride = self.stride(-2);
93 auto self_col_stride = self.stride(-1);
94
95 auto result_data = result.data_ptr<scalar_t>();
96 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
97 int64_t result_stride, result_row_stride, result_col_stride;
98 if (result_data != self_data) {
99 result_stride = (result.dim() > 2 && result.stride(-3) > 0) ? result.stride(-3) : 1;
100 result_row_stride = result.stride(-2);
101 result_col_stride = result.stride(-1);
102 } else {
103 result_stride = self_stride;
104 result_row_stride = self_row_stride;
105 result_col_stride = self_col_stride;
106 }
107
108 parallel_for(0, batchsize, 0, [&](int64_t start, int64_t end) {
109 for (const auto b : c10::irange(start, end)) {
110 const scalar_t* self_batch = &self_data[b * self_stride];
111 scalar_t* result_batch = &result_data[b * result_stride];
112 apply_triu_tril_single<scalar_t>(
113 result_batch,
114 self_batch,
115 inplace,
116 k,
117 n,
118 m,
119 result_row_stride,
120 result_col_stride,
121 self_row_stride,
122 self_col_stride,
123 upper);
124 }
125 });
126 }
127
128 struct UpperTriangle {
129 static constexpr const char* op_name = "triu";
130 static constexpr bool upper = true;
131 };
132
133 struct LowerTriangle {
134 static constexpr const char *op_name = "tril";
135 static constexpr bool upper = false;
136 };
137
138 template <typename Triangle>
compute_triu_tril(const Tensor & self,int64_t k,const Tensor & result)139 void compute_triu_tril(const Tensor& self, int64_t k, const Tensor &result) {
140 if (self.numel() == 0) {
141 return;
142 }
143
144 bool inplace_op = self.is_same(result);
145
146 bool inplace_update = false;
147 Tensor self_c;
148 std::tie(inplace_update, self_c) = checkTrilTriuBatchContiguous(self, inplace_op);
149
150 Tensor result_c;
151 if (inplace_op && !inplace_update) {
152 result_c = at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
153 } else {
154 result_c = result;
155 }
156
157 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
158 ScalarType::ComplexHalf,
159 ScalarType::BFloat16,
160 ScalarType::Half,
161 ScalarType::Bool,
162 self.scalar_type(),
163 Triangle::op_name,
164 [&]{
165 apply_triu_tril<scalar_t>(result_c, self_c, inplace_op && inplace_update, k, Triangle::upper);
166 });
167
168 if (inplace_op && !inplace_update) {
169 result.copy_(result_c);
170 }
171 }
172
173 } // namespace
174
TORCH_IMPL_FUNC(tril_cpu)175 TORCH_IMPL_FUNC(tril_cpu)(const Tensor& self, int64_t k, const Tensor &result) {
176 compute_triu_tril<LowerTriangle>(self, k, result);
177 }
178
TORCH_IMPL_FUNC(triu_cpu)179 TORCH_IMPL_FUNC(triu_cpu)(const Tensor& self, int64_t k, const Tensor &result) {
180 compute_triu_tril<UpperTriangle>(self, k, result);
181 }
182
trace_backward_symint(const Tensor & grad,c10::SymIntArrayRef sizes)183 Tensor trace_backward_symint(const Tensor& grad, c10::SymIntArrayRef sizes) {
184 if (sizes.size() != 2) {
185 throw std::runtime_error("expected matrix input");
186 }
187
188 auto grad_input = at::zeros_symint(sizes[0] * sizes[1], grad.options());
189 auto indices = at::arange(0, grad_input.numel(), sizes[1] + 1, grad.options().dtype(at::kLong));
190 // for composite compliance, use out-of-place variant of
191 // `index_fill` if grad tensor is a Tensor Subclass.
192 if (isTensorSubclassLike(grad)) {
193 grad_input = grad_input.index_fill(0, indices, grad);
194 } else {
195 grad_input.index_fill_(0, indices, grad);
196 }
197 return grad_input.view_symint(sizes);
198 }
199
200 } // namespace at::native
201