1 /*
2 * Copyright (c) 2017-2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25 #ifndef ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H
26 #define ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H
27
28 #include "src/core/NEON/NEFixedPoint.h"
29 #include "src/core/NEON/wrapper/wrapper.h"
30 #include "support/Requires.h"
31
32 #include <arm_neon.h>
33
34 namespace arm_compute
35 {
36 namespace detail
37 {
38 /** Loads a 3x3 matrix as a row (float).
39 *
40 * @param[in] ptr Pointer to a float 3x3 matrix.
41 * @param[in] weights_offset (Optional) Weights quantization offset.
42 *
43 * @return The loaded matrix.
44 */
45 inline float32x4x3_t load_matrix_row(const float *ptr, int weights_offset = 0)
46 {
47 ARM_COMPUTE_UNUSED(weights_offset);
48 const float32x4x3_t r =
49 {
50 {
51 vld1q_dup_f32(ptr),
52 vld1q_dup_f32(1 + ptr),
53 vld1q_dup_f32(2 + ptr)
54 }
55 };
56 return r;
57 }
58
59 /** Loads a 3x3 matrix as a row (uint8_t/int8_t).
60 *
61 * @param[in] ptr Pointer to a uint8_t/int8_t 3x3 matrix.
62 * @param[in] weights_offset (Optional) Weights quantization offset.
63 *
64 * @return The loaded matrix.
65 */
66 template < typename T, ARM_COMPUTE_REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
67 inline int32x4x3_t load_matrix_row(const T *ptr, int weights_offset = 0)
68 {
69 const int32x4_t v_weights_offset = vdupq_n_s32(weights_offset);
70
71 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
72 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
73 int32x4x3_t r =
74 {
75 {
76 vaddq_s32(v_weights_offset, vdupq_n_s32(*ptr)),
77 vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 1))),
78 vaddq_s32(v_weights_offset, vdupq_n_s32(*(ptr + 2)))
79 }
80 };
81 return r;
82 }
83
84 /** Stores a float32x4x2_t array into a memory location.
85 *
86 * @param[in] buffer Pointer to the memory location where the values will be stored.
87 * @param[in] values Values that will be stored.
88 *
89 */
90 template <unsigned int stridex>
91 void store_results(float *buffer, const float32x4x2_t &values);
92
93 template <>
94 inline void store_results<1>(float *buffer, const float32x4x2_t &values)
95 {
96 vst1q_f32(buffer, values.val[0]);
97 vst1q_f32(buffer + 4, values.val[1]);
98 }
99
100 template <>
101 inline void store_results<2>(float *buffer, const float32x4x2_t &values)
102 {
103 vst1q_f32(buffer, values.val[0]);
104 }
105
106 template <>
107 inline void store_results<3>(float *buffer, const float32x4x2_t &values)
108 {
109 vst1_f32(buffer, vget_low_f32(values.val[0]));
110 }
111
112 /** Stores a uint32_t array into a memory location.
113 *
114 * @param[in] buffer Pointer to the memory location where the values will be stored.
115 * @param[in] values Values that will be stored.
116 *
117 */
118 template <unsigned int stridex>
119 void store_results(int32_t *buffer, const int32x4x2_t &values);
120
121 template <>
122 inline void store_results<1>(int32_t *buffer, const int32x4x2_t &values)
123 {
124 vst1q_s32(buffer, values.val[0]);
125 vst1q_s32(buffer + 4, values.val[1]);
126 }
127
128 template <>
129 inline void store_results<2>(int32_t *buffer, const int32x4x2_t &values)
130 {
131 vst1q_s32(buffer, values.val[0]);
132 }
133
134 template <>
135 inline void store_results<3>(int32_t *buffer, const int32x4x2_t &values)
136 {
137 vst1_s32(buffer, vget_low_s32(values.val[0]));
138 }
139
140 template <unsigned int stridex>
141 inline void accumulate_results(float *buffer, const float32x4x2_t &values);
142
143 template <>
144 inline void accumulate_results<1>(float *buffer, const float32x4x2_t &values)
145 {
146 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
147 vst1q_f32(buffer + 4, vaddq_f32(vld1q_f32(buffer + 4), values.val[1]));
148 }
149
150 template <>
151 inline void accumulate_results<2>(float *buffer, const float32x4x2_t &values)
152 {
153 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
154 }
155
156 template <>
157 inline void accumulate_results<3>(float *buffer, const float32x4x2_t &values)
158 {
159 vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
160 }
161
162 template <unsigned int stridex>
163 void accumulate_results(int32_t *buffer, const int32x4x2_t &values);
164
165 template <>
166 inline void accumulate_results<1>(int32_t *buffer, const int32x4x2_t &values)
167 {
168 vst1q_s32(buffer, vaddq_s32(vld1q_s32(buffer), values.val[0]));
169 vst1q_s32(buffer + 4, vaddq_s32(vld1q_s32(buffer + 4), values.val[1]));
170 }
171
172 template <>
173 inline void accumulate_results<2>(int32_t *buffer, const int32x4x2_t &values)
174 {
175 vst1q_s32(buffer, vaddq_s32(vld1q_s32(buffer), values.val[0]));
176 }
177
178 template <>
179 inline void accumulate_results<3>(int32_t *buffer, const int32x4x2_t &values)
180 {
181 vst1_s32(buffer, vadd_s32(vld1_s32(buffer), vget_low_s32(values.val[0])));
182 }
183
184 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
185 /** Stores a float16x8x2_t array into a memory location.
186 *
187 * @param[in] buffer Pointer to the memory location where the values will be stored.
188 * @param[in] values Values that will be stored.
189 *
190 */
191 template <unsigned int stridex>
192 void store_results(float16_t *buffer, const float16x8x2_t &values);
193
194 template <>
195 inline void store_results<1>(float16_t *buffer, const float16x8x2_t &values)
196 {
197 vst1q_f16(buffer, values.val[0]);
198 vst1q_f16(buffer + 8, values.val[1]);
199 }
200
201 template <>
202 inline void store_results<2>(float16_t *buffer, const float16x8x2_t &values)
203 {
204 vst1q_f16(buffer, values.val[0]);
205 }
206
207 template <>
208 inline void store_results<3>(float16_t *buffer, const float16x8x2_t &values)
209 {
210 vst1_f16(buffer, vget_low_f16(values.val[0]));
211 }
212
213 template <unsigned int stridex>
214 inline void accumulate_results(float16_t *buffer, const float16x8x2_t &values);
215
216 template <>
217 inline void accumulate_results<1>(float16_t *buffer, const float16x8x2_t &values)
218 {
219 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
220 vst1q_f16(buffer + 8, vaddq_f16(vld1q_f16(buffer + 8), values.val[1]));
221 }
222
223 template <>
224 inline void accumulate_results<2>(float16_t *buffer, const float16x8x2_t &values)
225 {
226 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
227 }
228
229 template <>
230 inline void accumulate_results<3>(float16_t *buffer, const float16x8x2_t &values)
231 {
232 vst1_f16(buffer, vadd_f16(vld1_f16(buffer), vget_low_f16(values.val[0])));
233 }
234 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
235
236 /** Perform a 3x3 convolution for 4 consecutive elements on float32 when dilation.x() or dilation.y() is not 1.
237 *
238 * @param[in] in_top Pointer to the first row of the input.
239 * @param[in] in_mid Pointer to the second row of the input.
240 * @param[in] in_low Pointer to the third row of the input.
241 * @param[in] m0 First row of the filter.
242 * @param[in] m1 Second row of the filter.
243 * @param[in] m2 Third row of the filter.
244 * @param[in] dilation_x Dilation, in elements across x.
245 * @param[in] input_offset (Optional) Input quantization offset.
246 *
247 */
single_convolve_3x3_dilation(const float * in_top,const float * in_mid,const float * in_low,const float32x4x3_t & m0,const float32x4x3_t & m1,const float32x4x3_t & m2,const size_t dilation_x,int input_offset)248 inline float32x4_t single_convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low,
249 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
250 const size_t dilation_x, int input_offset)
251 {
252 ARM_COMPUTE_UNUSED(input_offset);
253
254 const float32x4x3_t vtop =
255 {
256 {
257 vld1q_f32(in_top),
258 vld1q_f32(in_top + dilation_x),
259 vld1q_f32(in_top + 2 * dilation_x)
260 }
261 };
262 const float32x4x3_t vmid =
263 {
264 {
265 vld1q_f32(in_mid),
266 vld1q_f32(in_mid + dilation_x),
267 vld1q_f32(in_mid + 2 * dilation_x)
268 }
269 };
270 const float32x4x3_t vlow =
271 {
272 {
273 vld1q_f32(in_low),
274 vld1q_f32(in_low + dilation_x),
275 vld1q_f32(in_low + 2 * dilation_x)
276 }
277 };
278 float32x4_t out = vmulq_f32(vtop.val[0], m0.val[0]);
279 out = vmlaq_f32(out, vtop.val[1], m0.val[1]);
280 out = vmlaq_f32(out, vtop.val[2], m0.val[2]);
281
282 out = vmlaq_f32(out, vmid.val[0], m1.val[0]);
283 out = vmlaq_f32(out, vmid.val[1], m1.val[1]);
284 out = vmlaq_f32(out, vmid.val[2], m1.val[2]);
285
286 out = vmlaq_f32(out, vlow.val[0], m2.val[0]);
287 out = vmlaq_f32(out, vlow.val[1], m2.val[1]);
288 out = vmlaq_f32(out, vlow.val[2], m2.val[2]);
289
290 return out;
291 }
292
293 /** Perform a 3x3 convolution for 8 consecutive elements on float32 when dilation.x() or dilation.y() is not 1.
294 *
295 * @param[in] in_top Pointer to the first row of the input.
296 * @param[in] in_mid Pointer to the second row of the input.
297 * @param[in] in_low Pointer to the third row of the input.
298 * @param[in] m0 First row of the filter.
299 * @param[in] m1 Second row of the filter.
300 * @param[in] m2 Third row of the filter.
301 * @param[in] dilation_x Dilation, in elements across x.
302 * @param[in] stridex Stride value in elements across x.
303 * @param[in] input_offset (Optional) Input quantization offset.
304 *
305 */
306 inline float32x4x2_t convolve_3x3_dilation(const float *in_top, const float *in_mid, const float *in_low,
307 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
308 const size_t dilation_x, unsigned int stridex, int input_offset = 0)
309 {
310 ARM_COMPUTE_ERROR_ON(stridex > 3);
311 float32x4x2_t out =
312 {
313 {
314 single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
315 single_convolve_3x3_dilation(in_top + 4, in_mid + 4, in_low + 4, m0, m1, m2, dilation_x, input_offset)
316 }
317 };
318
319 if(stridex == 2)
320 {
321 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
322 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
323 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
324 }
325 else if(stridex == 3)
326 {
327 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
328 }
329
330 return out;
331 }
332
333 /** Perform a convolve3x3 on float32.
334 *
335 * @param[in] in_top Pointer to the first row of the input.
336 * @param[in] in_mid Pointer to the second row of the input.
337 * @param[in] in_low Pointer to the third row of the input.
338 * @param[out] out_ptr Pointer to the output.
339 * @param[in] m0 First row of the filter.
340 * @param[in] m1 Second row of the filter.
341 * @param[in] m2 Third row of the filter.
342 * @param[in] stridex Stride value in elements across x.
343 * @param[in] input_offset (Optional) Input quantization offset.
344 *
345 */
346 template <bool accumulate>
347 void convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, float *out_ptr,
348 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
349 unsigned int stridex, int input_offset = 0);
350
351 template <bool accumulate>
convolve_3x3(const float * in_top,const float * in_mid,const float * in_low,float * out_ptr,const float32x4x3_t & m0,const float32x4x3_t & m1,const float32x4x3_t & m2,unsigned int stridex,int input_offset)352 inline void convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, float *out_ptr,
353 const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2,
354 unsigned int stridex, int input_offset)
355 {
356 ARM_COMPUTE_UNUSED(input_offset);
357 ARM_COMPUTE_ERROR_ON(stridex > 3);
358
359 float32x4x2_t out =
360 {
361 {
362 vdupq_n_f32(0.f),
363 vdupq_n_f32(0.f)
364 }
365 };
366 if(stridex == 2)
367 {
368 const float32x4x2_t vtop = vld2q_f32(in_top);
369 const float32x4x2_t vmid = vld2q_f32(in_mid);
370 const float32x4x2_t vlow = vld2q_f32(in_low);
371 const float32x4_t vtop_end = vld1q_f32(in_top + 8);
372 const float32x4_t vmid_end = vld1q_f32(in_mid + 8);
373 const float32x4_t vlow_end = vld1q_f32(in_low + 8);
374
375 out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]);
376
377 out.val[0] = vmlaq_f32(out.val[0], vtop.val[1], m0.val[1]);
378 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop_end, 1), m0.val[2]);
379
380 out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
381 out.val[0] = vmlaq_f32(out.val[0], vmid.val[1], m1.val[1]);
382 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid_end, 1), m1.val[2]);
383
384 out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
385 out.val[0] = vmlaq_f32(out.val[0], vlow.val[1], m2.val[1]);
386 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow_end, 1), m2.val[2]);
387
388 accumulate ? accumulate_results<2>(out_ptr, out) : store_results<2>(out_ptr, out);
389 }
390 else
391 {
392 const float32x4x3_t vtop =
393 {
394 {
395 vld1q_f32(in_top),
396 vld1q_f32(in_top + 4),
397 vld1q_f32(in_top + 8)
398 }
399 };
400 const float32x4x3_t vmid =
401 {
402 {
403 vld1q_f32(in_mid),
404 vld1q_f32(in_mid + 4),
405 vld1q_f32(in_mid + 8)
406 }
407 };
408 const float32x4x3_t vlow =
409 {
410 {
411 vld1q_f32(in_low),
412 vld1q_f32(in_low + 4),
413 vld1q_f32(in_low + 8)
414 }
415 };
416 out.val[0] = vmulq_f32(vtop.val[0], m0.val[0]);
417 out.val[1] = vmulq_f32(vtop.val[1], m0.val[0]);
418
419 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
420 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
421
422 out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
423 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
424 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
425
426 out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
427 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
428 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
429
430 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
431 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
432
433 out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
434 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
435 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
436
437 out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
438 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
439 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
440
441 if(stridex == 3)
442 {
443 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
444 accumulate ? accumulate_results<3>(out_ptr, out) : store_results<3>(out_ptr, out);
445 }
446 else
447 {
448 accumulate ? accumulate_results<1>(out_ptr, out) : store_results<1>(out_ptr, out);
449 }
450 }
451 }
452
453 /** Perform a 3x3 convolution for 4 consecutive 8-bit elements when dilation.x() or dilation.y() is not 1.
454 *
455 * @param[in] in_top Pointer to the first row of the input.
456 * @param[in] in_mid Pointer to the second row of the input.
457 * @param[in] in_low Pointer to the third row of the input.
458 * @param[in] m0 First row of the filter.
459 * @param[in] m1 Second row of the filter.
460 * @param[in] m2 Third row of the filter.
461 * @param[in] dilation_x Dilation, in elements across x.
462 * @param[in] input_offset Input quantization offset.
463 *
464 */
465 template < typename T, ARM_COMPUTE_REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
single_convolve_3x3_dilation(const T * in_top,const T * in_mid,const T * in_low,const int32x4x3_t & m0,const int32x4x3_t & m1,const int32x4x3_t & m2,size_t dilation_x,int32_t input_offset)466 inline int32x4_t single_convolve_3x3_dilation(const T *in_top, const T *in_mid, const T *in_low,
467 const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
468 size_t dilation_x, int32_t input_offset)
469 {
470 using VectorType = typename std::conditional<std::is_same<T, uint8_t>::value, uint8x8x3_t, int8x8x3_t>::type;
471 using OutputTagType = typename wrapper::traits::neon_bitvector_tag_t<int32_t, wrapper::traits::BitWidth::W128>;
472
473 const int32x4_t v_input_offset = wrapper::vdup_n(input_offset, OutputTagType{});
474
475 const VectorType vtop =
476 {
477 {
478 wrapper::vload(in_top),
479 wrapper::vload(in_top + dilation_x),
480 wrapper::vload(in_top + 2 * dilation_x)
481 }
482 };
483 const VectorType vmid =
484 {
485 {
486 wrapper::vload(in_mid),
487 wrapper::vload(in_mid + dilation_x),
488 wrapper::vload(in_mid + 2 * dilation_x)
489 }
490 };
491 const VectorType vlow =
492 {
493 {
494 wrapper::vload(in_low),
495 wrapper::vload(in_low + dilation_x),
496 wrapper::vload(in_low + 2 * dilation_x)
497 }
498 };
499
500 const int32x4x3_t vtop_s32 =
501 {
502 {
503 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[0])))),
504 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[1])))),
505 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[2])))),
506 }
507 };
508 const int32x4x3_t vmid_s32 =
509 {
510 {
511 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[0])))),
512 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[1])))),
513 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[2])))),
514 }
515 };
516 const int32x4x3_t vlow_s32 =
517 {
518 {
519 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[0])))),
520 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[1])))),
521 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[2])))),
522 }
523 };
524
525 int32x4_t out = wrapper::vmul(vtop_s32.val[0], m0.val[0]);
526 out = wrapper::vmla(out, vtop_s32.val[1], m0.val[1]);
527 out = wrapper::vmla(out, vtop_s32.val[2], m0.val[2]);
528
529 out = wrapper::vmla(out, vmid_s32.val[0], m1.val[0]);
530 out = wrapper::vmla(out, vmid_s32.val[1], m1.val[1]);
531 out = wrapper::vmla(out, vmid_s32.val[2], m1.val[2]);
532
533 out = wrapper::vmla(out, vlow_s32.val[0], m2.val[0]);
534 out = wrapper::vmla(out, vlow_s32.val[1], m2.val[1]);
535 out = wrapper::vmla(out, vlow_s32.val[2], m2.val[2]);
536
537 return out;
538 }
539
540 /** Perform a 3x3 convolution for 4 consecutive 8-bit elements when dilation.x() or dilation.y() is not 1.
541 *
542 * @param[in] in_top Pointer to the first row of the input.
543 * @param[in] in_mid Pointer to the second row of the input.
544 * @param[in] in_low Pointer to the third row of the input.
545 * @param[in] m0 First row of the filter.
546 * @param[in] m1 Second row of the filter.
547 * @param[in] m2 Third row of the filter.
548 * @param[in] dilation_x Dilation, in elements across x.
549 * @param[in] stridex Stride value in elements across x.
550 * @param[in] input_offset Input quantization offset.
551 *
552 */
553 template < typename T, ARM_COMPUTE_REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
convolve_3x3_dilation(const T * in_top,const T * in_mid,const T * in_low,const int32x4x3_t & m0,const int32x4x3_t & m1,const int32x4x3_t & m2,const size_t dilation_x,unsigned int stridex,int input_offset)554 inline int32x4x2_t convolve_3x3_dilation(const T *in_top, const T *in_mid, const T *in_low, const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
555 const size_t dilation_x, unsigned int stridex, int input_offset)
556 {
557 ARM_COMPUTE_ERROR_ON(stridex > 3);
558 int32x4x2_t out =
559 {
560 {
561 single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
562 single_convolve_3x3_dilation(in_top + 4, in_mid + 4, in_low + 4, m0, m1, m2, dilation_x, input_offset)
563 }
564 };
565
566 if(stridex == 2)
567 {
568 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 2), out.val[0], 1);
569 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 0), out.val[0], 2);
570 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 2), out.val[0], 3);
571 }
572 else if(stridex == 3)
573 {
574 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 3), out.val[0], 1);
575 }
576 return out;
577 }
578
579 /** Perform a convolve3x3 on 8-bit elements
580 *
581 * @param[in] in_top Pointer to the first row of the input.
582 * @param[in] in_mid Pointer to the second row of the input.
583 * @param[in] in_low Pointer to the third row of the input.
584 * @param[out] out_ptr Pointer to the output.
585 * @param[in] m0 First row of the filter.
586 * @param[in] m1 Second row of the filter.
587 * @param[in] m2 Third row of the filter.
588 * @param[in] stridex Stride value in elements across x.
589 * @param[in] input_offset Input quantization offset.
590 *
591 */
592 template < bool accumulate, typename T1, typename T2, ARM_COMPUTE_REQUIRES_TA(std::is_same<T1, uint8_t>::value || std::is_same<T1, int8_t>::value) >
convolve_3x3(const T1 * in_top,const T1 * in_mid,const T1 * in_low,T2 * out_ptr,const int32x4x3_t & m0,const int32x4x3_t & m1,const int32x4x3_t & m2,unsigned int stridex,int32_t input_offset)593 void convolve_3x3(const T1 *in_top, const T1 *in_mid, const T1 *in_low, T2 *out_ptr,
594 const int32x4x3_t &m0, const int32x4x3_t &m1, const int32x4x3_t &m2,
595 unsigned int stridex, int32_t input_offset)
596 {
597 ARM_COMPUTE_ERROR_ON(stridex > 3);
598 using VectorType = typename std::conditional<std::is_same<T1, uint8_t>::value, uint8x8x2_t, int8x8x2_t>::type;
599 using OutputTagType = typename wrapper::traits::neon_bitvector_tag_t<int32_t, wrapper::traits::BitWidth::W128>;
600
601 const int32x4_t v_input_offset = wrapper::vdup_n(input_offset, OutputTagType{});
602
603 const VectorType vtop =
604 {
605 {
606 wrapper::vload(in_top),
607 wrapper::vload(in_top + 8)
608 }
609 };
610 const VectorType vmid =
611 {
612 {
613 wrapper::vload(in_mid),
614 wrapper::vload(in_mid + 8)
615 }
616 };
617 const VectorType vlow =
618 {
619 {
620 wrapper::vload(in_low),
621 wrapper::vload(in_low + 8)
622 }
623 };
624
625 const int32x4x3_t vtop_s32 =
626 {
627 {
628 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[0])))),
629 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vtop.val[0])))),
630 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vtop.val[1])))),
631 }
632 };
633 const int32x4x3_t vmid_s32 =
634 {
635 {
636 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[0])))),
637 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vmid.val[0])))),
638 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vmid.val[1])))),
639 }
640 };
641 const int32x4x3_t vlow_s32 =
642 {
643 {
644 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[0])))),
645 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgethigh(wrapper::vmovl(vlow.val[0])))),
646 wrapper::vaddw(v_input_offset, wrapper::vreinterpret(wrapper::vgetlow(wrapper::vmovl(vlow.val[1])))),
647 }
648 };
649
650 int32x4x2_t out
651 {
652 {
653 wrapper::vdup_n(static_cast<int32_t>(0), OutputTagType{}),
654 wrapper::vdup_n(static_cast<int32_t>(0), OutputTagType{}),
655 }
656 };
657
658 // 0
659 out.val[0] = wrapper::vmla(out.val[0], vtop_s32.val[0], m0.val[0]);
660 out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vtop_s32.val[0], vtop_s32.val[1]), m0.val[1]);
661 out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vtop_s32.val[0], vtop_s32.val[1]), m0.val[2]);
662
663 out.val[0] = wrapper::vmla(out.val[0], vmid_s32.val[0], m1.val[0]);
664 out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vmid_s32.val[0], vmid_s32.val[1]), m1.val[1]);
665 out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vmid_s32.val[0], vmid_s32.val[1]), m1.val[2]);
666
667 out.val[0] = wrapper::vmla(out.val[0], vlow_s32.val[0], m2.val[0]);
668 out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_1(vlow_s32.val[0], vlow_s32.val[1]), m2.val[1]);
669 out.val[0] = wrapper::vmla(out.val[0], wrapper::vext_2(vlow_s32.val[0], vlow_s32.val[1]), m2.val[2]);
670
671 // 1
672 out.val[1] = wrapper::vmla(out.val[1], vtop_s32.val[1], m0.val[0]);
673 out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vtop_s32.val[1], vtop_s32.val[2]), m0.val[1]);
674 out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vtop_s32.val[1], vtop_s32.val[2]), m0.val[2]);
675
676 out.val[1] = wrapper::vmla(out.val[1], vmid_s32.val[1], m1.val[0]);
677 out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vmid_s32.val[1], vmid_s32.val[2]), m1.val[1]);
678 out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vmid_s32.val[1], vmid_s32.val[2]), m1.val[2]);
679
680 out.val[1] = wrapper::vmla(out.val[1], vlow_s32.val[1], m2.val[0]);
681 out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_1(vlow_s32.val[1], vlow_s32.val[2]), m2.val[1]);
682 out.val[1] = wrapper::vmla(out.val[1], wrapper::vext_2(vlow_s32.val[1], vlow_s32.val[2]), m2.val[2]);
683
684 if(stridex == 1)
685 {
686 accumulate ? accumulate_results<1>(out_ptr, out) : store_results<1>(out_ptr, out);
687 }
688 else if(stridex == 2)
689 {
690 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 2), out.val[0], 1);
691 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 0), out.val[0], 2);
692 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[1], 2), out.val[0], 3);
693
694 accumulate ? accumulate_results<2>(out_ptr, out) : store_results<2>(out_ptr, out);
695 }
696 else if(stridex == 3)
697 {
698 out.val[0] = wrapper::vsetlane(wrapper::vgetlane(out.val[0], 3), out.val[0], 1);
699 accumulate ? accumulate_results<3>(out_ptr, out) : store_results<3>(out_ptr, out);
700 }
701 }
702
703 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
704 /** Loads a 3x3 matrix as a row (float16_t).
705 *
706 * @param[in] ptr Pointer to a float 3x3 matrix.
707 *
708 * @return The loaded matrix.
709 */
710 inline float16x8x3_t load_matrix_row(const float16_t *ptr, int weights_offset = 0)
711 {
712 ARM_COMPUTE_UNUSED(weights_offset);
713 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
714 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
715 const float16x8x3_t r =
716 {
717 {
718 vld1q_dup_f16(ptr),
719 vld1q_dup_f16(1 + ptr),
720 vld1q_dup_f16(2 + ptr)
721 }
722 };
723 return r;
724 }
725
726 /** Perform a 3x3 convolution for 8 consecutive elements on float16 when dilation.x() or dilation.y() is not 1.
727 *
728 * @param[in] in_top Pointer to the first row of the input.
729 * @param[in] in_mid Pointer to the second row of the input.
730 * @param[in] in_low Pointer to the third row of the input.
731 * @param[in] m0 First row of the filter.
732 * @param[in] m1 Second row of the filter.
733 * @param[in] m2 Third row of the filter.
734 * @param[in] dilation_x Dilation, in elements across x.
735 * @param[in] input_offset (Optional)Input quantization offset.
736 *
737 */
738 inline float16x8_t single_convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
739 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
740 const size_t dilation_x, int input_offset = 0)
741 {
742 ARM_COMPUTE_UNUSED(input_offset);
743 const float16x8x3_t vtop =
744 {
745 {
746 vld1q_f16(in_top),
747 vld1q_f16(in_top + dilation_x),
748 vld1q_f16(in_top + 2 * dilation_x)
749 }
750 };
751 const float16x8x3_t vmid =
752 {
753 {
754 vld1q_f16(in_mid),
755 vld1q_f16(in_mid + dilation_x),
756 vld1q_f16(in_mid + 2 * dilation_x)
757 }
758 };
759 const float16x8x3_t vlow =
760 {
761 {
762 vld1q_f16(in_low),
763 vld1q_f16(in_low + dilation_x),
764 vld1q_f16(in_low + 2 * dilation_x)
765 }
766 };
767 float16x8_t out = vmulq_f16(vtop.val[0], m0.val[0]);
768 out = vaddq_f16(out, vmulq_f16(vtop.val[1], m0.val[1]));
769 out = vaddq_f16(out, vmulq_f16(vtop.val[2], m0.val[2]));
770
771 out = vaddq_f16(out, vmulq_f16(vmid.val[0], m1.val[0]));
772 out = vaddq_f16(out, vmulq_f16(vmid.val[1], m1.val[1]));
773 out = vaddq_f16(out, vmulq_f16(vmid.val[2], m1.val[2]));
774
775 out = vaddq_f16(out, vmulq_f16(vlow.val[0], m2.val[0]));
776 out = vaddq_f16(out, vmulq_f16(vlow.val[1], m2.val[1]));
777 out = vaddq_f16(out, vmulq_f16(vlow.val[2], m2.val[2]));
778
779 return out;
780 }
781
782 /** Perform a 3x3 convolution for 16 consecutive elements on float16 when dilation.x() or dilation.y() is not 1.
783 *
784 * @param[in] in_top Pointer to the first row of the input.
785 * @param[in] in_mid Pointer to the second row of the input.
786 * @param[in] in_low Pointer to the third row of the input.
787 * @param[in] m0 First row of the filter.
788 * @param[in] m1 Second row of the filter.
789 * @param[in] m2 Third row of the filter.
790 * @param[in] dilation_x Dilation, in elements across x.
791 * @param[in] stridex Stride value in elements across x.
792 * @param[in] input_offset (Optional) Input quantization offset.
793 *
794 */
795 inline float16x8x2_t convolve_3x3_dilation(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low,
796 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
797 const size_t dilation_x, unsigned int stridex, int input_offset = 0)
798 {
799 float16x8x2_t out =
800 {
801 {
802 single_convolve_3x3_dilation(in_top, in_mid, in_low, m0, m1, m2, dilation_x, input_offset),
803 single_convolve_3x3_dilation(in_top + 8, in_mid + 8, in_low + 8, m0, m1, m2, dilation_x, input_offset)
804 }
805 };
806
807 if(stridex == 2)
808 {
809 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
810 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 4), out.val[0], 2);
811 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 3);
812 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 4);
813 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 5);
814 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 4), out.val[0], 6);
815 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 6), out.val[0], 7);
816 }
817 else if(stridex == 3)
818 {
819 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
820 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2);
821 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3);
822 }
823
824 return out;
825 }
826
827 /** Perform a convolve3x3 on float16.
828 *
829 * @param[in] in_top Pointer to the first row of the input.
830 * @param[in] in_mid Pointer to the second row of the input.
831 * @param[in] in_low Pointer to the third row of the input.
832 * @param[out] out_ptr Pointer to the output.
833 * @param[in] m0 First row of the filter.
834 * @param[in] m1 Second row of the filter.
835 * @param[in] m2 Third row of the filter.
836 * @param[in] stridex Stride value in elements across x.
837 * @param[in] input_offset (Optional) Input quantization offset.
838 *
839 */
840 template <bool accumulate>
841 inline void convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, float16_t *out_ptr,
842 const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
843 unsigned int stridex, int input_offset = 0)
844 {
845 ARM_COMPUTE_UNUSED(input_offset);
846
847 float16x8x2_t out =
848 {
849 {
850 vdupq_n_f16(0),
851 vdupq_n_f16(0)
852 }
853 };
854 if(stridex == 2)
855 {
856 const float16x8x2_t vtop = vld2q_f16(in_top);
857 const float16x8x2_t vmid = vld2q_f16(in_mid);
858 const float16x8x2_t vlow = vld2q_f16(in_low);
859 const float16x8_t vtop_end = vld1q_f16(in_top + 16);
860 const float16x8_t vmid_end = vld1q_f16(in_mid + 16);
861 const float16x8_t vlow_end = vld1q_f16(in_low + 16);
862
863 out.val[0] = vmulq_f16(vtop.val[0], m0.val[0]);
864
865 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vtop.val[1], m0.val[1]));
866 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop_end, 1), m0.val[2]));
867
868 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
869 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[1], m1.val[1]));
870 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid_end, 1), m1.val[2]));
871
872 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
873 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[1], m2.val[1]));
874 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow_end, 1), m2.val[2]));
875
876 accumulate ? accumulate_results<2>(out_ptr, out) : store_results<2>(out_ptr, out);
877 }
878 else
879 {
880 const float16x8x3_t vtop =
881 {
882 {
883 vld1q_f16(in_top),
884 vld1q_f16(in_top + 8),
885 vld1q_f16(in_top + 16)
886 }
887 };
888 const float16x8x3_t vmid =
889 {
890 {
891 vld1q_f16(in_mid),
892 vld1q_f16(in_mid + 8),
893 vld1q_f16(in_mid + 16)
894 }
895 };
896 const float16x8x3_t vlow =
897 {
898 {
899 vld1q_f16(in_low),
900 vld1q_f16(in_low + 8),
901 vld1q_f16(in_low + 16)
902 }
903 };
904 out.val[0] = vmulq_f16(vtop.val[0], m0.val[0]);
905 out.val[1] = vmulq_f16(vtop.val[1], m0.val[0]);
906
907 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1]));
908 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2]));
909 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
910 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1]));
911 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2]));
912 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
913 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1]));
914 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2]));
915 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1]));
916 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2]));
917 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0]));
918 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1]));
919 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2]));
920 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0]));
921 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1]));
922 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2]));
923
924 if(stridex == 3)
925 {
926 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
927 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 6), out.val[0], 2);
928 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 1), out.val[0], 3);
929
930 accumulate ? accumulate_results<3>(out_ptr, out) : store_results<3>(out_ptr, out);
931 }
932 else
933 {
934 accumulate ? accumulate_results<1>(out_ptr, out) : store_results<1>(out_ptr, out);
935 }
936 }
937 }
938 #endif /** __ARM_FEATURE_FP16_VECTOR_ARITHMETIC **/
939
940 /** Get the number of elements processed on 3x3 convolution.
941 *
942 * @param[in] num_elems_written_per_iteration Number of elements written per iteration on 3x3 convolution.
943 * @param[in] stridex Stride value in elements across x.
944 *
945 * @return The number of elements processed.
946 */
get_input_num_elems_processed(unsigned int num_elems_written_per_iteration,unsigned int stridex)947 inline int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration, unsigned int stridex)
948 {
949 switch(stridex)
950 {
951 case 1:
952 return num_elems_written_per_iteration;
953 case 2:
954 return num_elems_written_per_iteration << 1;
955 case 3:
956 return num_elems_written_per_iteration * 3;
957 default:
958 ARM_COMPUTE_ERROR("stridex not supported");
959 return 0;
960 }
961 }
962 }
963 } // namespace arm_compute
964 #endif /* ARM_COMPUTE_NEDIRECTCONVOLUTIONDETAIL_H */
965