1 #pragma once
2 #include <c10/core/SymBool.h>
3 #include <c10/core/SymInt.h>
4 #include <c10/util/ArrayRef.h>
5 #include <c10/util/SmallVector.h>
6 #include <c10/util/irange.h>
7
8 #include <algorithm>
9 #include <cstdint>
10
11 namespace c10 {
12
13 template <typename T>
_compute_contiguous(ArrayRef<T> sizes,ArrayRef<T> strides,T numel)14 bool _compute_contiguous(ArrayRef<T> sizes, ArrayRef<T> strides, T numel) {
15 bool is_contiguous = true;
16 if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, 0))) {
17 return is_contiguous;
18 }
19 T z = 1;
20 // NB: make sure we do signed arithmetic
21 for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
22 const auto& size_d = sizes[d];
23 if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
24 if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(strides[d], z))) {
25 z *= size_d;
26 } else {
27 is_contiguous = false;
28 break;
29 }
30 }
31 }
32 return is_contiguous;
33 }
34
35 template <typename T>
_compute_channels_last_contiguous_2d(ArrayRef<T> sizes,ArrayRef<T> strides)36 bool _compute_channels_last_contiguous_2d(
37 ArrayRef<T> sizes,
38 ArrayRef<T> strides) {
39 // Please don't combine these code, constant array is used here to let
40 // compiler fully unroll the loop to get better performance
41 switch (sizes.size()) {
42 case 4: {
43 T expected = 1;
44 for (auto& d : {1, 3, 2, 0}) {
45 const auto& size_d = sizes[d];
46 if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
47 if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected))) {
48 return false;
49 }
50 expected *= size_d;
51 }
52 }
53 return true;
54 }
55 // NOLINTNEXTLINE(bugprone-branch-clone)
56 case 3:
57 // TODO dim == 3 case will be enabled once it is fully tested
58 return false;
59 default:
60 return false;
61 }
62 }
63
64 template <typename T>
_compute_channels_last_contiguous_3d(ArrayRef<T> sizes,ArrayRef<T> strides)65 bool _compute_channels_last_contiguous_3d(
66 ArrayRef<T> sizes,
67 ArrayRef<T> strides) {
68 // Please don't combine these code, constant array is used here to let
69 // compiler fully unroll the loop to get better performance
70 switch (sizes.size()) {
71 case 5: {
72 T expected = 1;
73 for (auto& d : {1, 4, 3, 2, 0}) {
74 const auto& size_d = sizes[d];
75 if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(size_d, 1))) {
76 if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(strides[d], expected))) {
77 return false;
78 }
79 expected *= size_d;
80 }
81 }
82 return true;
83 }
84 // NOLINTNEXTLINE(bugprone-branch-clone)
85 case 4:
86 // TODO dim == 4 case will be enabled once it is fully tested
87 return false;
88 default:
89 return false;
90 }
91 }
92
93 template <typename T>
_compute_non_overlapping_and_dense(ArrayRef<T> sizes,ArrayRef<T> strides)94 bool _compute_non_overlapping_and_dense(
95 ArrayRef<T> sizes,
96 ArrayRef<T> strides) {
97 auto dim = sizes.size();
98 if (dim == 1) {
99 return sizes[0] < 2 || strides[0] == 1;
100 }
101 SmallVector<int64_t, 5> perm;
102 perm.resize(dim);
103 for (const auto i : c10::irange(dim)) {
104 perm[i] = i;
105 }
106 // Sort by strides, leaving 0 and 1 sized dims at the end of the array
107 std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) {
108 if (sizes[a] < 2) {
109 return false;
110 } else if (sizes[b] < 2) {
111 return true;
112 }
113 return strides[a] < strides[b];
114 });
115 T require_stride = 1;
116 for (const auto i : c10::irange(dim)) {
117 const auto& size_perm_i = sizes[perm[i]];
118 if (size_perm_i < 2) {
119 return true;
120 }
121 if (strides[perm[i]] != require_stride) {
122 return false;
123 }
124 require_stride *= size_perm_i;
125 }
126 return true;
127 }
128
129 } // namespace c10
130