1 #pragma once
2
3 #include <c10/util/ArrayRef.h>
4 #include <c10/util/Exception.h>
5
6 #include <cstdint>
7 #include <ostream>
8 #include <vector>
9
10 // Memory format is not the property of a Tensor. It is the way to tell an
11 // operator how the result should be organized in memory and nothing more. That
12 // means memory format should never be used as return value for any tensor state
13 // interrogation functions (internally and externally).
14 //
15 // Possible options are:
16 // Preserve:
17 // If any of the input tensors is in channels_last format, operator output
18 // should be in channels_last format
19 //
20 // Contiguous:
21 // Regardless of input tensors format, the output should be contiguous
22 // Tensor.
23 //
24 // ChannelsLast:
25 // Regardless of input tensors format, the output should be in channels_last
26 // format.
27
28 namespace c10 {
29 enum class MemoryFormat : int8_t {
30 Contiguous,
31 Preserve,
32 ChannelsLast,
33 ChannelsLast3d,
34 NumOptions
35 };
36
37 // If you are seeing this, it means that this call site was not checked if
38 // the memory format could be preserved, and it was switched to old default
39 // behaviour of contiguous
40 #define LEGACY_CONTIGUOUS_MEMORY_FORMAT c10::get_contiguous_memory_format()
41
get_contiguous_memory_format()42 inline MemoryFormat get_contiguous_memory_format() {
43 return MemoryFormat::Contiguous;
44 }
45
46 inline std::ostream& operator<<(
47 std::ostream& stream,
48 at::MemoryFormat memory_format) {
49 switch (memory_format) {
50 case MemoryFormat::Preserve:
51 return stream << "Preserve";
52 case MemoryFormat::Contiguous:
53 return stream << "Contiguous";
54 case MemoryFormat::ChannelsLast:
55 return stream << "ChannelsLast";
56 case MemoryFormat::ChannelsLast3d:
57 return stream << "ChannelsLast3d";
58 default:
59 TORCH_CHECK(false, "Unknown memory format ", memory_format);
60 }
61 }
62
63 // Note: Hardcoded the channel last stride indices here to get better
64 // performance
65 template <typename T>
get_channels_last_strides_2d(ArrayRef<T> sizes)66 inline std::vector<T> get_channels_last_strides_2d(ArrayRef<T> sizes) {
67 std::vector<T> strides(sizes.size());
68 switch (sizes.size()) {
69 case 4:
70 strides[1] = 1;
71 strides[3] = sizes[1];
72 strides[2] = strides[3] * sizes[3];
73 strides[0] = strides[2] * sizes[2];
74 return strides;
75 case 3:
76 strides[0] = 1;
77 strides[2] = sizes[0];
78 strides[1] = strides[2] * sizes[2];
79 return strides;
80 default:
81 TORCH_INTERNAL_ASSERT(
82 false, "ChannelsLast2d doesn't support size ", sizes.size());
83 }
84 }
85
get_channels_last_strides_2d(IntArrayRef sizes)86 inline std::vector<int64_t> get_channels_last_strides_2d(IntArrayRef sizes) {
87 return get_channels_last_strides_2d<int64_t>(sizes);
88 }
89
90 template <typename T>
get_channels_last_strides_3d(ArrayRef<T> sizes)91 std::vector<T> get_channels_last_strides_3d(ArrayRef<T> sizes) {
92 std::vector<T> strides(sizes.size());
93 switch (sizes.size()) {
94 case 5:
95 strides[1] = 1;
96 strides[4] = sizes[1];
97 strides[3] = strides[4] * sizes[4];
98 strides[2] = strides[3] * sizes[3];
99 strides[0] = strides[2] * sizes[2];
100 return strides;
101 case 4:
102 strides[0] = 1;
103 strides[3] = sizes[0];
104 strides[2] = strides[3] * sizes[3];
105 strides[1] = strides[2] * sizes[2];
106 return strides;
107 default:
108 TORCH_INTERNAL_ASSERT(
109 false, "ChannelsLast3d doesn't support size ", sizes.size());
110 }
111 }
112
get_channels_last_strides_3d(IntArrayRef sizes)113 inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
114 return get_channels_last_strides_3d<int64_t>(sizes);
115 }
116
117 // NOTE:
118 // Below are Helper functions for is_channels_last_strides_xd.
119 // 1. Please do not combine these helper functions, each helper function handles
120 // exactly one case of sizes + memory_format, by doing this, the strides indices
121 // will be a constant array and we can access it using constant index number,
122 // the compiler will fully unroll the loop on strides indices to gain a better
123 // performance.
124 // 2. No error check in helper function, caller ensures the correctness of the
125 // input
126 // 3. All helper functions have similar comments, only 1st helper function is
127 // commented here.
128 template <typename T>
is_channels_last_strides_2d_s4(const ArrayRef<T> sizes,const ArrayRef<T> strides)129 inline bool is_channels_last_strides_2d_s4(
130 const ArrayRef<T> sizes,
131 const ArrayRef<T> strides) {
132 T min = 0;
133 // special case for trivial C dimension. default to NCHW
134 if (strides[1] == 0) {
135 return false;
136 }
137 // loop strides indices
138 for (auto& d : {1, 3, 2, 0}) {
139 if (sizes[d] == 0) {
140 return false;
141 }
142 if (strides[d] < min) {
143 return false;
144 }
145 // Fallback to NCHW as default layout for ambiguous cases
146 // This is the flaw of implicit memory_format from strides.
147 // N111 tensor with identical strides for size 1 dimension;
148 // Two cases could lead us here:
149 // a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
150 // b. N11W contiguous Tensor sliced on the W-dimension.
151 // ([N,1,1,1]@[W,W,W,W])
152 if (d == 0 && min == strides[1]) {
153 return false;
154 }
155 // This is necessary to:
156 // 1. distinguish the memory_format of N1H1;
157 // [H, 1, 1, 1] channels_last stride
158 // [H, H, 1, 1] contiguous stride
159 // 2. permutation of 1C1W:
160 // [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3)
161 // [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as channels_last
162 min = strides[d];
163 if (sizes[d] > 1) {
164 min *= sizes[d];
165 }
166 }
167 return true;
168 }
169
170 template <typename T>
is_channels_last_strides_3d_s5(const ArrayRef<T> sizes,const ArrayRef<T> strides)171 inline bool is_channels_last_strides_3d_s5(
172 const ArrayRef<T> sizes,
173 const ArrayRef<T> strides) {
174 T min = 0;
175 if (strides[1] == 0) {
176 return false;
177 }
178 for (auto& d : {1, 4, 3, 2, 0}) {
179 if (sizes[d] == 0) {
180 return false;
181 }
182 if (strides[d] < min) {
183 return false;
184 }
185 if (d == 0 && min == strides[1]) {
186 return false;
187 }
188 min = strides[d];
189 if (sizes[d] > 1) {
190 min *= sizes[d];
191 }
192 }
193 return true;
194 }
195
196 // Note [Ambiguous is_channels_last_strides_xd]
197 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
198 // The flaw of carrying memory_format implicitly through strides is very hard
199 // to WAR properly. issue #24090
200 // Without the history of permutation, we can't infer the memory_format of a
201 // tensor from the snapshot of its size & stride
202 // e.g.
203 //
204 // 1. We can NOT specify the memory_format of N111 tensor through strides in a
205 // meaningful way;
206 //
207 // 2. Two path that ended up with identical size/stride
208 // N11W contiguous tensor sliced at w-dimension becomes [N,1,1,1]@[W,W,W,W]
209 // NC11 channels_last tensor sliced at c-dimension becomes [N,1,1,1]@[C,C,C,C]
210 // So if we see a tensor [N,1,1,1]@[X,X,X,X], there's no way for us to infer
211 // the memory_format of the original tensor.
212 //
213 // Due to the limitations, our temporary WAR `is_channels_last_strides` does the
214 // best effort to infer whether the original memory_format of a tensor is
215 // at::MemoryFormat::ChannelsLast. The two objectives of this function (ordered
216 // by their importance):
217 // 1. Ensure that normal shape manipulation does not accidentally change the
218 // MemoryFormat of an existing tensor.
219 // 2. Allows user to mark MemoryFormat::ChannelsLast to tensors;
220 //
221 // The function does so via checking strides of the tensor, including strides of
222 // size-1 dimensions. Although conventionally PyTorch implies no restriction on
223 // trivial stride (stride for size-1 dimension).
224 //
225 // Note that this approach is a compromise. We did not solve the problem
226 // completely. Many cases we will not be able to infer the correct memory
227 // format.
228 // The implementation of `is_channels_last_strides` is to serve the objectives:
229 // MemoryFormat::ChannelsLast has to be explicitly opted-in (no accidental
230 // conversion); Best effort to maintain the ChannelsLast flag.
231 //
232 // Due to the fact that this is not a bulletproof solution, through testing
233 // (aten/src/ATen/test/memory_format_test.cpp)
234 // a. we ensure that the common tasks are supported;
235 // a. we identify corner cases where the implementation compromises on.
236 //
237 // By the time accumulated permutation is enabled to replace implicit
238 // memory_format through strides, we should be updating our tests and fix the
239 // issues in our tests.
240 //
241 // We use Channels Last 2d as an example above.
242 // This is a general problem for all the is_channels_last_strides_xd
243 // implementation. Please check the helper functions
244 // (is_channels_last_strides_*d_s*) for more details.
245
246 template <typename T>
is_channels_last_strides_2d(const ArrayRef<T> sizes,const ArrayRef<T> strides)247 inline bool is_channels_last_strides_2d(
248 const ArrayRef<T> sizes,
249 const ArrayRef<T> strides) {
250 switch (sizes.size()) {
251 case 4:
252 return is_channels_last_strides_2d_s4(sizes, strides);
253 // NOLINTNEXTLINE(bugprone-branch-clone)
254 case 3:
255 // TODO dim == 3 case will be enabled once it is fully tested
256 return false;
257 default:
258 return false;
259 }
260 }
261
262 template <typename T>
is_channels_last_strides_3d(const ArrayRef<T> sizes,const ArrayRef<T> strides)263 inline bool is_channels_last_strides_3d(
264 const ArrayRef<T> sizes,
265 const ArrayRef<T> strides) {
266 switch (sizes.size()) {
267 case 5:
268 return is_channels_last_strides_3d_s5(sizes, strides);
269 // NOLINTNEXTLINE(bugprone-branch-clone)
270 case 4:
271 // TODO dim == 4 case will be enabled once it is fully tested
272 return false;
273 default:
274 return false;
275 }
276 }
277
is_channels_last_strides_2d(const IntArrayRef sizes,const IntArrayRef strides)278 inline bool is_channels_last_strides_2d(
279 const IntArrayRef sizes,
280 const IntArrayRef strides) {
281 return is_channels_last_strides_2d<int64_t>(sizes, strides);
282 }
283
is_channels_last_strides_3d(const IntArrayRef sizes,const IntArrayRef strides)284 inline bool is_channels_last_strides_3d(
285 const IntArrayRef sizes,
286 const IntArrayRef strides) {
287 return is_channels_last_strides_3d<int64_t>(sizes, strides);
288 }
289
290 } // namespace c10
291