xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/convolution/common/shims.hpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #pragma once
26 #ifndef DOXYGEN_SKIP_THIS
27 #include <cstdint>
28 #endif /* DOXYGEN_SKIP_THIS */
29 #include "arm.hpp"
30 
31 namespace reorder {
32 /** Re-order a tensor from NCHW format to NHWC.
33  *
34  * @note The stride parameters are optional and are provided to allow padding in either input or output tensors.
35  *
36  * @param[in] in Input tensor in NCHW format.
37  * @param[out] out Output tensor, to be written in NHWC format.
38  * @param n_batches Number of batches in the tensors.
39  * @param n_channels Number of channels in the tensors
40  * @param n_rows Height of the tensor
41  * @param n_cols Width of the tensor
42  * @param in_batch_stride Stride over batches in the input tensor. If `0` defaults to `n_channels * in_channel_stride`.
43  * @param in_channel_stride Stride over channels in the input tensor. If `0` defaults to `n_rows * in_row_stride`.
44  * @param in_row_stride Stride over rows in the input tensor. If `0` defaults to `n_cols`.
45  * @param out_batch_stride Stride over batches in the output tensor. If `0` defaults to `n_rows * out_row_stride`.
46  * @param out_row_stride Stride over rows in the output tensor. If `0` defaults to `n_cols * out_col_stride`.
47  * @param out_col_stride Stride over columns in the output tensor. If `0` defaults to `n_channels`.
48  */
49 template <typename T>
50 inline void nchw_to_nhwc(
51   const T* const in,
52   T* const out,
53   const int n_batches,
54   const int n_channels,
55   const int n_rows,
56   const int n_cols,
57   int in_batch_stride=0,
58   int in_channel_stride=0,
59   int in_row_stride=0,
60   int out_batch_stride=0,
61   int out_row_stride=0,
62   int out_col_stride=0
63 );
64 
65 /** Re-order a tensor from NHWC format to NCHW.
66  *
67  * @note The stride parameters are optional and are provided to allow padding in either input or output tensors.
68  *
69  * @param[in] in Input tensor in NHWC format.
70  * @param[out] out Output tensor, to be written in NCHW format.
71  * @param n_batches Number of batches in the tensors.
72  * @param n_rows Height of the tensor
73  * @param n_cols Width of the tensor
74  * @param n_channels Number of channels in the tensors
75  * @param in_batch_stride Stride over batches in the input tensor. If `0` defaults to `n_rows * in_row_stride`.
76  * @param in_row_stride Stride over rows in the input tensor. If `0` defaults to `n_cols * in_col_stride`.
77  * @param in_col_stride Stride over columns in the input tensor. If `0` defaults to `n_channels`.
78  * @param out_batch_stride Stride over batches in the output tensor. If `0` defaults to `n_channels * out_channel_stride`.
79  * @param out_channel_stride Stride over channels in the output tensor. If `0` defaults to `n_rows * out_row_stride`.
80  * @param out_row_stride Stride over rows in the output tensor. If `0` defaults to `n_cols`.
81  */
82 template <typename T>
83 inline void nhwc_to_nchw(
84   const T* const in,  // Input data in NHWC form
85   T* const out,       // Output data in NCHW form
86   const int n_batches,
87   const int n_rows,
88   const int n_cols,
89   const int n_channels,
90   int in_batch_stride=0,
91   int in_row_stride=0,
92   int in_col_stride=0,
93   int out_batch_stride=0,
94   int out_channel_stride=0,
95   int out_row_stride=0
96 );
97 
98 /** Re-order a weight tensor from [Output feature map x Input feature map x
99  *  Height x Width] format to [Height x Width x Input feature map x Output
100  *  feature map] format.
101  */
102 template <typename T>
103 inline void ofm_ifm_h_w_to_h_w_ifm_ofm(
104   const T* const in,  // Input in [Output x Input x Height x Width] form
105   T* const out,       // Output in [Height x Width x Input x Output] form
106   const int n_output_feature_maps,
107   const int n_input_feature_maps,
108   const int n_rows,
109   const int n_cols,
110   int in_output_feature_map_stride=0,
111   int in_input_feature_map_stride=0,
112   int in_row_stride=0,
113   int out_row_stride=0,
114   int out_col_stride=0,
115   int out_input_feature_map_stride=0
116 );
117 
118 /** Re-order a weight tensor from [Height x Width x Input feature map x Output
119  *  feature map] format to [Output feature map x Input feature map x Height x
120  *  Width] format.
121  */
122 template <typename T>
123 inline void h_w_ifm_ofm_to_ofm_ifm_h_w(
124   const T* const in,  // Input in [Height x Width x Input x Output] form
125   T* const out,       // Output in [Output x Input x Height x Width] form
126   const int n_rows,
127   const int n_cols,
128   const int n_input_feature_maps,
129   const int n_output_feature_maps,
130   int in_row_stride=0,
131   int in_col_stride=0,
132   int in_input_feature_map_stride=0,
133   int out_output_feature_map_stride=0,
134   int out_input_feature_map_stride=0,
135   int out_row_stride=0
136 );
137 
138 /*****************************************************************************/
139 /* 32-bit implementation : NCHW -> NHWC
140  */
141 template <>
nchw_to_nhwc(const int32_t * const in,int32_t * const out,const int n_batches,const int n_channels,const int n_rows,const int n_cols,int in_batch_stride,int in_channel_stride,int in_row_stride,int out_batch_stride,int out_row_stride,int out_col_stride)142 inline void nchw_to_nhwc(
143   const int32_t* const in,
144   int32_t* const out,
145   const int n_batches,
146   const int n_channels,
147   const int n_rows,
148   const int n_cols,
149   int in_batch_stride,
150   int in_channel_stride,
151   int in_row_stride,
152   int out_batch_stride,
153   int out_row_stride,
154   int out_col_stride
155 )
156 {
157   typedef int32_t T;
158 
159   // Fill in the stride values
160   in_row_stride = (in_row_stride) ? in_row_stride : n_cols;
161   in_channel_stride = (in_channel_stride) ? in_channel_stride
162                                           : n_rows * in_row_stride;
163   in_batch_stride = (in_batch_stride) ? in_batch_stride
164                                       : n_channels * in_channel_stride;
165 
166   out_col_stride = (out_col_stride) ? out_col_stride : n_channels;
167   out_row_stride = (out_row_stride) ? out_row_stride : n_cols * out_col_stride;
168   out_batch_stride = (out_batch_stride) ? out_batch_stride
169                                         : n_rows * out_row_stride;
170 
171   // Perform the re-ordering
172   for (int n = 0; n < n_batches; n++)
173   {
174     const T* const in_batch = in + n*in_batch_stride;
175     T* const out_batch = out + n*out_batch_stride;
176 
177     for (int i = 0; i < n_rows; i++)
178     {
179       const T* const in_row = in_batch + i*in_row_stride;
180       T* const out_row = out_batch + i*out_row_stride;
181 
182       int j = 0, j_remaining = n_cols;
183 #ifdef __arm_any__
184       for (; j_remaining >= 4; j += 4, j_remaining -= 4)
185       {
186         int c = 0, c_remaining = n_channels;
187         for (; c_remaining >= 4; c += 4, c_remaining -= 4)
188         {
189           // Read 4 channels worth of 4 columns, then zip to produce 4 columns
190           // worth of 4 channels.
191           int32x4_t channel_pixels[4];
192           channel_pixels[0] = vld1q_s32(in_row + (c + 0)*in_channel_stride + j);
193           channel_pixels[1] = vld1q_s32(in_row + (c + 1)*in_channel_stride + j);
194           channel_pixels[2] = vld1q_s32(in_row + (c + 2)*in_channel_stride + j);
195           channel_pixels[3] = vld1q_s32(in_row + (c + 3)*in_channel_stride + j);
196 
197           const auto zip1 = vzipq_s32(channel_pixels[0], channel_pixels[2]);
198           const auto zip2 = vzipq_s32(channel_pixels[1], channel_pixels[3]);
199           const auto out_0 = vzipq_s32(zip1.val[0], zip2.val[0]);
200           const auto out_1 = vzipq_s32(zip1.val[1], zip2.val[1]);
201 
202           vst1q_s32(out_row + (j + 0)*out_col_stride + c, out_0.val[0]);
203           vst1q_s32(out_row + (j + 1)*out_col_stride + c, out_0.val[1]);
204           vst1q_s32(out_row + (j + 2)*out_col_stride + c, out_1.val[0]);
205           vst1q_s32(out_row + (j + 3)*out_col_stride + c, out_1.val[1]);
206         }
207         for (; c_remaining; c++, c_remaining--)
208         {
209           for (int _j = 0; _j < 4; _j++)
210           {
211             const T* const in_col = in_row + j + _j;
212             T* const out_col = out_row + (j + _j)*out_col_stride;
213             const T* const in_channel = in_col + c*in_channel_stride;
214             out_col[c] = *(in_channel);
215           }
216         }
217       }
218       for (; j_remaining >= 2; j += 2, j_remaining -= 2)
219       {
220         int c = 0, c_remaining = n_channels;
221         for (; c_remaining >= 2; c += 2, c_remaining -= 2)
222         {
223           // Read 2 channels worth of 2 columns, then zip to produce 2 columns
224           // worth of 2 channels.
225           int32x2_t channel_pixels[2];
226           channel_pixels[0] = vld1_s32(in_row + (c + 0)*in_channel_stride + j);
227           channel_pixels[1] = vld1_s32(in_row + (c + 1)*in_channel_stride + j);
228 
229           const auto output = vzip_s32(channel_pixels[0], channel_pixels[1]);
230 
231           vst1_s32(out_row + (j + 0)*out_col_stride + c, output.val[0]);
232           vst1_s32(out_row + (j + 1)*out_col_stride + c, output.val[1]);
233         }
234         for (; c_remaining; c++, c_remaining--)
235         {
236           for (int _j = 0; _j < 2; _j++)
237           {
238             const T* const in_col = in_row + j + _j;
239             T* const out_col = out_row + (j + _j)*out_col_stride;
240             const T* const in_channel = in_col + c*in_channel_stride;
241             out_col[c] = *(in_channel);
242           }
243         }
244       }
245 #endif  // __arm_any__
246       for (; j_remaining; j++, j_remaining--)
247       {
248         const T* const in_col = in_row + j;
249         T* const out_col = out_row + j*out_col_stride;
250 
251         for (int c = 0; c < n_channels; c++)
252         {
253           const T* const in_channel = in_col + c*in_channel_stride;
254           out_col[c] = *(in_channel);
255         }
256       }
257     }
258   }
259 }
260 
261 template <>
nchw_to_nhwc(const uint32_t * const in,uint32_t * const out,const int n_batches,const int n_channels,const int n_rows,const int n_cols,int in_batch_stride,int in_channel_stride,int in_row_stride,int out_batch_stride,int out_row_stride,int out_col_stride)262 inline void nchw_to_nhwc(
263   const uint32_t* const in,
264   uint32_t* const out,
265   const int n_batches,
266   const int n_channels,
267   const int n_rows,
268   const int n_cols,
269   int in_batch_stride,
270   int in_channel_stride,
271   int in_row_stride,
272   int out_batch_stride,
273   int out_row_stride,
274   int out_col_stride
275 )
276 {
277   nchw_to_nhwc(
278     reinterpret_cast<const int32_t*>(in),
279     reinterpret_cast<int32_t*>(out),
280     n_batches, n_channels, n_rows, n_cols,
281     in_batch_stride, in_channel_stride, in_row_stride,
282     out_batch_stride, out_row_stride, out_col_stride
283   );
284 }
285 
286 template <>
nchw_to_nhwc(const float * const in,float * const out,const int n_batches,const int n_channels,const int n_rows,const int n_cols,int in_batch_stride,int in_channel_stride,int in_row_stride,int out_batch_stride,int out_row_stride,int out_col_stride)287 inline void nchw_to_nhwc(
288   const float* const in,
289   float* const out,
290   const int n_batches,
291   const int n_channels,
292   const int n_rows,
293   const int n_cols,
294   int in_batch_stride,
295   int in_channel_stride,
296   int in_row_stride,
297   int out_batch_stride,
298   int out_row_stride,
299   int out_col_stride
300 )
301 {
302   nchw_to_nhwc(
303     reinterpret_cast<const int32_t*>(in),
304     reinterpret_cast<int32_t*>(out),
305     n_batches, n_channels, n_rows, n_cols,
306     in_batch_stride, in_channel_stride, in_row_stride,
307     out_batch_stride, out_row_stride, out_col_stride
308   );
309 }
310 
311 /*****************************************************************************/
312 /* Generic implementation : NCHW -> NHWC
313  */
314 template <typename T>
nchw_to_nhwc(const T * const in,T * const out,const int n_batches,const int n_channels,const int n_rows,const int n_cols,int in_batch_stride,int in_channel_stride,int in_row_stride,int out_batch_stride,int out_row_stride,int out_col_stride)315 inline void nchw_to_nhwc(
316   const T* const in,
317   T* const out,
318   const int n_batches,
319   const int n_channels,
320   const int n_rows,
321   const int n_cols,
322   int in_batch_stride,
323   int in_channel_stride,
324   int in_row_stride,
325   int out_batch_stride,
326   int out_row_stride,
327   int out_col_stride
328 )
329 {
330   // Fill in the stride values
331   in_row_stride = (in_row_stride) ? in_row_stride : n_cols;
332   in_channel_stride = (in_channel_stride) ? in_channel_stride
333                                           : n_rows * in_row_stride;
334   in_batch_stride = (in_batch_stride) ? in_batch_stride
335                                       : n_channels * in_channel_stride;
336 
337   out_col_stride = (out_col_stride) ? out_col_stride : n_channels;
338   out_row_stride = (out_row_stride) ? out_row_stride : n_cols * out_col_stride;
339   out_batch_stride = (out_batch_stride) ? out_batch_stride
340                                         : n_rows * out_row_stride;
341 
342   // Perform the re-ordering
343   for (int n = 0; n < n_batches; n++)
344   {
345     const T* const in_batch = in + n*in_batch_stride;
346     T* const out_batch = out + n*out_batch_stride;
347 
348     for (int i = 0; i < n_rows; i++)
349     {
350       const T* const in_row = in_batch + i*in_row_stride;
351       T* const out_row = out_batch + i*out_row_stride;
352 
353       for (int j = 0; j < n_cols; j++)
354       {
355         const T* const in_col = in_row + j;
356         T* const out_col = out_row + j*out_col_stride;
357 
358         for (int c = 0; c < n_channels; c++)
359         {
360           const T* const in_channel = in_col + c*in_channel_stride;
361           out_col[c] = *(in_channel);
362         }
363       }
364     }
365   }
366 }
367 
368 /*****************************************************************************/
369 /* 32-bit implementation : NHWC -> NCHW
370  */
371 template <>
nhwc_to_nchw(const int32_t * const in,int32_t * const out,const int n_batches,const int n_rows,const int n_cols,const int n_channels,int in_batch_stride,int in_row_stride,int in_col_stride,int out_batch_stride,int out_channel_stride,int out_row_stride)372 inline void nhwc_to_nchw(
373   const int32_t* const in,  // Input data in NHWC form
374   int32_t* const out,       // Output data in NCHW form
375   const int n_batches,
376   const int n_rows,
377   const int n_cols,
378   const int n_channels,
379   int in_batch_stride,
380   int in_row_stride,
381   int in_col_stride,
382   int out_batch_stride,
383   int out_channel_stride,
384   int out_row_stride
385 )
386 {
387   typedef int32_t T;
388 
389   // Fill in stride values
390   in_col_stride = (in_col_stride) ? in_col_stride : n_channels;
391   in_row_stride = (in_row_stride) ? in_row_stride : n_cols * in_col_stride;
392   in_batch_stride = (in_batch_stride) ? in_batch_stride
393                                       : n_rows * in_row_stride;
394 
395   out_row_stride = (out_row_stride) ? out_row_stride : n_cols;
396   out_channel_stride = (out_channel_stride) ? out_channel_stride
397                                             : n_rows * out_row_stride;
398   out_batch_stride = (out_batch_stride) ? out_batch_stride
399                                         : n_channels * out_channel_stride;
400 
401   // Perform the re-ordering
402   // For every batch
403   for (int n = 0; n < n_batches; n++)
404   {
405     const T* const in_batch = in + n*in_batch_stride;
406     T* const out_batch = out + n*out_batch_stride;
407 
408     // For every row
409     for (int i = 0; i < n_rows; i++)
410     {
411       const T* const in_i = in_batch + i*in_row_stride;
412       T* const out_i = out_batch + i*out_row_stride;
413 
414       // For every column, beginning with chunks of 4
415       int j = 0, j_remaining = n_cols;
416 #ifdef __arm_any__
417       for (; j_remaining >= 4; j += 4, j_remaining -=4)
418       {
419         // For every channel, beginning with chunks of 4
420         int c = 0, c_remaining = n_channels;
421         for (; c_remaining >= 4; c += 4, c_remaining -= 4)
422         {
423           // Read 4 columns worth of 4 channels then zip to produce 4 channels
424           // worth of 4 columns.
425           int32x4_t pixel_channels[4];
426           pixel_channels[0] = vld1q_s32(in_i + (j + 0)*in_col_stride + c);
427           pixel_channels[1] = vld1q_s32(in_i + (j + 1)*in_col_stride + c);
428           pixel_channels[2] = vld1q_s32(in_i + (j + 2)*in_col_stride + c);
429           pixel_channels[3] = vld1q_s32(in_i + (j + 3)*in_col_stride + c);
430 
431           const auto zip1 = vzipq_s32(pixel_channels[0], pixel_channels[2]);
432           const auto zip2 = vzipq_s32(pixel_channels[1], pixel_channels[3]);
433           const auto out_0 = vzipq_s32(zip1.val[0], zip2.val[0]);
434           const auto out_1 = vzipq_s32(zip1.val[1], zip2.val[1]);
435 
436           vst1q_s32(out_i + j + (c + 0)*out_channel_stride, out_0.val[0]);
437           vst1q_s32(out_i + j + (c + 1)*out_channel_stride, out_0.val[1]);
438           vst1q_s32(out_i + j + (c + 2)*out_channel_stride, out_1.val[0]);
439           vst1q_s32(out_i + j + (c + 3)*out_channel_stride, out_1.val[1]);
440         }
441         for (; c_remaining; c++, c_remaining--)
442         {
443           for (int _j = 0; _j < 4; _j++)
444           {
445             const T* const in_j = in_i + (j + _j)*in_col_stride;
446             T* const out_j = out_i + (j + _j);
447 
448             const T* const in_channel = in_j + c;
449             T* const out_channel = out_j + c*out_channel_stride;
450             *(out_channel) = *(in_channel);
451           }
452         }
453       }
454       for (; j_remaining >= 2; j += 2, j_remaining -=2)
455       {
456         int c = 0, c_remaining = n_channels;
457         for (; c_remaining >= 2; c += 2, c_remaining -= 2)
458         {
459           // Read 2 columns worth of 2 channels then zip to produce 2 channels
460           // worth of 2 columns.
461           int32x2_t pixel_channels[2];
462           pixel_channels[0] = vld1_s32(in_i + (j + 0)*in_col_stride + c);
463           pixel_channels[1] = vld1_s32(in_i + (j + 1)*in_col_stride + c);
464 
465           const auto output = vzip_s32(pixel_channels[0], pixel_channels[1]);
466 
467           vst1_s32(out_i + j + (c + 0)*out_channel_stride, output.val[0]);
468           vst1_s32(out_i + j + (c + 1)*out_channel_stride, output.val[1]);
469         }
470         for (; c_remaining; c++, c_remaining--)
471         {
472           for (int _j = 0; _j < 2; _j++)
473           {
474             const T* const in_j = in_i + (j + _j)*in_col_stride;
475             T* const out_j = out_i + (j + _j);
476 
477             const T* const in_channel = in_j + c;
478             T* const out_channel = out_j + c*out_channel_stride;
479             *(out_channel) = *(in_channel);
480           }
481         }
482       }
483 #endif  // __arm_any__
484       for (; j_remaining; j++, j_remaining--)
485       {
486         const T* const in_j = in_i + j*in_col_stride;
487         T* const out_j = out_i + j;
488 
489         // For every channel
490         for (int c = 0; c < n_channels; c++)
491         {
492           const T* const in_channel = in_j + c;
493           T* const out_channel = out_j + c*out_channel_stride;
494           *(out_channel) = *(in_channel);
495         }
496       }
497     }
498   }
499 }
500 
501 template <>
nhwc_to_nchw(const uint32_t * const in,uint32_t * const out,const int n_batches,const int n_rows,const int n_cols,const int n_channels,int in_batch_stride,int in_row_stride,int in_col_stride,int out_batch_stride,int out_channel_stride,int out_row_stride)502 inline void nhwc_to_nchw(
503   const uint32_t* const in,  // Input data in NHWC form
504   uint32_t* const out,       // Output data in NCHW form
505   const int n_batches,
506   const int n_rows,
507   const int n_cols,
508   const int n_channels,
509   int in_batch_stride,
510   int in_row_stride,
511   int in_col_stride,
512   int out_batch_stride,
513   int out_channel_stride,
514   int out_row_stride
515 )
516 {
517   // Redirect to generic 32-bit implementation
518   nhwc_to_nchw(
519     reinterpret_cast<const int32_t*>(in),
520     reinterpret_cast<int32_t*>(out),
521     n_batches, n_rows, n_cols, n_channels,
522     in_batch_stride, in_row_stride, in_col_stride,
523     out_batch_stride, out_channel_stride, out_row_stride
524   );
525 }
526 
527 template <>
nhwc_to_nchw(const float * const in,float * const out,const int n_batches,const int n_rows,const int n_cols,const int n_channels,int in_batch_stride,int in_row_stride,int in_col_stride,int out_batch_stride,int out_channel_stride,int out_row_stride)528 inline void nhwc_to_nchw(
529   const float* const in,  // Input data in NHWC form
530   float* const out,       // Output data in NCHW form
531   const int n_batches,
532   const int n_rows,
533   const int n_cols,
534   const int n_channels,
535   int in_batch_stride,
536   int in_row_stride,
537   int in_col_stride,
538   int out_batch_stride,
539   int out_channel_stride,
540   int out_row_stride
541 )
542 {
543   // Redirect to generic 32-bit implementation
544   nhwc_to_nchw(
545     reinterpret_cast<const int32_t*>(in),
546     reinterpret_cast<int32_t*>(out),
547     n_batches, n_rows, n_cols, n_channels,
548     in_batch_stride, in_row_stride, in_col_stride,
549     out_batch_stride, out_channel_stride, out_row_stride
550   );
551 }
552 
553 /*****************************************************************************/
554 /* Generic implementation : NHWC -> NCHW
555  */
556 template <typename T>
nhwc_to_nchw(const T * const in,T * const out,const int n_batches,const int n_rows,const int n_cols,const int n_channels,int in_batch_stride,int in_row_stride,int in_col_stride,int out_batch_stride,int out_channel_stride,int out_row_stride)557 inline void nhwc_to_nchw(
558   const T* const in,  // Input data in NHWC form
559   T* const out,       // Output data in NCHW form
560   const int n_batches,
561   const int n_rows,
562   const int n_cols,
563   const int n_channels,
564   int in_batch_stride,
565   int in_row_stride,
566   int in_col_stride,
567   int out_batch_stride,
568   int out_channel_stride,
569   int out_row_stride
570 )
571 {
572   // Fill in stride values
573   in_col_stride = (in_col_stride) ? in_col_stride : n_channels;
574   in_row_stride = (in_row_stride) ? in_row_stride : n_cols * in_col_stride;
575   in_batch_stride = (in_batch_stride) ? in_batch_stride
576                                       : n_rows * in_row_stride;
577 
578   out_row_stride = (out_row_stride) ? out_row_stride : n_cols;
579   out_channel_stride = (out_channel_stride) ? out_channel_stride
580                                             : n_rows * out_row_stride;
581   out_batch_stride = (out_batch_stride) ? out_batch_stride
582                                         : n_channels * out_channel_stride;
583 
584   // Perform the re-ordering
585   // For every batch
586   for (int n = 0; n < n_batches; n++)
587   {
588     const T* const in_batch = in + n*in_batch_stride;
589     T* const out_batch = out + n*out_batch_stride;
590 
591     // For every row
592     for (int i = 0; i < n_rows; i++)
593     {
594       const T* const in_i = in_batch + i*in_row_stride;
595       T* const out_i = out_batch + i*out_row_stride;
596 
597       // For every column
598       for (int j = 0; j < n_cols; j++)
599       {
600         const T* const in_j = in_i + j*in_col_stride;
601         T* const out_j = out_i + j;
602 
603         // For every channel
604         for (int c = 0; c < n_channels; c++)
605         {
606           const T* const in_channel = in_j + c;
607           T* const out_channel = out_j + c*out_channel_stride;
608           *(out_channel) = *(in_channel);
609         }
610       }
611     }
612   }
613 }
614 
615 /*****************************************************************************/
616 /* Generic weight re-order implementation.
617  */
618 template <typename T>
ofm_ifm_h_w_to_h_w_ifm_ofm(const T * const in,T * const out,const int n_output_feature_maps,const int n_input_feature_maps,const int n_rows,const int n_cols,int in_output_feature_map_stride,int in_input_feature_map_stride,int in_row_stride,int out_row_stride,int out_col_stride,int out_input_feature_map_stride)619 inline void ofm_ifm_h_w_to_h_w_ifm_ofm(
620   const T* const in,  // Input in [Output x Input x Height x Width] form
621   T* const out,       // Output in [Height x Width x Input x Output] form
622   const int n_output_feature_maps,
623   const int n_input_feature_maps,
624   const int n_rows,
625   const int n_cols,
626   int in_output_feature_map_stride,
627   int in_input_feature_map_stride,
628   int in_row_stride,
629   int out_row_stride,
630   int out_col_stride,
631   int out_input_feature_map_stride
632 )
633 {
634   // Fill in stride values
635   in_row_stride = (in_row_stride)
636     ? in_row_stride
637     : n_cols;
638   in_input_feature_map_stride = (in_input_feature_map_stride)
639     ? in_input_feature_map_stride
640     : n_rows * in_row_stride;
641   in_output_feature_map_stride = (in_output_feature_map_stride)
642     ? in_output_feature_map_stride
643     : n_input_feature_maps * in_input_feature_map_stride;
644 
645   out_input_feature_map_stride = (out_input_feature_map_stride)
646     ? out_input_feature_map_stride
647     : n_output_feature_maps;
648   out_col_stride = (out_col_stride)
649     ? out_col_stride
650     : n_input_feature_maps * out_input_feature_map_stride;
651   out_row_stride = (out_row_stride)
652     ? out_row_stride
653     : n_cols * out_col_stride;
654 
655   // Perform the re-ordering
656   for (int i = 0; i < n_rows; i++)
657   {
658     const T* const in_row = in + i * in_row_stride;
659     T* out_row = out + i * out_row_stride;
660 
661     for (int j = 0; j < n_cols; j++)
662     {
663       const T* const in_col = in_row + j;
664       T* const out_col = out_row + j * out_col_stride;
665 
666       for (int ifm = 0; ifm < n_input_feature_maps; ifm++)
667       {
668         const T* const in_ifm = in_col + ifm * in_input_feature_map_stride;
669         T* const out_ifm = out_col + ifm * out_input_feature_map_stride;
670 
671         for (int ofm = 0; ofm < n_output_feature_maps; ofm++)
672         {
673           const T* const in_ofm = in_ifm + ofm * in_output_feature_map_stride;
674           T* const out_ofm = out_ifm + ofm;
675           *(out_ofm) = *(in_ofm);
676         }
677       }
678     }
679   }
680 }
681 
682 /*****************************************************************************/
683 /* Generic weight re-order implementation.
684  */
685 template <typename T>
h_w_ifm_ofm_to_ofm_ifm_h_w(const T * const in,T * const out,const int n_rows,const int n_cols,const int n_input_feature_maps,const int n_output_feature_maps,int in_row_stride,int in_col_stride,int in_input_feature_map_stride,int out_output_feature_map_stride,int out_input_feature_map_stride,int out_row_stride)686 inline void h_w_ifm_ofm_to_ofm_ifm_h_w(
687   const T* const in,  // Input in [Height x Width x Input x Output] form
688   T* const out,       // Output in [Output x Input x Height x Width] form
689   const int n_rows,
690   const int n_cols,
691   const int n_input_feature_maps,
692   const int n_output_feature_maps,
693   int in_row_stride,
694   int in_col_stride,
695   int in_input_feature_map_stride,
696   int out_output_feature_map_stride,
697   int out_input_feature_map_stride,
698   int out_row_stride
699 )
700 {
701   // Fill in the stride values
702   in_input_feature_map_stride = (in_input_feature_map_stride)
703     ? in_input_feature_map_stride
704     : n_output_feature_maps;
705   in_col_stride = (in_col_stride)
706     ? in_col_stride
707     : n_input_feature_maps * in_input_feature_map_stride;
708   in_row_stride = (in_row_stride)
709     ? in_row_stride
710     : n_cols * in_col_stride;
711 
712   out_row_stride = (out_row_stride)
713     ? out_row_stride
714     : n_cols;
715   out_input_feature_map_stride = (out_input_feature_map_stride)
716     ? out_input_feature_map_stride
717     : n_rows * out_row_stride;
718   out_output_feature_map_stride = (out_output_feature_map_stride)
719     ? out_output_feature_map_stride
720     : n_input_feature_maps * out_input_feature_map_stride;
721 
722   // Perform the re-ordering
723   for (int i = 0; i < n_rows; i++)
724   {
725     const T* const in_row = in + i * in_row_stride;
726     T* const out_row = out + i * out_row_stride;
727 
728     for (int j = 0; j < n_cols; j++)
729     {
730       const T* const in_col = in_row + j * in_col_stride;
731       T* const out_col = out_row + j;
732 
733       for (int ifm = 0; ifm < n_input_feature_maps; ifm++)
734       {
735         const T* const in_ifm = in_col + ifm * in_input_feature_map_stride;
736         T* const out_ifm = out_col + ifm * out_input_feature_map_stride;
737 
738         for (int ofm = 0; ofm < n_output_feature_maps; ofm++)
739         {
740           const T* const in_ofm = in_ifm + ofm;
741           T* const out_ofm = out_ifm + ofm * out_output_feature_map_stride;
742           *(out_ofm) = *(in_ofm);
743         }
744       }
745     }
746   }
747 }
748 
749 }  // namespace reorder
750