xref: /aosp_15_r20/external/pytorch/c10/core/Contiguity.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/core/SymBool.h>
3 #include <c10/core/SymInt.h>
4 #include <c10/util/ArrayRef.h>
5 #include <c10/util/SmallVector.h>
6 #include <c10/util/irange.h>
7 
8 #include <algorithm>
9 #include <cstdint>
10 
11 namespace c10 {
12 
13 template <typename T>
_compute_contiguous(ArrayRef<T> sizes,ArrayRef<T> strides,T numel)14 bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
15   bool is_contiguous = true;
16   if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0))) {
17     return is_contiguous;
18   }
19   T z = 1;
20   // NB: make sure we do signed arithmetic
21   for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
22     const auto& size_d = sizes[d];
23     if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
24       if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(strides[d], z))) {
25         z *= size_d;
26       } else {
27         is_contiguous = false;
28         break;
29       }
30     }
31   }
32   return is_contiguous;
33 }
34 
35 template <typename T>
_compute_channels_last_contiguous_2d(ArrayRef<T> sizes,ArrayRef<T> strides)36 bool _compute_channels_last_contiguous_2d(
37     ArrayRef<T> sizes,
38     ArrayRef<T> strides) {
39   // Please don't combine these code, constant array is used here to let
40   // compiler fully unroll the loop to get better performance
41   switch (sizes.size()) {
42     case 4: {
43       T expected = 1;
44       for (auto& d : {1, 3, 2, 0}) {
45         const auto& size_d = sizes[d];
46         if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
47           if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected))) {
48             return false;
49           }
50           expected *= size_d;
51         }
52       }
53       return true;
54     }
55       // NOLINTNEXTLINE(bugprone-branch-clone)
56     case 3:
57       // TODO dim == 3 case will be enabled once it is fully tested
58       return false;
59     default:
60       return false;
61   }
62 }
63 
64 template <typename T>
_compute_channels_last_contiguous_3d(ArrayRef<T> sizes,ArrayRef<T> strides)65 bool _compute_channels_last_contiguous_3d(
66     ArrayRef<T> sizes,
67     ArrayRef<T> strides) {
68   // Please don't combine these code, constant array is used here to let
69   // compiler fully unroll the loop to get better performance
70   switch (sizes.size()) {
71     case 5: {
72       T expected = 1;
73       for (auto& d : {1, 4, 3, 2, 0}) {
74         const auto& size_d = sizes[d];
75         if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
76           if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected))) {
77             return false;
78           }
79           expected *= size_d;
80         }
81       }
82       return true;
83     }
84       // NOLINTNEXTLINE(bugprone-branch-clone)
85     case 4:
86       // TODO dim == 4 case will be enabled once it is fully tested
87       return false;
88     default:
89       return false;
90   }
91 }
92 
93 template <typename T>
_compute_non_overlapping_and_dense(ArrayRef<T> sizes,ArrayRef<T> strides)94 bool _compute_non_overlapping_and_dense(
95     ArrayRef<T> sizes,
96     ArrayRef<T> strides) {
97   auto dim = sizes.size();
98   if (dim == 1) {
99     return sizes[0] < 2 || strides[0] == 1;
100   }
101   SmallVector<int64_t, 5> perm;
102   perm.resize(dim);
103   for (const auto i : c10::irange(dim)) {
104     perm[i] = i;
105   }
106   // Sort by strides, leaving 0 and 1 sized dims at the end of the array
107   std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) {
108     if (sizes[a] < 2) {
109       return false;
110     } else if (sizes[b] < 2) {
111       return true;
112     }
113     return strides[a] < strides[b];
114   });
115   T require_stride = 1;
116   for (const auto i : c10::irange(dim)) {
117     const auto& size_perm_i = sizes[perm[i]];
118     if (size_perm_i < 2) {
119       return true;
120     }
121     if (strides[perm[i]] != require_stride) {
122       return false;
123     }
124     require_stride *= size_perm_i;
125   }
126   return true;
127 }
128 
129 } // namespace c10
130