xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/detail/NEDirectConvolutionDetail.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #ifndef ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H
26 #define ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H
27 
28 #include "src/core/NEON/NEFixedPoint.h"
29 #include "src/core/NEON/wrapper/wrapper.h"
30 #include "support/Requires.h"
31 
32 #include <arm_neon.h>
33 
34 namespace arm_compute
35 {
36 namespace detail
37 {
38 /** Loads a 3x3 matrix as a row  (float).
39  *
40  * @param[in] ptr            Pointer to a float 3x3 matrix.
41  * @param[in] weights_offset (Optional) Weights quantization offset.
42  *
43  * @return The loaded matrix.
44  */
45 inline float32x4x3_t load_matrix_row(const float *ptr, int weights_offset = 0)
46 {
47     ARM_COMPUTE_UNUSED(weights_offset);
48     const float32x4x3_t r =
49     {
50         {
51             vld1q_dup_f32(ptr),
52             vld1q_dup_f32(1 + ptr),
53             vld1q_dup_f32(2 + ptr)
54         }
55     };
56     return r;
57 }
58 
59 /** Loads a 3x3 matrix as a row (uint8_t/int8_t).
60  *
61  * @param[in] ptr            Pointer to a uint8_t/int8_t 3x3 matrix.
62  * @param[in] weights_offset (Optional) Weights quantization offset.
63  *
64  * @return The loaded matrix.
65  */
66 template < typename T, ARM_COMPUTE_REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
67 inline int32x4x3_t load_matrix_row(const T *ptr, int weights_offset = 0)
68 {
69     const int32x4_t v_weights_offset = vdupq_n_s32(weights_offset);
70 
71     /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
72        r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
73     int32x4x3_t r =
74     {
75         {
76             vaddq_s32(v_weights_offset, vdupq_n_s32(*ptr)),
77             vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 1))),
78             vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 2)))
79         }
80     };
81     return r;
82 }
83 
84 /** Stores a float32x4x2_t array into a memory location.
85  *
86  * @param[in] buffer Pointer to the memory location where the values will be stored.
87  * @param[in] values Values that will be stored.
88  *
89  */
90 template <unsigned int stridex>
91 void store_results(float *buffer, const float32x4x2_t &values);
92 
93 template <>
94 inline void store_results<1>(float *buffer, const float32x4x2_t &values)
95 {
96     vst1q_f32(buffer, values.val[0]);
97     vst1q_f32(buffer + 4, values.val[1]);
98 }
99 
100 template <>
101 inline void store_results<2>(float *buffer, const float32x4x2_t &values)
102 {
103     vst1q_f32(buffer, values.val[0]);
104 }
105 
106 template <>
107 inline void store_results<3>(float *buffer, const float32x4x2_t &values)
108 {
109     vst1_f32(buffer, vget_low_f32(values.val[0]));
110 }
111 
112 /** Stores a uint32_t array into a memory location.
113  *
114  * @param[in] buffer Pointer to the memory location where the values will be stored.
115  * @param[in] values Values that will be stored.
116  *
117  */
118 template <unsigned int stridex>
119 void store_results(int32_t *buffer, const int32x4x2_t &values);
120 
121 template <>
122 inline void store_results<1>(int32_t *buffer, const int32x4x2_t &values)
123 {
124     vst1q_s32(buffer, values.val[0]);
125     vst1q_s32(buffer + 4, values.val[1]);
126 }
127 
128 template <>
129 inline void store_results<2>(int32_t *buffer, const int32x4x2_t &values)
130 {
131     vst1q_s32(buffer, values.val[0]);
132 }
133 
134 template <>
135 inline void store_results<3>(int32_t *buffer, const int32x4x2_t &values)
136 {
137     vst1_s32(buffer, vget_low_s32(values.val[0]));
138 }
139 
140 template <unsigned int stridex>
141 inline void accumulate_results(float *buffer, const float32x4x2_t &values);
142 
143 template <>
144 inline void accumulate_results<1>(float *buffer, const float32x4x2_t &values)
145 {
146     vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
147     vst1q_f32(buffer + 4, vaddq_f32(vld1q_f32(buffer + 4), values.val[1]));
148 }
149 
150 template <>
151 inline void accumulate_results<2>(float *buffer, const float32x4x2_t &values)
152 {
153     vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
154 }
155 
156 template <>
157 inline void accumulate_results<3>(float *buffer, const float32x4x2_t &values)
158 {
159     vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
160 }
161 
162 template <unsigned int stridex>
163 void accumulate_results(int32_t *buffer, const int32x4x2_t &values);
164 
165 template <>
166 inline void accumulate_results<1>(int32_t *buffer, const int32x4x2_t &values)
167 {
168     vst1q_s32(buffer, vaddq_s32(vld1q_s32(buffer), values.val[0]));
169     vst1q_s32(buffer + 4, vaddq_s32(vld1q_s32(buffer + 4), values.val[1]));
170 }
171 
172 template <>
173 inline void accumulate_results<2>(int32_t *buffer, const int32x4x2_t &values)
174 {
175     vst1q_s32(buffer, vaddq_s32(vld1q_s32(buffer), values.val[0]));
176 }
177 
178 template <>
179 inline void accumulate_results<3>(int32_t *buffer, const int32x4x2_t &values)
180 {
181     vst1_s32(buffer, vadd_s32(vld1_s32(buffer), vget_low_s32(values.val[0])));
182 }
183 
184 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
185 /** Stores a float16x8x2_t array into a memory location.
186  *
187  * @param[in] buffer Pointer to the memory location where the values will be stored.
188  * @param[in] values Values that will be stored.
189  *
190  */
191 template <unsigned int stridex>
192 void store_results(float16_t *buffer, const float16x8x2_t &values);
193 
194 template <>
195 inline void store_results<1>(float16_t *buffer, const float16x8x2_t &values)
196 {
197     vst1q_f16(buffer, values.val[0]);
198     vst1q_f16(buffer + 8, values.val[1]);
199 }
200 
201 template <>
202 inline void store_results<2>(float16_t *buffer, const float16x8x2_t &values)
203 {
204     vst1q_f16(buffer, values.val[0]);
205 }
206 
207 template <>
208 inline void store_results<3>(float16_t *buffer, const float16x8x2_t &values)
209 {
210     vst1_f16(buffer, vget_low_f16(values.val[0]));
211 }
212 
213 template <unsigned int stridex>
214 inline void accumulate_results(float16_t *buffer, const float16x8x2_t &values);
215 
216 template <>
217 inline void accumulate_results<1>(float16_t *buffer, const float16x8x2_t &values)
218 {
219     vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
220     vst1q_f16(buffer + 8, vaddq_f16(vld1q_f16(buffer + 8), values.val[1]));
221 }
222 
223 template <>
224 inline void accumulate_results<2>(float16_t *buffer, const float16x8x2_t &values)
225 {
226     vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
227 }
228 
229 template <>
230 inline void accumulate_results<3>(float16_t *buffer, const float16x8x2_t &values)
231 {
232     vst1_f16(buffer, vadd_f16(vld1_f16(buffer), vget_low_f16(values.val[0])));
233 }
234 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
235 
236 /** Perform a 3x3 convolution for 4 consecutive elements on float32 when dilation.x() or dilation.y() is not 1.
237  *
238  * @param[in] in_top       Pointer to the first row of the input.
239  * @param[in] in_mid       Pointer to the second row of the input.
240  * @param[in] in_low       Pointer to the third row of the input.
241  * @param[in] m0           First row of the filter.
242  * @param[in] m1           Second row of the filter.
243  * @param[in] m2           Third row of the filter.
244  * @param[in] dilation_x   Dilation, in elements across x.
245  * @param[in] input_offset (Optional) Input quantization offset.
246  *
247  */
single_convolve_3x3_dilation(const float * in_top,const float * in_mid,const float * in_low,const float32x4x3_t & m0,const float32x4x3_t & m1,const float32x4x3_t & m2,const size_t dilation_x,int input_offset)248 inline float32x4_t single_convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low,
249                                                 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
250                                                 const size_t dilation_x, int input_offset)
251 {
252     ARM_COMPUTE_UNUSED(input_offset);
253 
254     const float32x4x3_t vtop =
255     {
256         {
257             vld1q_f32(in_top),
258             vld1q_f32(in_top + dilation_x),
259             vld1q_f32(in_top + 2 * dilation_x)
260         }
261     };
262     const float32x4x3_t vmid =
263     {
264         {
265             vld1q_f32(in_mid),
266             vld1q_f32(in_mid + dilation_x),
267             vld1q_f32(in_mid + 2 * dilation_x)
268         }
269     };
270     const float32x4x3_t vlow =
271     {
272         {
273             vld1q_f32(in_low),
274             vld1q_f32(in_low + dilation_x),
275             vld1q_f32(in_low + 2 * dilation_x)
276         }
277     };
278     float32x4_t out = vmulq_f32(vtop.val[0], m0.val[0]);
279     out             = vmlaq_f32(out, vtop.val[1], m0.val[1]);
280     out             = vmlaq_f32(out, vtop.val[2], m0.val[2]);
281 
282     out = vmlaq_f32(out, vmid.val[0], m1.val[0]);
283     out = vmlaq_f32(out, vmid.val[1], m1.val[1]);
284     out = vmlaq_f32(out, vmid.val[2], m1.val[2]);
285 
286     out = vmlaq_f32(out, vlow.val[0], m2.val[0]);
287     out = vmlaq_f32(out, vlow.val[1], m2.val[1]);
288     out = vmlaq_f32(out, vlow.val[2], m2.val[2]);
289 
290     return out;
291 }
292 
293 /** Perform a 3x3 convolution for 8 consecutive elements on float32 when dilation.x() or dilation.y() is not 1.
294  *
295  * @param[in] in_top       Pointer to the first row of the input.
296  * @param[in] in_mid       Pointer to the second row of the input.
297  * @param[in] in_low       Pointer to the third row of the input.
298  * @param[in] m0           First row of the filter.
299  * @param[in] m1           Second row of the filter.
300  * @param[in] m2           Third row of the filter.
301  * @param[in] dilation_x   Dilation, in elements across x.
302  * @param[in] stridex      Stride value in elements across x.
303  * @param[in] input_offset (Optional) Input quantization offset.
304  *
305  */
306 inline float32x4x2_t convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low,
307                                            const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
308                                            const size_t dilation_x, unsigned int stridex, int input_offset = 0)
309 {
310     ARM_COMPUTE_ERROR_ON(stridex > 3);
311     float32x4x2_t out =
312     {
313         {
314             single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
315             single_convolve_3x3_dilation(in_top + 4, in_mid + 4, in_low + 4, m0, m1, m2, dilation_x, input_offset)
316         }
317     };
318 
319     if(stridex == 2)
320     {
321         out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
322         out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
323         out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
324     }
325     else if(stridex == 3)
326     {
327         out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
328     }
329 
330     return out;
331 }
332 
333 /** Perform a convolve3x3 on float32.
334  *
335  * @param[in]  in_top       Pointer to the first row of the input.
336  * @param[in]  in_mid       Pointer to the second row of the input.
337  * @param[in]  in_low       Pointer to the third row of the input.
338  * @param[out] out_ptr      Pointer to the output.
339  * @param[in]  m0           First row of the filter.
340  * @param[in]  m1           Second row of the filter.
341  * @param[in]  m2           Third row of the filter.
342  * @param[in]  stridex      Stride value in elements across x.
343  * @param[in]  input_offset (Optional) Input quantization offset.
344  *
345  */
346 template <bool accumulate>
347 void convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, float *out_ptr,
348                   const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
349                   unsigned int stridex, int input_offset = 0);
350 
351 template <bool accumulate>
convolve_3x3(const float * in_top,const float * in_mid,const float * in_low,float * out_ptr,const float32x4x3_t & m0,const float32x4x3_t & m1,const float32x4x3_t & m2,unsigned int stridex,int input_offset)352 inline void convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, float *out_ptr,
353                          const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
354                          unsigned int stridex, int input_offset)
355 {
356     ARM_COMPUTE_UNUSED(input_offset);
357     ARM_COMPUTE_ERROR_ON(stridex > 3);
358 
359     float32x4x2_t out =
360     {
361         {
362             vdupq_n_f32(0.f),
363             vdupq_n_f32(0.f)
364         }
365     };
366     if(stridex == 2)
367     {
368         const float32x4x2_t vtop     = vld2q_f32(in_top);
369         const float32x4x2_t vmid     = vld2q_f32(in_mid);
370         const float32x4x2_t vlow     = vld2q_f32(in_low);
371         const float32x4_t   vtop_end = vld1q_f32(in_top + 8);
372         const float32x4_t   vmid_end = vld1q_f32(in_mid + 8);
373         const float32x4_t   vlow_end = vld1q_f32(in_low + 8);
374 
375         out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]);
376 
377         out.val[0] = vmlaq_f32(out.val[0], vtop.val[1], m0.val[1]);
378         out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop_end, 1), m0.val[2]);
379 
380         out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
381         out.val[0] = vmlaq_f32(out.val[0], vmid.val[1], m1.val[1]);
382         out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid_end, 1), m1.val[2]);
383 
384         out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
385         out.val[0] = vmlaq_f32(out.val[0], vlow.val[1], m2.val[1]);
386         out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow_end, 1), m2.val[2]);
387 
388         accumulate ? accumulate_results<2>(out_ptr, out) : store_results<2>(out_ptr, out);
389     }
390     else
391     {
392         const float32x4x3_t vtop =
393         {
394             {
395                 vld1q_f32(in_top),
396                 vld1q_f32(in_top + 4),
397                 vld1q_f32(in_top + 8)
398             }
399         };
400         const float32x4x3_t vmid =
401         {
402             {
403                 vld1q_f32(in_mid),
404                 vld1q_f32(in_mid + 4),
405                 vld1q_f32(in_mid + 8)
406             }
407         };
408         const float32x4x3_t vlow =
409         {
410             {
411                 vld1q_f32(in_low),
412                 vld1q_f32(in_low + 4),
413                 vld1q_f32(in_low + 8)
414             }
415         };
416         out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]);
417         out.val[1] = vmulq_f32(vtop.val[1], m0.val[0]);
418 
419         out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
420         out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
421 
422         out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
423         out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
424         out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
425 
426         out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
427         out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
428         out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
429 
430         out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
431         out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
432 
433         out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
434         out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
435         out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
436 
437         out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
438         out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
439         out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
440 
441         if(stridex == 3)
442         {
443             out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
444             accumulate ? accumulate_results<3>(out_ptr, out) : store_results<3>(out_ptr, out);
445         }
446         else
447         {
448             accumulate ? accumulate_results<1>(out_ptr, out) : store_results<1>(out_ptr, out);
449         }
450     }
451 }
452 
453 /** Perform a 3x3 convolution for 4 consecutive 8-bit elements when dilation.x() or dilation.y() is not 1.
454  *
455  * @param[in] in_top       Pointer to the first row of the input.
456  * @param[in] in_mid       Pointer to the second row of the input.
457  * @param[in] in_low       Pointer to the third row of the input.
458  * @param[in] m0           First row of the filter.
459  * @param[in] m1           Second row of the filter.
460  * @param[in] m2           Third row of the filter.
461  * @param[in] dilation_x   Dilation, in elements across x.
462  * @param[in] input_offset Input quantization offset.
463  *
464  */
465 template < typename T, ARM_COMPUTE_REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
single_convolve_3x3_dilation(const T * in_top,const T * in_mid,const T * in_low,const int32x4x3_t & m0,const int32x4x3_t & m1,const int32x4x3_t & m2,size_t dilation_x,int32_t input_offset)466 inline int32x4_t single_convolve_3x3_dilation(const T *in_top, const T *in_mid, const T *in_low,
467                                               const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
468                                               size_t dilation_x, int32_t input_offset)
469 {
470     using VectorType    = typename std::conditional<std::is_same<T, uint8_t>::value, uint8x8x3_t, int8x8x3_t>::type;
471     using OutputTagType = typename wrapper::traits::neon_bitvector_tag_t<int32_t, wrapper::traits::BitWidth::W128>;
472 
473     const int32x4_t v_input_offset = wrapper::vdup_n(input_offset, OutputTagType{});
474 
475     const VectorType vtop =
476     {
477         {
478             wrapper::vload(in_top),
479             wrapper::vload(in_top + dilation_x),
480             wrapper::vload(in_top + 2 * dilation_x)
481         }
482     };
483     const VectorType vmid =
484     {
485         {
486             wrapper::vload(in_mid),
487             wrapper::vload(in_mid + dilation_x),
488             wrapper::vload(in_mid + 2 * dilation_x)
489         }
490     };
491     const VectorType vlow =
492     {
493         {
494             wrapper::vload(in_low),
495             wrapper::vload(in_low + dilation_x),
496             wrapper::vload(in_low + 2 * dilation_x)
497         }
498     };
499 
500     const int32x4x3_t vtop_s32 =
501     {
502         {
503             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[0])))),
504             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[1])))),
505             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[2])))),
506         }
507     };
508     const int32x4x3_t vmid_s32 =
509     {
510         {
511             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[0])))),
512             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[1])))),
513             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[2])))),
514         }
515     };
516     const int32x4x3_t vlow_s32 =
517     {
518         {
519             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[0])))),
520             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[1])))),
521             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[2])))),
522         }
523     };
524 
525     int32x4_t out = wrapper::vmul(vtop_s32.val[0], m0.val[0]);
526     out           = wrapper::vmla(out, vtop_s32.val[1], m0.val[1]);
527     out           = wrapper::vmla(out, vtop_s32.val[2], m0.val[2]);
528 
529     out = wrapper::vmla(out, vmid_s32.val[0], m1.val[0]);
530     out = wrapper::vmla(out, vmid_s32.val[1], m1.val[1]);
531     out = wrapper::vmla(out, vmid_s32.val[2], m1.val[2]);
532 
533     out = wrapper::vmla(out, vlow_s32.val[0], m2.val[0]);
534     out = wrapper::vmla(out, vlow_s32.val[1], m2.val[1]);
535     out = wrapper::vmla(out, vlow_s32.val[2], m2.val[2]);
536 
537     return out;
538 }
539 
540 /** Perform a 3x3 convolution for 4 consecutive 8-bit elements when dilation.x() or dilation.y() is not 1.
541  *
542  * @param[in] in_top       Pointer to the first row of the input.
543  * @param[in] in_mid       Pointer to the second row of the input.
544  * @param[in] in_low       Pointer to the third row of the input.
545  * @param[in] m0           First row of the filter.
546  * @param[in] m1           Second row of the filter.
547  * @param[in] m2           Third row of the filter.
548  * @param[in] dilation_x   Dilation, in elements across x.
549  * @param[in] stridex      Stride value in elements across x.
550  * @param[in] input_offset Input quantization offset.
551  *
552  */
553 template < typename T, ARM_COMPUTE_REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
convolve_3x3_dilation(const T * in_top,const T * in_mid,const T * in_low,const int32x4x3_t & m0,const int32x4x3_t & m1,const int32x4x3_t & m2,const size_t dilation_x,unsigned int stridex,int input_offset)554 inline int32x4x2_t convolve_3x3_dilation(const T *in_top, const T *in_mid, const T *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
555                                          const size_t dilation_x, unsigned int stridex, int input_offset)
556 {
557     ARM_COMPUTE_ERROR_ON(stridex > 3);
558     int32x4x2_t out =
559     {
560         {
561             single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
562             single_convolve_3x3_dilation(in_top + 4, in_mid + 4, in_low + 4, m0, m1, m2, dilation_x, input_offset)
563         }
564     };
565 
566     if(stridex == 2)
567     {
568         out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 2), out.val[0], 1);
569         out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 0), out.val[0], 2);
570         out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 2), out.val[0], 3);
571     }
572     else if(stridex == 3)
573     {
574         out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 3), out.val[0], 1);
575     }
576     return out;
577 }
578 
579 /** Perform a convolve3x3 on 8-bit elements
580  *
581  * @param[in]  in_top       Pointer to the first row of the input.
582  * @param[in]  in_mid       Pointer to the second row of the input.
583  * @param[in]  in_low       Pointer to the third row of the input.
584  * @param[out] out_ptr      Pointer to the output.
585  * @param[in]  m0           First row of the filter.
586  * @param[in]  m1           Second row of the filter.
587  * @param[in]  m2           Third row of the filter.
588  * @param[in]  stridex      Stride value in elements across x.
589  * @param[in]  input_offset Input quantization offset.
590  *
591  */
592 template < bool accumulate, typename T1, typename T2, ARM_COMPUTE_REQUIRES_TA(std::is_same<T1, uint8_t>::value || std::is_same<T1, int8_t>::value) >
convolve_3x3(const T1 * in_top,const T1 * in_mid,const T1 * in_low,T2 * out_ptr,const int32x4x3_t & m0,const int32x4x3_t & m1,const int32x4x3_t & m2,unsigned int stridex,int32_t input_offset)593 void convolve_3x3(const T1 *in_top, const T1 *in_mid, const T1 *in_low, T2 *out_ptr,
594                   const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
595                   unsigned int stridex, int32_t input_offset)
596 {
597     ARM_COMPUTE_ERROR_ON(stridex > 3);
598     using VectorType    = typename std::conditional<std::is_same<T1, uint8_t>::value, uint8x8x2_t, int8x8x2_t>::type;
599     using OutputTagType = typename wrapper::traits::neon_bitvector_tag_t<int32_t, wrapper::traits::BitWidth::W128>;
600 
601     const int32x4_t v_input_offset = wrapper::vdup_n(input_offset, OutputTagType{});
602 
603     const VectorType vtop =
604     {
605         {
606             wrapper::vload(in_top),
607             wrapper::vload(in_top + 8)
608         }
609     };
610     const VectorType vmid =
611     {
612         {
613             wrapper::vload(in_mid),
614             wrapper::vload(in_mid + 8)
615         }
616     };
617     const VectorType vlow =
618     {
619         {
620             wrapper::vload(in_low),
621             wrapper::vload(in_low + 8)
622         }
623     };
624 
625     const int32x4x3_t vtop_s32 =
626     {
627         {
628             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[0])))),
629             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vtop.val[0])))),
630             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[1])))),
631         }
632     };
633     const int32x4x3_t vmid_s32 =
634     {
635         {
636             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[0])))),
637             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vmid.val[0])))),
638             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[1])))),
639         }
640     };
641     const int32x4x3_t vlow_s32 =
642     {
643         {
644             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[0])))),
645             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vlow.val[0])))),
646             wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[1])))),
647         }
648     };
649 
650     int32x4x2_t out
651     {
652         {
653             wrapper::vdup_n(static_cast<int32_t>(0), OutputTagType{}),
654             wrapper::vdup_n(static_cast<int32_t>(0), OutputTagType{}),
655         }
656     };
657 
658     // 0
659     out.val[0] = wrapper::vmla(out.val[0], vtop_s32.val[0], m0.val[0]);
660     out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vtop_s32.val[0], vtop_s32.val[1]), m0.val[1]);
661     out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vtop_s32.val[0], vtop_s32.val[1]), m0.val[2]);
662 
663     out.val[0] = wrapper::vmla(out.val[0], vmid_s32.val[0], m1.val[0]);
664     out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vmid_s32.val[0], vmid_s32.val[1]), m1.val[1]);
665     out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vmid_s32.val[0], vmid_s32.val[1]), m1.val[2]);
666 
667     out.val[0] = wrapper::vmla(out.val[0], vlow_s32.val[0], m2.val[0]);
668     out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vlow_s32.val[0], vlow_s32.val[1]), m2.val[1]);
669     out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vlow_s32.val[0], vlow_s32.val[1]), m2.val[2]);
670 
671     // 1
672     out.val[1] = wrapper::vmla(out.val[1], vtop_s32.val[1], m0.val[0]);
673     out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vtop_s32.val[1], vtop_s32.val[2]), m0.val[1]);
674     out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vtop_s32.val[1], vtop_s32.val[2]), m0.val[2]);
675 
676     out.val[1] = wrapper::vmla(out.val[1], vmid_s32.val[1], m1.val[0]);
677     out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vmid_s32.val[1], vmid_s32.val[2]), m1.val[1]);
678     out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vmid_s32.val[1], vmid_s32.val[2]), m1.val[2]);
679 
680     out.val[1] = wrapper::vmla(out.val[1], vlow_s32.val[1], m2.val[0]);
681     out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vlow_s32.val[1], vlow_s32.val[2]), m2.val[1]);
682     out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vlow_s32.val[1], vlow_s32.val[2]), m2.val[2]);
683 
684     if(stridex == 1)
685     {
686         accumulate ? accumulate_results<1>(out_ptr, out) : store_results<1>(out_ptr, out);
687     }
688     else if(stridex == 2)
689     {
690         out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 2), out.val[0], 1);
691         out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 0), out.val[0], 2);
692         out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 2), out.val[0], 3);
693 
694         accumulate ? accumulate_results<2>(out_ptr, out) : store_results<2>(out_ptr, out);
695     }
696     else if(stridex == 3)
697     {
698         out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 3), out.val[0], 1);
699         accumulate ? accumulate_results<3>(out_ptr, out) : store_results<3>(out_ptr, out);
700     }
701 }
702 
703 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
704 /** Loads a 3x3 matrix as a row (float16_t).
705  *
706  * @param[in] ptr Pointer to a float 3x3 matrix.
707  *
708  * @return The loaded matrix.
709  */
710 inline float16x8x3_t load_matrix_row(const float16_t *ptr, int weights_offset = 0)
711 {
712     ARM_COMPUTE_UNUSED(weights_offset);
713     /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
714        r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
715     const float16x8x3_t r =
716     {
717         {
718             vld1q_dup_f16(ptr),
719             vld1q_dup_f16(1 + ptr),
720             vld1q_dup_f16(2 + ptr)
721         }
722     };
723     return r;
724 }
725 
726 /** Perform a 3x3 convolution for 8 consecutive elements on float16 when dilation.x() or dilation.y() is not 1.
727  *
728  * @param[in] in_top       Pointer to the first row of the input.
729  * @param[in] in_mid       Pointer to the second row of the input.
730  * @param[in] in_low       Pointer to the third row of the input.
731  * @param[in] m0           First row of the filter.
732  * @param[in] m1           Second row of the filter.
733  * @param[in] m2           Third row of the filter.
734  * @param[in] dilation_x   Dilation, in elements across x.
735  * @param[in] input_offset (Optional)Input quantization offset.
736  *
737  */
738 inline float16x8_t single_convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
739                                                 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
740                                                 const size_t dilation_x, int input_offset = 0)
741 {
742     ARM_COMPUTE_UNUSED(input_offset);
743     const float16x8x3_t vtop =
744     {
745         {
746             vld1q_f16(in_top),
747             vld1q_f16(in_top + dilation_x),
748             vld1q_f16(in_top + 2 * dilation_x)
749         }
750     };
751     const float16x8x3_t vmid =
752     {
753         {
754             vld1q_f16(in_mid),
755             vld1q_f16(in_mid + dilation_x),
756             vld1q_f16(in_mid + 2 * dilation_x)
757         }
758     };
759     const float16x8x3_t vlow =
760     {
761         {
762             vld1q_f16(in_low),
763             vld1q_f16(in_low + dilation_x),
764             vld1q_f16(in_low + 2 * dilation_x)
765         }
766     };
767     float16x8_t out = vmulq_f16(vtop.val[0], m0.val[0]);
768     out             = vaddq_f16(out, vmulq_f16(vtop.val[1], m0.val[1]));
769     out             = vaddq_f16(out, vmulq_f16(vtop.val[2], m0.val[2]));
770 
771     out = vaddq_f16(out, vmulq_f16(vmid.val[0], m1.val[0]));
772     out = vaddq_f16(out, vmulq_f16(vmid.val[1], m1.val[1]));
773     out = vaddq_f16(out, vmulq_f16(vmid.val[2], m1.val[2]));
774 
775     out = vaddq_f16(out, vmulq_f16(vlow.val[0], m2.val[0]));
776     out = vaddq_f16(out, vmulq_f16(vlow.val[1], m2.val[1]));
777     out = vaddq_f16(out, vmulq_f16(vlow.val[2], m2.val[2]));
778 
779     return out;
780 }
781 
782 /** Perform a 3x3 convolution for 16 consecutive elements on float16 when dilation.x() or dilation.y() is not 1.
783  *
784  * @param[in] in_top       Pointer to the first row of the input.
785  * @param[in] in_mid       Pointer to the second row of the input.
786  * @param[in] in_low       Pointer to the third row of the input.
787  * @param[in] m0           First row of the filter.
788  * @param[in] m1           Second row of the filter.
789  * @param[in] m2           Third row of the filter.
790  * @param[in] dilation_x   Dilation, in elements across x.
791  * @param[in] stridex      Stride value in elements across x.
792  * @param[in] input_offset (Optional) Input quantization offset.
793  *
794  */
795 inline float16x8x2_t convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
796                                            const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
797                                            const size_t dilation_x, unsigned int stridex, int input_offset = 0)
798 {
799     float16x8x2_t out =
800     {
801         {
802             single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
803             single_convolve_3x3_dilation(in_top + 8, in_mid + 8, in_low + 8, m0, m1, m2, dilation_x, input_offset)
804         }
805     };
806 
807     if(stridex == 2)
808     {
809         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
810         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 4), out.val[0], 2);
811         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 3);
812         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 4);
813         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 5);
814         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 4), out.val[0], 6);
815         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 6), out.val[0], 7);
816     }
817     else if(stridex == 3)
818     {
819         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
820         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2);
821         out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3);
822     }
823 
824     return out;
825 }
826 
827 /** Perform a convolve3x3 on float16.
828  *
829  * @param[in]  in_top       Pointer to the first row of the input.
830  * @param[in]  in_mid       Pointer to the second row of the input.
831  * @param[in]  in_low       Pointer to the third row of the input.
832  * @param[out] out_ptr      Pointer to the output.
833  * @param[in]  m0           First row of the filter.
834  * @param[in]  m1           Second row of the filter.
835  * @param[in]  m2           Third row of the filter.
836  * @param[in]  stridex      Stride value in elements across x.
837  * @param[in]  input_offset (Optional) Input quantization offset.
838  *
839  */
840 template <bool accumulate>
841 inline void convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, float16_t *out_ptr,
842                          const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
843                          unsigned int stridex, int input_offset = 0)
844 {
845     ARM_COMPUTE_UNUSED(input_offset);
846 
847     float16x8x2_t out =
848     {
849         {
850             vdupq_n_f16(0),
851             vdupq_n_f16(0)
852         }
853     };
854     if(stridex == 2)
855     {
856         const float16x8x2_t vtop     = vld2q_f16(in_top);
857         const float16x8x2_t vmid     = vld2q_f16(in_mid);
858         const float16x8x2_t vlow     = vld2q_f16(in_low);
859         const float16x8_t   vtop_end = vld1q_f16(in_top + 16);
860         const float16x8_t   vmid_end = vld1q_f16(in_mid + 16);
861         const float16x8_t   vlow_end = vld1q_f16(in_low + 16);
862 
863         out.val[0] = vmulq_f16(vtop.val[0], m0.val[0]);
864 
865         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vtop.val[1], m0.val[1]));
866         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop_end, 1), m0.val[2]));
867 
868         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
869         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[1], m1.val[1]));
870         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid_end, 1), m1.val[2]));
871 
872         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
873         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[1], m2.val[1]));
874         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow_end, 1), m2.val[2]));
875 
876         accumulate ? accumulate_results<2>(out_ptr, out) : store_results<2>(out_ptr, out);
877     }
878     else
879     {
880         const float16x8x3_t vtop =
881         {
882             {
883                 vld1q_f16(in_top),
884                 vld1q_f16(in_top + 8),
885                 vld1q_f16(in_top + 16)
886             }
887         };
888         const float16x8x3_t vmid =
889         {
890             {
891                 vld1q_f16(in_mid),
892                 vld1q_f16(in_mid + 8),
893                 vld1q_f16(in_mid + 16)
894             }
895         };
896         const float16x8x3_t vlow =
897         {
898             {
899                 vld1q_f16(in_low),
900                 vld1q_f16(in_low + 8),
901                 vld1q_f16(in_low + 16)
902             }
903         };
904         out.val[0] = vmulq_f16(vtop.val[0], m0.val[0]);
905         out.val[1] = vmulq_f16(vtop.val[1], m0.val[0]);
906 
907         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1]));
908         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2]));
909         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
910         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1]));
911         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2]));
912         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
913         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1]));
914         out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2]));
915         out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1]));
916         out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2]));
917         out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0]));
918         out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1]));
919         out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2]));
920         out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0]));
921         out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1]));
922         out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2]));
923 
924         if(stridex == 3)
925         {
926             out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
927             out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2);
928             out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3);
929 
930             accumulate ? accumulate_results<3>(out_ptr, out) : store_results<3>(out_ptr, out);
931         }
932         else
933         {
934             accumulate ? accumulate_results<1>(out_ptr, out) : store_results<1>(out_ptr, out);
935         }
936     }
937 }
938 #endif /** __ARM_FEATURE_FP16_VECTOR_ARITHMETIC **/
939 
940 /** Get the number of elements processed on 3x3 convolution.
941  *
942  * @param[in] num_elems_written_per_iteration Number of elements written per iteration on 3x3 convolution.
943  * @param[in] stridex                         Stride value in elements across x.
944  *
945  * @return The number of elements processed.
946  */
get_input_num_elems_processed(unsigned int num_elems_written_per_iteration,unsigned int stridex)947 inline int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration, unsigned int stridex)
948 {
949     switch(stridex)
950     {
951         case 1:
952             return num_elems_written_per_iteration;
953         case 2:
954             return num_elems_written_per_iteration << 1;
955         case 3:
956             return num_elems_written_per_iteration * 3;
957         default:
958             ARM_COMPUTE_ERROR("stridex not supported");
959             return 0;
960     }
961 }
962 }
963 } // namespace arm_compute
964 #endif /* ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H */
965