xref: /aosp_15_r20/external/pytorch/c10/core/MemoryFormat.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/ArrayRef.h>
4 #include <c10/util/Exception.h>
5 
6 #include <cstdint>
7 #include <ostream>
8 #include <vector>
9 
10 // Memory format is not the property of a Tensor. It is the way to tell an
11 // operator how the result should be organized in memory and nothing more. That
12 // means memory format should never be used as return value for any tensor state
13 // interrogation functions (internally and externally).
14 //
15 // Possible options are:
16 //  Preserve:
17 //    If any of the input tensors is in channels_last format, operator output
18 //    should be in channels_last format
19 //
20 //  Contiguous:
21 //    Regardless of input tensors format, the output should be contiguous
22 //    Tensor.
23 //
24 //  ChannelsLast:
25 //    Regardless of input tensors format, the output should be in channels_last
26 //    format.
27 
28 namespace c10 {
29 enum class MemoryFormat : int8_t {
30   Contiguous,
31   Preserve,
32   ChannelsLast,
33   ChannelsLast3d,
34   NumOptions
35 };
36 
37 // If you are seeing this, it means that this call site was not checked if
38 // the memory format could be preserved, and it was switched to old default
39 // behaviour of contiguous
40 #define LEGACY_CONTIGUOUS_MEMORY_FORMAT c10::get_contiguous_memory_format()
41 
get_contiguous_memory_format()42 inline MemoryFormat get_contiguous_memory_format() {
43   return MemoryFormat::Contiguous;
44 }
45 
46 inline std::ostream& operator<<(
47     std::ostream& stream,
48     at::MemoryFormat memory_format) {
49   switch (memory_format) {
50     case MemoryFormat::Preserve:
51       return stream << "Preserve";
52     case MemoryFormat::Contiguous:
53       return stream << "Contiguous";
54     case MemoryFormat::ChannelsLast:
55       return stream << "ChannelsLast";
56     case MemoryFormat::ChannelsLast3d:
57       return stream << "ChannelsLast3d";
58     default:
59       TORCH_CHECK(false, "Unknown memory format ", memory_format);
60   }
61 }
62 
63 // Note: Hardcoded the channel last stride indices here to get better
64 // performance
65 template <typename T>
get_channels_last_strides_2d(ArrayRef<T> sizes)66 inline std::vector<T> get_channels_last_strides_2d(ArrayRef<T> sizes) {
67   std::vector<T> strides(sizes.size());
68   switch (sizes.size()) {
69     case 4:
70       strides[1] = 1;
71       strides[3] = sizes[1];
72       strides[2] = strides[3] * sizes[3];
73       strides[0] = strides[2] * sizes[2];
74       return strides;
75     case 3:
76       strides[0] = 1;
77       strides[2] = sizes[0];
78       strides[1] = strides[2] * sizes[2];
79       return strides;
80     default:
81       TORCH_INTERNAL_ASSERT(
82           false, "ChannelsLast2d doesn't support size ", sizes.size());
83   }
84 }
85 
get_channels_last_strides_2d(IntArrayRef sizes)86 inline std::vector<int64_t> get_channels_last_strides_2d(IntArrayRef sizes) {
87   return get_channels_last_strides_2d<int64_t>(sizes);
88 }
89 
90 template <typename T>
get_channels_last_strides_3d(ArrayRef<T> sizes)91 std::vector<T> get_channels_last_strides_3d(ArrayRef<T> sizes) {
92   std::vector<T> strides(sizes.size());
93   switch (sizes.size()) {
94     case 5:
95       strides[1] = 1;
96       strides[4] = sizes[1];
97       strides[3] = strides[4] * sizes[4];
98       strides[2] = strides[3] * sizes[3];
99       strides[0] = strides[2] * sizes[2];
100       return strides;
101     case 4:
102       strides[0] = 1;
103       strides[3] = sizes[0];
104       strides[2] = strides[3] * sizes[3];
105       strides[1] = strides[2] * sizes[2];
106       return strides;
107     default:
108       TORCH_INTERNAL_ASSERT(
109           false, "ChannelsLast3d doesn't support size ", sizes.size());
110   }
111 }
112 
get_channels_last_strides_3d(IntArrayRef sizes)113 inline std::vector<int64_t> get_channels_last_strides_3d(IntArrayRef sizes) {
114   return get_channels_last_strides_3d<int64_t>(sizes);
115 }
116 
117 // NOTE:
118 // Below are Helper functions for is_channels_last_strides_xd.
119 // 1. Please do not combine these helper functions, each helper function handles
120 // exactly one case of sizes + memory_format, by doing this, the strides indices
121 // will be a constant array and we can access it using constant index number,
122 // the compiler will fully unroll the loop on strides indices to gain a better
123 // performance.
124 // 2. No error check in helper function, caller ensures the correctness of the
125 // input
126 // 3. All helper functions have similar comments, only 1st helper function is
127 // commented here.
128 template <typename T>
is_channels_last_strides_2d_s4(const ArrayRef<T> sizes,const ArrayRef<T> strides)129 inline bool is_channels_last_strides_2d_s4(
130     const ArrayRef<T> sizes,
131     const ArrayRef<T> strides) {
132   T min = 0;
133   // special case for trivial C dimension. default to NCHW
134   if (strides[1] == 0) {
135     return false;
136   }
137   // loop strides indices
138   for (auto& d : {1, 3, 2, 0}) {
139     if (sizes[d] == 0) {
140       return false;
141     }
142     if (strides[d] < min) {
143       return false;
144     }
145     // Fallback to NCHW as default layout for ambiguous cases
146     // This is the flaw of implicit memory_format from strides.
147     // N111 tensor with identical strides for size 1 dimension;
148     // Two cases could lead us here:
149     // a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
150     // b. N11W contiguous Tensor sliced on the W-dimension.
151     // ([N,1,1,1]@[W,W,W,W])
152     if (d == 0 && min == strides[1]) {
153       return false;
154     }
155     // This is necessary to:
156     // 1. distinguish the memory_format of N1H1;
157     //     [H, 1, 1, 1] channels_last stride
158     //     [H, H, 1, 1] contiguous stride
159     // 2. permutation of 1C1W:
160     //     [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3)
161     //     [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as channels_last
162     min = strides[d];
163     if (sizes[d] > 1) {
164       min *= sizes[d];
165     }
166   }
167   return true;
168 }
169 
170 template <typename T>
is_channels_last_strides_3d_s5(const ArrayRef<T> sizes,const ArrayRef<T> strides)171 inline bool is_channels_last_strides_3d_s5(
172     const ArrayRef<T> sizes,
173     const ArrayRef<T> strides) {
174   T min = 0;
175   if (strides[1] == 0) {
176     return false;
177   }
178   for (auto& d : {1, 4, 3, 2, 0}) {
179     if (sizes[d] == 0) {
180       return false;
181     }
182     if (strides[d] < min) {
183       return false;
184     }
185     if (d == 0 && min == strides[1]) {
186       return false;
187     }
188     min = strides[d];
189     if (sizes[d] > 1) {
190       min *= sizes[d];
191     }
192   }
193   return true;
194 }
195 
196 // Note [Ambiguous is_channels_last_strides_xd]
197 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
198 // The flaw of carrying memory_format implicitly through strides is very hard
199 // to WAR properly. issue #24090
200 // Without the history of permutation, we can't infer the memory_format of a
201 // tensor from the snapshot of its size & stride
202 // e.g.
203 //
204 // 1. We can NOT specify the memory_format of N111 tensor through strides in a
205 //  meaningful way;
206 //
207 // 2. Two path that ended up with identical size/stride
208 //  N11W contiguous tensor sliced at w-dimension becomes [N,1,1,1]@[W,W,W,W]
209 //  NC11 channels_last tensor sliced at c-dimension becomes [N,1,1,1]@[C,C,C,C]
210 //    So if we see a tensor [N,1,1,1]@[X,X,X,X], there's no way for us to infer
211 //    the memory_format of the original tensor.
212 //
213 // Due to the limitations, our temporary WAR `is_channels_last_strides` does the
214 // best effort to infer whether the original memory_format of a tensor is
215 // at::MemoryFormat::ChannelsLast. The two objectives of this function (ordered
216 // by their importance):
217 //   1. Ensure that normal shape manipulation does not accidentally change the
218 //      MemoryFormat of an existing tensor.
219 //   2. Allows user to mark MemoryFormat::ChannelsLast to tensors;
220 //
221 // The function does so via checking strides of the tensor, including strides of
222 // size-1 dimensions. Although conventionally PyTorch implies no restriction on
223 // trivial stride (stride for size-1 dimension).
224 //
225 // Note that this approach is a compromise. We did not solve the problem
226 // completely. Many cases we will not be able to infer the correct memory
227 // format.
228 // The implementation of `is_channels_last_strides` is to serve the objectives:
229 // MemoryFormat::ChannelsLast has to be explicitly opted-in (no accidental
230 // conversion); Best effort to maintain the ChannelsLast flag.
231 //
232 // Due to the fact that this is not a bulletproof solution, through testing
233 // (aten/src/ATen/test/memory_format_test.cpp)
234 //   a. we ensure that the common tasks are supported;
235 //   a. we identify corner cases where the implementation compromises on.
236 //
237 // By the time accumulated permutation is enabled to replace implicit
238 // memory_format through strides, we should be updating our tests and fix the
239 // issues in our tests.
240 //
241 // We use Channels Last 2d as an example above.
242 // This is a general problem for all the is_channels_last_strides_xd
243 // implementation. Please check the helper functions
244 // (is_channels_last_strides_*d_s*) for more details.
245 
246 template <typename T>
is_channels_last_strides_2d(const ArrayRef<T> sizes,const ArrayRef<T> strides)247 inline bool is_channels_last_strides_2d(
248     const ArrayRef<T> sizes,
249     const ArrayRef<T> strides) {
250   switch (sizes.size()) {
251     case 4:
252       return is_channels_last_strides_2d_s4(sizes, strides);
253       // NOLINTNEXTLINE(bugprone-branch-clone)
254     case 3:
255       // TODO dim == 3 case will be enabled once it is fully tested
256       return false;
257     default:
258       return false;
259   }
260 }
261 
262 template <typename T>
is_channels_last_strides_3d(const ArrayRef<T> sizes,const ArrayRef<T> strides)263 inline bool is_channels_last_strides_3d(
264     const ArrayRef<T> sizes,
265     const ArrayRef<T> strides) {
266   switch (sizes.size()) {
267     case 5:
268       return is_channels_last_strides_3d_s5(sizes, strides);
269       // NOLINTNEXTLINE(bugprone-branch-clone)
270     case 4:
271       // TODO dim == 4 case will be enabled once it is fully tested
272       return false;
273     default:
274       return false;
275   }
276 }
277 
is_channels_last_strides_2d(const IntArrayRef sizes,const IntArrayRef strides)278 inline bool is_channels_last_strides_2d(
279     const IntArrayRef sizes,
280     const IntArrayRef strides) {
281   return is_channels_last_strides_2d<int64_t>(sizes, strides);
282 }
283 
is_channels_last_strides_3d(const IntArrayRef sizes,const IntArrayRef strides)284 inline bool is_channels_last_strides_3d(
285     const IntArrayRef sizes,
286     const IntArrayRef strides) {
287   return is_channels_last_strides_3d<int64_t>(sizes, strides);
288 }
289 
290 } // namespace c10
291