xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/TriangularOps.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/TensorMeta.h>
6 #include <ATen/native/TriangularOpsUtils.h>
7 #include <ATen/TensorSubclassLikeUtils.h>
8 #include <c10/util/irange.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/arange.h>
15 #include <ATen/ops/empty_like.h>
16 #include <ATen/ops/trace_backward_native.h>
17 #include <ATen/ops/tril_native.h>
18 #include <ATen/ops/triu_native.h>
19 #include <ATen/ops/zeros.h>
20 #endif
21 
22 namespace at::meta {
23 
TORCH_META_FUNC(tril)24 TORCH_META_FUNC(tril)(const Tensor& self, int64_t k) {
25   TORCH_CHECK(self.dim() >= 2, "tril: input tensor must have at least 2 dimensions")
26   set_output_raw_strided(0, self.sizes(), {}, self.options());
27 }
28 
TORCH_META_FUNC(triu)29 TORCH_META_FUNC(triu)(const Tensor& self, int64_t k) {
30   TORCH_CHECK(self.dim() >= 2, "triu: input tensor must have at least 2 dimensions")
31   set_output_raw_strided(0, self.sizes(), {}, self.options());
32 }
33 
34 }  // namespace at::meta
35 
36 namespace at::native {
37 namespace {
38 
39 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
40 
41 template <typename scalar_t>
apply_triu_tril_single(scalar_t * result,const scalar_t * self,bool inplace,int64_t k,int64_t n,int64_t m,int64_t res_row_stride,int64_t res_col_stride,int64_t self_row_stride,int64_t self_col_stride,bool upper)42 void apply_triu_tril_single(
43     scalar_t* result,
44     const scalar_t* self,
45     bool inplace,
46     int64_t k,
47     int64_t n,
48     int64_t m,
49     int64_t res_row_stride,
50     int64_t res_col_stride,
51     int64_t self_row_stride,
52     int64_t self_col_stride,
53     bool upper) {
54   constexpr int64_t zero = 0;
55 
56   if (upper) {
57     parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
58       for (int64_t i : c10::irange(start, end)) {
59         for (int64_t j = 0; j < std::min(m, i + k); j++) {
60           result[i * res_row_stride + j * res_col_stride] = static_cast<scalar_t>(0);
61         }
62         if (!inplace) {  // copy the rest of the self if not inplace
63           for (int64_t j = std::max(zero, i + k); j < m; j++) {
64             result[i * res_row_stride + j * res_col_stride] = self[i * self_row_stride + j * self_col_stride];
65           }
66         }
67       }
68     });
69   } else {
70     parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
71       for (int64_t i : c10::irange(start, end)) {
72         for (int64_t j = std::max(zero, i + k + 1); j < m; j++) {
73           result[i * res_row_stride + j * res_col_stride] = static_cast<scalar_t>(0);
74         }
75         if (!inplace) {  // copy the rest of the self if not inplace
76           for (int64_t j = zero; j < std::min(m, i + k + 1); j++) {
77             result[i * res_row_stride + j * res_col_stride] = self[i * self_row_stride + j * self_col_stride];
78           }
79         }
80       }
81     });
82   }
83 }
84 
85 template <typename scalar_t>
apply_triu_tril(const Tensor & result,const Tensor & self,bool inplace,int64_t k,bool upper)86 void apply_triu_tril(const Tensor& result, const Tensor& self, bool inplace, int64_t k, bool upper) {
87   auto n = self.size(-2);
88   auto m = self.size(-1);
89   auto self_data = self.const_data_ptr<scalar_t>();
90   auto self_stride = (self.dim() > 2 && self.stride(-3) > 0) ? self.stride(-3) : 1;
91   auto batchsize = batchCountTrilTriu(result);
92   auto self_row_stride = self.stride(-2);
93   auto self_col_stride = self.stride(-1);
94 
95   auto result_data = result.data_ptr<scalar_t>();
96   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
97   int64_t result_stride, result_row_stride, result_col_stride;
98   if (result_data != self_data) {
99     result_stride = (result.dim() > 2 && result.stride(-3) > 0) ? result.stride(-3) : 1;
100     result_row_stride = result.stride(-2);
101     result_col_stride = result.stride(-1);
102   } else {
103     result_stride = self_stride;
104     result_row_stride = self_row_stride;
105     result_col_stride = self_col_stride;
106   }
107 
108   parallel_for(0, batchsize, 0, [&](int64_t start, int64_t end) {
109     for (const auto b : c10::irange(start, end)) {
110       const scalar_t* self_batch = &self_data[b * self_stride];
111       scalar_t* result_batch = &result_data[b * result_stride];
112       apply_triu_tril_single<scalar_t>(
113           result_batch,
114           self_batch,
115           inplace,
116           k,
117           n,
118           m,
119           result_row_stride,
120           result_col_stride,
121           self_row_stride,
122           self_col_stride,
123           upper);
124     }
125   });
126 }
127 
128 struct UpperTriangle {
129   static constexpr const char* op_name = "triu";
130   static constexpr bool upper = true;
131 };
132 
133 struct LowerTriangle {
134   static constexpr const char *op_name = "tril";
135   static constexpr bool upper = false;
136 };
137 
138 template <typename Triangle>
compute_triu_tril(const Tensor & self,int64_t k,const Tensor & result)139 void compute_triu_tril(const Tensor& self, int64_t k, const Tensor &result) {
140   if (self.numel() == 0) {
141     return;
142   }
143 
144   bool inplace_op = self.is_same(result);
145 
146   bool inplace_update = false;
147   Tensor self_c;
148   std::tie(inplace_update, self_c) = checkTrilTriuBatchContiguous(self, inplace_op);
149 
150   Tensor result_c;
151   if (inplace_op && !inplace_update) {
152     result_c = at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
153   } else {
154     result_c = result;
155   }
156 
157   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
158       ScalarType::ComplexHalf,
159       ScalarType::BFloat16,
160       ScalarType::Half,
161       ScalarType::Bool,
162       self.scalar_type(),
163       Triangle::op_name,
164       [&]{
165         apply_triu_tril<scalar_t>(result_c, self_c, inplace_op && inplace_update, k, Triangle::upper);
166       });
167 
168   if (inplace_op && !inplace_update) {
169     result.copy_(result_c);
170   }
171 }
172 
173 }  // namespace
174 
TORCH_IMPL_FUNC(tril_cpu)175 TORCH_IMPL_FUNC(tril_cpu)(const Tensor& self, int64_t k, const Tensor &result) {
176   compute_triu_tril<LowerTriangle>(self, k, result);
177 }
178 
TORCH_IMPL_FUNC(triu_cpu)179 TORCH_IMPL_FUNC(triu_cpu)(const Tensor& self, int64_t k, const Tensor &result) {
180   compute_triu_tril<UpperTriangle>(self, k, result);
181 }
182 
trace_backward_symint(const Tensor & grad,c10::SymIntArrayRef sizes)183 Tensor trace_backward_symint(const Tensor& grad, c10::SymIntArrayRef sizes) {
184   if (sizes.size() != 2) {
185     throw std::runtime_error("expected matrix input");
186   }
187 
188   auto grad_input = at::zeros_symint(sizes[0] * sizes[1], grad.options());
189   auto indices = at::arange(0, grad_input.numel(), sizes[1] + 1, grad.options().dtype(at::kLong));
190   // for composite compliance, use out-of-place variant of
191   // `index_fill` if grad tensor is a Tensor Subclass.
192   if (isTensorSubclassLike(grad)) {
193     grad_input = grad_input.index_fill(0, indices, grad);
194   } else {
195     grad_input.index_fill_(0, indices, grad);
196   }
197   return grad_input.view_symint(sizes);
198 }
199 
200 }  // namespace at::native
201