xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/kernels/cast/generic/neon/fp16.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2016-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
25 
26 #include "arm_compute/core/TensorInfo.h"
27 #include "src/cpu/kernels/CpuCastKernel.h"
28 #include "src/cpu/kernels/cast/list.h"
29 #include "support/SaturateCast.h"
30 
31 namespace arm_compute
32 {
33 namespace cpu
34 {
neon_qasymm8_signed_to_fp16_cast(const ITensor * _src,ITensor * _dst,const ThreadInfo & info,ConvertPolicy _policy,const Window & window)35 void neon_qasymm8_signed_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window)
36 {
37     ARM_COMPUTE_UNUSED(info);
38     ARM_COMPUTE_UNUSED(_policy);
39 
40     const auto window_start_x = static_cast<int>(window.x().start());
41     const auto window_end_x   = static_cast<int>(window.x().end());
42     const int  window_step_x  = 16;
43 
44     ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
45     ARM_COMPUTE_ERROR_ON(_src == _dst);
46 
47     ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
48 
49     Window win{ window };
50     win.set(Window::DimX, Window::Dimension(0, 1, 1));
51 
52     Iterator src(_src, win);
53     Iterator dst(_dst, win);
54     execute_window_loop(win, [&](const Coordinates &)
55     {
56         const auto src_ptr = reinterpret_cast<const int8_t *>(src.ptr());
57         const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
58         int        x       = window_start_x;
59 
60         for(; x <= (window_end_x - window_step_x); x += window_step_x)
61         {
62             const int8x16_t texels_s8 = vld1q_s8(src_ptr + x);
63 
64             const int16x8x2_t texels =
65             {
66                 {
67                     vmovl_s8(vget_low_s8(texels_s8)),
68                     vmovl_s8(vget_high_s8(texels_s8))
69                 }
70             };
71             vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0]));
72             vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1]));
73         }
74 
75         // Compute left-over elements
76         for(; x < window_end_x; ++x)
77         {
78             *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
79         }
80     },
81     src, dst);
82 }
83 
neon_s32_to_fp16_cast(const ITensor * _src,ITensor * _dst,const ThreadInfo & info,ConvertPolicy _policy,const Window & window)84 void neon_s32_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window)
85 {
86     ARM_COMPUTE_UNUSED(info);
87     ARM_COMPUTE_UNUSED(_policy);
88 
89     const auto window_start_x = static_cast<int>(window.x().start());
90     const auto window_end_x   = static_cast<int>(window.x().end());
91     const int  window_step_x  = 16;
92 
93     ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
94     ARM_COMPUTE_ERROR_ON(_src == _dst);
95 
96     ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
97 
98     Window win{ window };
99     win.set(Window::DimX, Window::Dimension(0, 1, 1));
100 
101     Iterator src(_src, win);
102     Iterator dst(_dst, win);
103 
104     execute_window_loop(win, [&](const Coordinates &)
105     {
106         const auto src_ptr = reinterpret_cast<const int32_t *>(src.ptr());
107         const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
108 
109         int x = window_start_x;
110         for(; x <= (window_end_x - window_step_x); x += window_step_x)
111         {
112             const float32x4x4_t texels =
113             {
114                 {
115                     vcvtq_f32_s32(vld1q_s32(src_ptr + x)),
116                     vcvtq_f32_s32(vld1q_s32(src_ptr + x + 4)),
117                     vcvtq_f32_s32(vld1q_s32(src_ptr + x + 8)),
118                     vcvtq_f32_s32(vld1q_s32(src_ptr + x + 12))
119                 }
120             };
121 
122             vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1])));
123             vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3])));
124         }
125 
126         // Compute left-over elements
127         for(; x < window_end_x; ++x)
128         {
129             *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
130         }
131     },
132     src, dst);
133 }
134 
neon_fp32_to_fp16_cast(const ITensor * _src,ITensor * _dst,const ThreadInfo & info,ConvertPolicy _policy,const Window & window)135 void neon_fp32_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window)
136 {
137     ARM_COMPUTE_UNUSED(info);
138     ARM_COMPUTE_UNUSED(_policy);
139 
140     const auto window_start_x = static_cast<int>(window.x().start());
141     const auto window_end_x   = static_cast<int>(window.x().end());
142     const int  window_step_x  = 16;
143 
144     ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
145     ARM_COMPUTE_ERROR_ON(_src == _dst);
146 
147     ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
148 
149     Window win{ window };
150     win.set(Window::DimX, Window::Dimension(0, 1, 1));
151 
152     Iterator src(_src, win);
153     Iterator dst(_dst, win);
154 
155     execute_window_loop(win, [&](const Coordinates &)
156     {
157         const auto src_ptr = reinterpret_cast<const float *>(src.ptr());
158         const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
159 
160         int x = window_start_x;
161         for(; x <= (window_end_x - window_step_x); x += window_step_x)
162         {
163             const float32x4x4_t texels =
164             {
165                 {
166                     vld1q_f32(src_ptr + x),
167                     vld1q_f32(src_ptr + x + 4),
168                     vld1q_f32(src_ptr + x + 8),
169                     vld1q_f32(src_ptr + x + 12)
170                 }
171             };
172 
173             vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1])));
174             vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3])));
175         }
176 
177         // Compute left-over elements
178         for(; x < window_end_x; ++x)
179         {
180             *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
181         }
182     },
183     src, dst);
184 }
185 
neon_fp16_to_other_dt_cast(const ITensor * _src,ITensor * _dst,const ThreadInfo & info,ConvertPolicy _policy,const Window & window)186 void neon_fp16_to_other_dt_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window)
187 {
188     ARM_COMPUTE_UNUSED(info);
189     ARM_COMPUTE_UNUSED(_policy);
190 
191     const auto window_start_x = static_cast<int>(window.x().start());
192     const auto window_end_x   = static_cast<int>(window.x().end());
193     const int  window_step_x  = 16;
194 
195     ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
196     ARM_COMPUTE_ERROR_ON(_src == _dst);
197 
198     ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
199 
200     Window win{ window };
201     win.set(Window::DimX, Window::Dimension(0, 1, 1));
202 
203     Iterator src(_src, win);
204     Iterator dst(_dst, win);
205     switch(_dst->info()->data_type())
206     {
207         case DataType::QASYMM8_SIGNED:
208         {
209             /* Down-conversion F16 -> QASYMM8_SIGNED (Always saturating) */
210             execute_window_loop(win, [&](const Coordinates &)
211             {
212                 const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
213                 const auto dst_ptr = reinterpret_cast<int8_t *>(dst.ptr());
214 
215                 int x = window_start_x;
216                 for(; x <= (window_end_x - window_step_x); x += window_step_x)
217                 {
218                     const float16x8x2_t texels =
219                     {
220                         {
221                             vld1q_f16(src_ptr + x),
222                             vld1q_f16(src_ptr + x + 8),
223                         }
224                     };
225 
226                     vst1q_s8(dst_ptr + x, vcombine_s8(vqmovn_s16(vcvtq_s16_f16(texels.val[0])), vqmovn_s16(vcvtq_s16_f16(texels.val[1]))));
227                 }
228 
229                 // Compute left-over elements
230                 for(; x < window_end_x; ++x)
231                 {
232                     *(dst_ptr + x) = utils::cast::saturate_cast<int8_t>(*(src_ptr + x));
233                 }
234             },
235             src, dst);
236             break;
237         }
238         case DataType::QASYMM8:
239         case DataType::U8:
240         {
241             /* Down-conversion F16 -> QASYMM8/U8 (Always saturating) */
242             execute_window_loop(win, [&](const Coordinates &)
243             {
244                 const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
245                 const auto dst_ptr = reinterpret_cast<uint8_t *>(dst.ptr());
246 
247                 int x = window_start_x;
248                 for(; x <= (window_end_x - window_step_x); x += window_step_x)
249                 {
250                     const float16x8x2_t texels =
251                     {
252                         {
253                             vld1q_f16(src_ptr + x),
254                             vld1q_f16(src_ptr + x + 8),
255                         }
256                     };
257 
258                     vst1q_u8(dst_ptr + x, vcombine_u8(vqmovun_s16(vcvtq_s16_f16(texels.val[0])), vqmovun_s16(vcvtq_s16_f16(texels.val[1]))));
259                 }
260 
261                 // Compute left-over elements
262                 for(; x < window_end_x; ++x)
263                 {
264                     *(dst_ptr + x) = utils::cast::saturate_cast<uint8_t>(*(src_ptr + x));
265                 }
266 
267             },
268             src, dst);
269             break;
270         }
271         case DataType::F32:
272         {
273             /* Up-conversion F16 -> F32 */
274             execute_window_loop(win, [&](const Coordinates &)
275             {
276                 const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
277                 const auto dst_ptr = reinterpret_cast<float *>(dst.ptr());
278 
279                 int x = window_start_x;
280                 for(; x <= (window_end_x - window_step_x); x += window_step_x)
281                 {
282                     const float16x8x2_t texels =
283                     {
284                         {
285                             vld1q_f16(src_ptr + x),
286                             vld1q_f16(src_ptr + x + 8)
287                         }
288                     };
289                     vst1q_f32(dst_ptr + x, vcvt_f32_f16(vget_low_f16(texels.val[0])));
290                     vst1q_f32(dst_ptr + x + 4, vcvt_f32_f16(vget_high_f16(texels.val[0])));
291                     vst1q_f32(dst_ptr + x + 8, vcvt_f32_f16(vget_low_f16(texels.val[1])));
292                     vst1q_f32(dst_ptr + x + 12, vcvt_f32_f16(vget_high_f16(texels.val[1])));
293                 }
294 
295                 // Compute left-over elements
296                 for(; x < window_end_x; ++x)
297                 {
298                     *(dst_ptr + x) = static_cast<float>(*(src_ptr + x));
299                 }
300             },
301             src, dst);
302             break;
303         }
304         case DataType::S32:
305         {
306             /* Up-conversion F16 -> S32 */
307             execute_window_loop(win, [&](const Coordinates &)
308             {
309                 const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
310                 const auto dst_ptr = reinterpret_cast<int32_t *>(dst.ptr());
311 
312                 int x = window_start_x;
313                 for(; x <= (window_end_x - window_step_x); x += window_step_x)
314                 {
315                     const float16x8x2_t texels =
316                     {
317                         {
318                             vld1q_f16(src_ptr + x),
319                             vld1q_f16(src_ptr + x + 8)
320                         }
321                     };
322 
323                     vst1q_s32(dst_ptr + x, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[0]))));
324                     vst1q_s32(dst_ptr + x + 4, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[0]))));
325                     vst1q_s32(dst_ptr + x + 8, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[1]))));
326                     vst1q_s32(dst_ptr + x + 12, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[1]))));
327                 }
328 
329                 // Compute left-over elements
330                 for(; x < window_end_x; ++x)
331                 {
332                     *(dst_ptr + x) = static_cast<int32_t>(*(src_ptr + x));
333                 }
334             },
335             src, dst);
336             break;
337         }
338         default:
339             ARM_COMPUTE_ERROR("dst data type not supported");
340     }
341 }
342 
neon_u8_to_fp16_cast(const ITensor * _src,ITensor * _dst,const ThreadInfo & info,ConvertPolicy _policy,const Window & window)343 void neon_u8_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window)
344 {
345     ARM_COMPUTE_UNUSED(info);
346     ARM_COMPUTE_UNUSED(_policy);
347 
348     const auto window_start_x = static_cast<int>(window.x().start());
349     const auto window_end_x   = static_cast<int>(window.x().end());
350     const int  window_step_x  = 16;
351 
352     ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
353     ARM_COMPUTE_ERROR_ON(_src == _dst);
354 
355     ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
356 
357     Window win{ window };
358     win.set(Window::DimX, Window::Dimension(0, 1, 1));
359 
360     Iterator src(_src, win);
361     Iterator dst(_dst, win);
362     /* Up-conversion U8 -> F16 */
363     execute_window_loop(win, [&](const Coordinates &)
364     {
365         const auto src_ptr = reinterpret_cast<const uint8_t *>(src.ptr());
366         const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
367 
368         int x = window_start_x;
369         for(; x <= (window_end_x - window_step_x); x += window_step_x)
370         {
371             const uint8x16_t texels_u8 = vld1q_u8(src_ptr + x);
372 
373             const int16x8x2_t texels =
374             {
375                 {
376                     vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(texels_u8))),
377                     vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(texels_u8)))
378                 }
379             };
380             vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0]));
381             vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1]));
382         }
383 
384         // Compute left-over elements
385         for(; x < window_end_x; ++x)
386         {
387             *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
388         }
389     },
390     src, dst);
391     return;
392 }
393 
394 } // namespace cpu
395 } // namespace arm_compute
396 #endif /* #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */
397