xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/TriangularOpsUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/Tensor.h>
2 #include <ATen/native/LinearAlgebraUtils.h>
3 
4 namespace at::native {
5 
6 /*
7  * Given batches of matrices with arbitrary batch dim,
8  * computes the number of batches for Triu and Tril. This ignores stride 0 dimension
9  */
batchCountTrilTriu(const Tensor & batched_matrices)10 static inline int64_t batchCountTrilTriu(const Tensor& batched_matrices) {
11   int64_t result = 1;
12   for (int64_t i = 0; i < batched_matrices.ndimension() - 2; i++) {
13     if (batched_matrices.stride(i) != 0) {
14       result *= batched_matrices.size(i);
15     }
16   }
17   return result;
18 }
19 
20 /* Checks a necessary property for the triu and tril implementations, hence the name.
21  * Here batch contiguity is checked for tensors with greater than 4 dimensions.
22  * Contiguous tensors and tensors with less than 3 dimensions pass this check
23  */
checkTrilTriuBatchContiguous(const Tensor & tensor,bool allow_zero_stride)24 static inline std::tuple<bool, Tensor> checkTrilTriuBatchContiguous(const Tensor& tensor, bool allow_zero_stride) {
25   // Complete contiguity is the most desired property, which is why
26   // we return true if the tensor is contiguous
27   if (tensor.is_contiguous()) {
28     auto default_strides_for_size = batched_matrix_contiguous_strides(tensor.sizes());
29     if (tensor.strides() == default_strides_for_size) {
30       return std::make_tuple(true, tensor);
31     } else {
32       return std::make_tuple(false, tensor.as_strided(tensor.sizes(), default_strides_for_size));
33     }
34   }
35 
36   int64_t dims = tensor.dim();
37 
38   // Tensors with dimension less than 4 are handled by default
39   if (allow_zero_stride && dims <= 3) {
40     return std::make_tuple(true, tensor);
41   }
42 
43   int64_t expected_stride = tensor.size(-1) * tensor.size(-2);
44   for (int64_t i = dims - 3; i >= 0; i--) {
45     // Skip trivial dimension;
46     if (allow_zero_stride && i == 0 && (tensor.stride(i) == 0 || tensor.size(i) == 1)) {
47       continue;
48     }
49     if (expected_stride != tensor.stride(i)) {
50       return std::make_tuple(false, tensor.contiguous());
51     }
52     expected_stride *= tensor.size(i);
53   }
54   return std::make_tuple(true, tensor);
55 }
56 
57 }  // namespace at::native
58