xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/arm_conv/addressing.hpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 /* arm_conv kernels share a lot of similarities in how they address input and
26  * output tensors. Consequently, this file contains common approaches to
27  * preparing these tensor descriptions. Generic (i.e., untyped) methods are
28  * contained within the `arm_conv::addressing` namespace, and typed wrappers
29  * are provided within an anonymous namespace within `arm_conv`. The various
30  * methods are described below.
31  */
32 
33 #include <cstddef>
34 
35 namespace arm_conv {
36 namespace addressing {
37 
38 /* Pointer array
39  * -------------
40  *
41  * Constructs an array of pointers which point to a `array_rows` x `array_cols`
42  * chunk of a tensor. The array of pointers will be written into `dest`.
43  *
44  * `base_ptr` should point at the first VALID element of the chunk of tensor
45  * (i.e., if there's one padded row, and one padded column, then `base_ptr`
46  * should point at the element which will be at position (1, 1) in the array).
47  * `ld_row` and `ld_col` are in bytes, and describe the strides over rows and
48  * columns (respectively) of the NHWC-ordered tensor. `pad_buffer` should point
49  * at a suitably sized (and initialised) area of memory which can be addressed
50  * by elements of the array which represent padding.
51  *
52  * `pad_top` and `pad_left` describe the padding on the top and left of the
53  * array, respectively, and `valid_rows` and `valid_cols` describe the number
54  * of rows and columns between the element pointed to by `base_ptr` and the
55  * edge of the image (that is `valid_rows` may be greater than `array_rows` and
56  * likewise for the columns).
57  */
58 void fill_pointer_array(
59   size_t element_size,
60   void **dest, unsigned int array_rows, unsigned int array_cols,
61   void *base_ptr, size_t ld_row, size_t ld_col,
62   void *pad_buffer,
63   unsigned int pad_top, unsigned int valid_rows,
64   unsigned int pad_left, unsigned int valid_cols
65 );
66 
67 /* Interleaved multi-point pointer array
68  * -------------------------------------
69  *
70  * For each point in a `output_rows` x `output_cols` array, constructs
71  * `kernel_rows` x `kernel_cols` array of pointers. The pointers are
72  * interleaved thusly:
73  *
74  *   for ki in kernel_rows:
75  *       for kj in kernel_cols:
76  *           for oi in output_rows:
77  *               for oj in output_cols:
78  *                   get pointer for point (oi*stride_rows + ki, oj*stride_cols + kj)
79  *
80  * Other arguments are as for `fill_pointer_array`.
81  *
82  * The name reflects that this is the form of addressing mode used by "generic"
83  * depthwise and pooling kernels.
84  */
85 void fill_pointer_array_generic_kernel(
86   size_t element_size,
87   void **dest,
88   unsigned int output_rows, unsigned int output_cols,
89   unsigned int kernel_rows, unsigned int kernel_cols,
90   unsigned int stride_rows, unsigned int stride_cols,
91   void *base_ptr, size_t ld_row, size_t ld_col,
92   void *pad_buffer,
93   unsigned int pad_top, unsigned int valid_rows,
94   unsigned int pad_left, unsigned int valid_cols
95 );
96 
97 /* NCHW-patch addressed by row
98  * ---------------------------
99  *
100  * Construct an array of pointers, each of which points at a row of an
101  * NCHW-ordered patch of a tensor. Memory addressed by the pointers may be
102  * outside of the original tensor, and should therefore not be written to
103  * (modifications will be lost).
104  *
105  * `dest_row_pointers` should point at a `patch_rows` list of pointers; each of
106  * which will point at a 1 x `patch_cols` NCHW-ordered sample of the source
107  * tensor.
108  *
109  * `dest_patch` should point to a `element_size * patch_rows * patch_cols` area
110  * of memory which can be written to by this function to form samples of the
111  * source tensor.
112  *
113  * `src_ptr` should point at the first VALID element of the chunk of tensor
114  * (i.e., if there's one padded row, and one padded column, then `src_ptr`
115  * should point at the element which will be at position (1, 1) in the array).
116  * `ld_row` and `ld_col` are in bytes, and describe the strides over rows and
117  * columns (respectively) of the NHWC-ordered tensor. If `ld_col` ==
118  * `element_size` then copies from the source tensor will be elided and source
119  * data may be addressed directly.
120  *
121  * `pad_row` should point to a `patch_cols` array of (appropriately
122  * initialised) padding values.
123  *
124  * Other arguments are as for `fill_pointer_array`.
125  */
126 void fill_nchw_patch_array(
127   size_t element_size,
128   const void **dest_row_pointers,  // Array of pointers to each row of the patch
129   void *dest_patch,  // Pointer to space which can be used to construct the patch
130   unsigned int patch_rows, unsigned int patch_cols,  // Patch size
131   const void *src_ptr, size_t ld_row, size_t ld_col,  // Source tensor
132   const void *pad_row,  // Pointer to a row of padding values
133   unsigned int pad_top, unsigned int valid_rows,
134   unsigned int pad_left, unsigned int valid_cols
135 );
136 
137 void fill_patch_array_generic_kernel(
138   size_t element_size,
139   const void **dest_pointers,  // Pointers: one per output row per kernel point
140   void *dest_patch,  // Pointer to space which can be used to construct the patch
141   unsigned int output_rows, unsigned int output_cols,
142   unsigned int kernel_rows, unsigned int kernel_cols,
143   unsigned int stride_rows, unsigned int stride_cols,
144   const void *src_ptr, size_t ld_row, size_t ld_col,  // Source tensor
145   const void *pad_row,  // Pointer to a row of padding values
146   unsigned int pad_top, unsigned int valid_rows,
147   unsigned int pad_left, unsigned int valid_cols
148 );
149 
150 }  // namespace addressing
151 
152 namespace {
153 
154 /* Pointer array
155  * -------------
156  *
157  * See `addressing::fill_pointer_array`. No copies are made by this method,
158  * memory pointed to by the pointer array is contained within the base tensor
159  * and the padding buffer.
160  */
161 template <typename T>
fill_pointer_array(T ** dest,unsigned int array_rows,unsigned int array_cols,T * base_ptr,size_t ld_row,size_t ld_col,T * pad_buffer,unsigned int pad_top,unsigned int valid_rows,unsigned int pad_left,unsigned int valid_cols)162 inline void fill_pointer_array(
163   T **dest, unsigned int array_rows, unsigned int array_cols,
164   T *base_ptr, size_t ld_row, size_t ld_col,
165   T *pad_buffer,
166   unsigned int pad_top, unsigned int valid_rows,
167   unsigned int pad_left, unsigned int valid_cols
168 )
169 {
170   addressing::fill_pointer_array(
171     sizeof(T), (void **) dest, array_rows, array_cols,
172     (void *) base_ptr, ld_row, ld_col,
173     (void *) pad_buffer,
174     pad_top, valid_rows,
175     pad_left, valid_cols
176   );
177 }
178 
179 
180 /* Interleaved multi-point pointer array
181  * -------------------------------------
182  *
183  * See `addressing::fill_pointer_array_generic_kernel`. No copies are made by
184  * this method, memory pointed to by the pointer array is contained within the
185  * base tensor and the padding buffer.
186  */
187 template <typename T>
fill_pointer_array_generic_kernel(T ** dest,unsigned int output_rows,unsigned int output_cols,unsigned int kernel_rows,unsigned int kernel_cols,unsigned int stride_rows,unsigned int stride_cols,T * base_ptr,size_t ld_row,size_t ld_col,T * pad_buffer,unsigned int pad_top,unsigned int valid_rows,unsigned int pad_left,unsigned int valid_cols)188 inline void fill_pointer_array_generic_kernel(
189   T **dest,
190   unsigned int output_rows, unsigned int output_cols,
191   unsigned int kernel_rows, unsigned int kernel_cols,
192   unsigned int stride_rows, unsigned int stride_cols,
193   T *base_ptr, size_t ld_row, size_t ld_col,
194   T *pad_buffer,
195   unsigned int pad_top, unsigned int valid_rows,
196   unsigned int pad_left, unsigned int valid_cols
197 )
198 {
199   addressing::fill_pointer_array_generic_kernel(
200     sizeof(T),
201     (void **) dest,
202     output_rows, output_cols,
203     kernel_rows, kernel_cols,
204     stride_rows, stride_cols,
205     (void *) base_ptr, ld_row, ld_col,
206     (void *) pad_buffer,
207     pad_top, valid_rows,
208     pad_left, valid_cols
209   );
210 }
211 
212 template <typename T>
fill_nchw_patch_array(const T ** dest_row_pointers,T * dest_patch,unsigned int patch_rows,unsigned int patch_cols,const T * src_ptr,size_t ld_row,size_t ld_col,const T * pad_row,unsigned int pad_top,unsigned int valid_rows,unsigned int pad_left,unsigned int valid_cols)213 inline void fill_nchw_patch_array(
214   const T **dest_row_pointers,  // Array of pointers to each row of the patch
215   T *dest_patch,  // Pointer to space which can be used to construct the patch
216   unsigned int patch_rows, unsigned int patch_cols,  // Patch size
217   const T *src_ptr, size_t ld_row, size_t ld_col,  // Source tensor
218   const T *pad_row,  // Pointer to a row of padding values
219   unsigned int pad_top, unsigned int valid_rows,
220   unsigned int pad_left, unsigned int valid_cols
221 )
222 {
223   addressing::fill_nchw_patch_array(
224     sizeof(T),
225     reinterpret_cast<const void **>(dest_row_pointers),
226     reinterpret_cast<void *>(dest_patch),
227     patch_rows, patch_cols,
228     reinterpret_cast<const void *>(src_ptr), ld_row, ld_col,
229     reinterpret_cast<const void *>(pad_row),
230     pad_top, valid_rows,
231     pad_left, valid_cols
232   );
233 }
234 
235 template <typename T>
fill_patch_array_generic_kernel(const T ** dest_pointers,T * dest_patch,unsigned int output_rows,unsigned int output_cols,unsigned int kernel_rows,unsigned int kernel_cols,unsigned int stride_rows,unsigned int stride_cols,const T * src_ptr,size_t ld_row,size_t ld_col,const T * pad_row,unsigned int pad_top,unsigned int valid_rows,unsigned int pad_left,unsigned int valid_cols)236 inline void fill_patch_array_generic_kernel(
237   const T **dest_pointers,  // Pointers: one per output row per kernel point
238   T *dest_patch,  // Pointer to space which can be used to construct the patch
239   unsigned int output_rows, unsigned int output_cols,
240   unsigned int kernel_rows, unsigned int kernel_cols,
241   unsigned int stride_rows, unsigned int stride_cols,
242   const T *src_ptr, size_t ld_row, size_t ld_col,  // Source tensor
243   const T *pad_row,  // Pointer to a row of padding values
244   unsigned int pad_top, unsigned int valid_rows,
245   unsigned int pad_left, unsigned int valid_cols
246 )
247 {
248   addressing::fill_patch_array_generic_kernel(
249     sizeof(T),
250     reinterpret_cast<const void **>(dest_pointers),
251     reinterpret_cast<void *>(dest_patch),
252     output_rows, output_cols,
253     kernel_rows, kernel_cols,
254     stride_rows, stride_cols,
255     reinterpret_cast<const void *>(src_ptr), ld_row, ld_col,
256     reinterpret_cast<const void *>(pad_row),
257     pad_top, valid_rows,
258     pad_left, valid_cols
259   );
260 }
261 
262 }  // namespace {anonymous}
263 }  // namespace arm_conv
264