xref: /aosp_15_r20/external/executorch/runtime/core/exec_aten/util/dim_order_util.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <cstdint>
12 
13 #include <executorch/runtime/core/error.h>
14 #include <executorch/runtime/platform/assert.h>
15 #include <executorch/runtime/platform/compiler.h>
16 
17 namespace executorch {
18 namespace runtime {
19 
20 namespace {
21 template <typename DimOrderType>
validate_dim_order(const DimOrderType * dim_order,const size_t dims)22 bool validate_dim_order(const DimOrderType* dim_order, const size_t dims) {
23   for (int32_t i = 0; i < dims; ++i) {
24     if (dim_order[i] >= dims) {
25       return false;
26     }
27   }
28   return true;
29 }
30 } // namespace
31 
32 /**
33  * Check if a given dim_order array is equivalent to the contiguous dim order of
34  * {0, 1, 2, 3, ...}
35  *
36  * @param[in] dim_order pointer to dim_order array
37  * @param[in] dims length of the dim_order array
38  */
39 template <typename DimOrderType>
is_contiguous_dim_order(const DimOrderType * dim_order,const size_t dims)40 inline bool is_contiguous_dim_order(
41     const DimOrderType* dim_order,
42     const size_t dims) {
43   for (int i = 0; i < dims; ++i) {
44     if (dim_order[i] != i) {
45       return false;
46     }
47   }
48   return true;
49 }
50 
51 /**
52  * Check if a given dim_order array is equivalent to a channels last dim order.
53  * Channels last dim order is only valid for 4-dim and 5-dim tensors.
54  *
55  * @param[in] dim_order pointer to dim_order array
56  * @param[in] dims length of the dim_order array
57  */
58 template <typename DimOrderType>
is_channels_last_dim_order(const DimOrderType * dim_order,const size_t dims)59 bool is_channels_last_dim_order(
60     const DimOrderType* dim_order,
61     const size_t dims) {
62   if (dims != 4 && dims != 5) {
63     return false;
64   }
65   // 4-dim tensor is interpreted as NCHW, 5-dim tensor is interpreted as NCHWD
66   size_t channels_dim = 1;
67   // Last value in the dim order should be the channels dim
68   if (dim_order[dims - 1] != channels_dim) {
69     return false;
70   }
71 
72   if (dim_order[0] != 0) {
73     return false;
74   }
75   int d = 1;
76   while (d < dims - 1) {
77     if (dim_order[d] != d + 1) {
78       return false;
79     }
80     d++;
81   }
82   return true;
83 }
84 
85 /*
86  * This utility translated sizes to strides by using dimension order
87  * information. Dimension order specifies how the dimensions are laid out in the
88  * memory. For example for Size = [2, 3, 4, 5] dim_names = [N, C, H, W]
89  * dim_order = [0, 2, 3, 1]
90  * strides = [60, 1, 15, 3]
91  * param[in]: sizes, pointer to sizes array
92  * param[in]: dim_order, pointer to dimension order array
93  * param[in]: dims, number of dims. Sizes and dim_order must be sizes to dims
94  * param[out]: strides, pointer to strides array that is filled in
95  *
96  * NB: Reason for not using ArrayRef is the dependency on kernel_types.h
97  * This header cannot be included, because of circular dep it causes.
98  * kernel_types depends on executorch_kernel_types in lean mode, which compiles
99  * TensorImpl.cpp. executorch_kernel_types needs to depend on dim_order_utils
100  * in order to utilize dim_order_to_stride in its resize impl. If
101  * dim_order_utils depends on kernel_type, we have circular deps. This is also
102  * the reason for templatizing this function. Better ideas welcome!
103  * TODO(T148342910)
104  *
105  * Note that this function does not check that the provided dim order is valid.
106  * This function should only be used when the validity of the dim order has been
107  * checked beforehand. A safer version of this function is provided below as
108  * dim_order_to_stride which will check that the dim order is valid.
109  */
110 template <typename SizesType, typename DimOrderType, typename StridesType>
dim_order_to_stride_nocheck(const SizesType * sizes,const DimOrderType * dim_order,const size_t dims,StridesType * strides)111 inline void dim_order_to_stride_nocheck(
112     const SizesType* sizes,
113     const DimOrderType* dim_order,
114     const size_t dims,
115     StridesType* strides) {
116   // For 0 dim tensors, just return ok.
117   if (dims == 0) {
118     return;
119   }
120   // Fastest moving dim has stride of 1.
121   // For example:
122   // Size = [2, 3, 4, 5] dim_names = [N, C, H, W]
123   // dim_order = [0, 2, 3, 1]
124   // strides = [60, 1, 15, 3]
125   strides[dim_order[dims - 1]] = 1;
126   for (int32_t i = dims - 2; i >= 0; --i) {
127     if (sizes[dim_order[i + 1]] == 0) {
128       strides[dim_order[i]] = strides[dim_order[i + 1]];
129     } else {
130       strides[dim_order[i]] =
131           strides[dim_order[i + 1]] * sizes[dim_order[i + 1]];
132     }
133   }
134 }
135 
136 template <typename SizesType, typename DimOrderType, typename StridesType>
dim_order_to_stride(const SizesType * sizes,const DimOrderType * dim_order,const size_t dims,StridesType * strides)137 ET_NODISCARD inline Error dim_order_to_stride(
138     const SizesType* sizes,
139     const DimOrderType* dim_order,
140     const size_t dims,
141     StridesType* strides) {
142   // For 0 dim tensors, just return ok.
143   if (dims == 0) {
144     return Error::Ok;
145   }
146   ET_CHECK_OR_RETURN_ERROR(
147       validate_dim_order(dim_order, dims),
148       InvalidArgument,
149       "Invalid dim order. One of the value is larger than the number of dims %zu",
150       dims);
151 
152   dim_order_to_stride_nocheck(sizes, dim_order, dims, strides);
153   return Error::Ok;
154 }
155 
156 namespace internal {
157 
158 template <typename StridesType, typename DimOrderType>
159 struct StrideDimOrder {
160   StridesType stride;
161   DimOrderType dim_order;
162 
StrideDimOrderStrideDimOrder163   StrideDimOrder(StridesType stride, DimOrderType dim_order)
164       : stride(stride), dim_order(dim_order) {}
165   StrideDimOrder() = default;
166   bool operator>(const StrideDimOrder& other) const {
167     // descending order
168     return stride < other.stride;
169   }
170 };
171 
172 template <typename ValueType>
173 struct Sorter {
174  public:
quick_sortSorter175   void quick_sort(ValueType arr[], int32_t low, int32_t high) {
176     if (low < high) {
177       ValueType pivot = arr[high];
178       int32_t pos = partition(arr, low, high, pivot);
179 
180       quick_sort(arr, low, pos - 1);
181       quick_sort(arr, pos + 1, high);
182     }
183   }
184 
185  private:
swapSorter186   void swap(ValueType arr[], int32_t pos1, int32_t pos2) noexcept {
187     ValueType temp = arr[pos1];
188     arr[pos1] = arr[pos2];
189     arr[pos2] = temp;
190   }
191 
192   int32_t
partitionSorter193   partition(ValueType arr[], int32_t low, int32_t high, ValueType pivot) {
194     int32_t i = low;
195     int32_t j = low;
196     while (i <= high) {
197       if (arr[i] > pivot) {
198         i++;
199       } else {
200         swap(arr, i++, j++);
201       }
202     }
203     return j - 1;
204   }
205 };
206 
207 } // namespace internal
208 
209 /*
210  * This utility translated strides to dimension order
211  * information. Dimension order specifies how the dimensions are laid out in the
212  * memory. For example for tensor with sizes [3, 5, 2] and strides [5, 1, 15],
213  * dim order should be [2, 0, 1], which is obtained by sorting strides in
214  * descending order. param[in]: sizes, pointer to sizes array param[in]:
215  * dim_order, pointer to dimension order array param[in]: dims, number of dims.
216  * Sizes and dim_order must be sizes to dims param[out]: strides, pointer to
217  * strides array that is filled in
218  *
219  * NB: Reason for not using ArrayRef is the dependency on kernel_types.h
220  * This header cannot be included, because of circular dep it causes.
221  * kernel_types depends on executorch_kernel_types in lean mode, which compiles
222  * TensorImpl.cpp. executorch_kernel_types needs to depend on dim_order_utils
223  * in order to utilize dim_order_to_stride in its resize impl. If
224  * dim_order_utils depends on kernel_type, we have circular deps. This is also
225  * the reason for templatizing this function. Better ideas welcome!
226  * TODO(T148342910)
227  */
228 template <typename DimOrderType, typename StridesType>
stride_to_dim_order(const StridesType * strides,const size_t dims,DimOrderType * dim_order)229 ET_NODISCARD inline Error stride_to_dim_order(
230     const StridesType* strides,
231     const size_t dims,
232     DimOrderType* dim_order) {
233   const size_t kMaxNumOfDimensions = 16;
234   ET_CHECK_OR_RETURN_ERROR(
235       dim_order != nullptr,
236       MemoryAllocationFailed,
237       "Need memory to get dim_order.");
238   ET_CHECK_OR_RETURN_ERROR(
239       dims <= kMaxNumOfDimensions,
240       NotSupported,
241       "dims %zu exceeds maximum allowed %zu",
242       dims,
243       kMaxNumOfDimensions);
244   internal::StrideDimOrder<StridesType, DimOrderType>
245       array[kMaxNumOfDimensions];
246   for (DimOrderType i = 0; i < dims; i++) {
247     array[i].dim_order = i;
248     array[i].stride = strides[i];
249   }
250 
251   internal::Sorter<internal::StrideDimOrder<StridesType, DimOrderType>> sorter;
252 
253   sorter.quick_sort(array, 0, dims - 1);
254 
255   for (auto i = 0; i < dims; i++) {
256     dim_order[i] = array[i].dim_order;
257   }
258   return Error::Ok;
259 }
260 } // namespace runtime
261 } // namespace executorch
262 
263 namespace torch {
264 namespace executor {
265 // TODO(T197294990): Remove these deprecated aliases once all users have moved
266 // to the new `::executorch` namespaces.
267 using ::executorch::runtime::dim_order_to_stride;
268 using ::executorch::runtime::dim_order_to_stride_nocheck;
269 using ::executorch::runtime::is_channels_last_dim_order;
270 using ::executorch::runtime::is_contiguous_dim_order;
271 using ::executorch::runtime::stride_to_dim_order;
272 } // namespace executor
273 } // namespace torch
274