1 /*
2 * Copyright (c) 2016-2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS)
25
26 #include "arm_compute/core/TensorInfo.h"
27 #include "src/cpu/kernels/CpuCastKernel.h"
28 #include "src/cpu/kernels/cast/list.h"
29 #include "support/SaturateCast.h"
30
31 namespace arm_compute
32 {
33 namespace cpu
34 {
neon_qasymm8_signed_to_fp16_cast(const ITensor * _src,ITensor * _dst,const ThreadInfo & info,ConvertPolicy _policy,const Window & window)35 void neon_qasymm8_signed_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window)
36 {
37 ARM_COMPUTE_UNUSED(info);
38 ARM_COMPUTE_UNUSED(_policy);
39
40 const auto window_start_x = static_cast<int>(window.x().start());
41 const auto window_end_x = static_cast<int>(window.x().end());
42 const int window_step_x = 16;
43
44 ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
45 ARM_COMPUTE_ERROR_ON(_src == _dst);
46
47 ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
48
49 Window win{ window };
50 win.set(Window::DimX, Window::Dimension(0, 1, 1));
51
52 Iterator src(_src, win);
53 Iterator dst(_dst, win);
54 execute_window_loop(win, [&](const Coordinates &)
55 {
56 const auto src_ptr = reinterpret_cast<const int8_t *>(src.ptr());
57 const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
58 int x = window_start_x;
59
60 for(; x <= (window_end_x - window_step_x); x += window_step_x)
61 {
62 const int8x16_t texels_s8 = vld1q_s8(src_ptr + x);
63
64 const int16x8x2_t texels =
65 {
66 {
67 vmovl_s8(vget_low_s8(texels_s8)),
68 vmovl_s8(vget_high_s8(texels_s8))
69 }
70 };
71 vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0]));
72 vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1]));
73 }
74
75 // Compute left-over elements
76 for(; x < window_end_x; ++x)
77 {
78 *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
79 }
80 },
81 src, dst);
82 }
83
neon_s32_to_fp16_cast(const ITensor * _src,ITensor * _dst,const ThreadInfo & info,ConvertPolicy _policy,const Window & window)84 void neon_s32_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window)
85 {
86 ARM_COMPUTE_UNUSED(info);
87 ARM_COMPUTE_UNUSED(_policy);
88
89 const auto window_start_x = static_cast<int>(window.x().start());
90 const auto window_end_x = static_cast<int>(window.x().end());
91 const int window_step_x = 16;
92
93 ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
94 ARM_COMPUTE_ERROR_ON(_src == _dst);
95
96 ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
97
98 Window win{ window };
99 win.set(Window::DimX, Window::Dimension(0, 1, 1));
100
101 Iterator src(_src, win);
102 Iterator dst(_dst, win);
103
104 execute_window_loop(win, [&](const Coordinates &)
105 {
106 const auto src_ptr = reinterpret_cast<const int32_t *>(src.ptr());
107 const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
108
109 int x = window_start_x;
110 for(; x <= (window_end_x - window_step_x); x += window_step_x)
111 {
112 const float32x4x4_t texels =
113 {
114 {
115 vcvtq_f32_s32(vld1q_s32(src_ptr + x)),
116 vcvtq_f32_s32(vld1q_s32(src_ptr + x + 4)),
117 vcvtq_f32_s32(vld1q_s32(src_ptr + x + 8)),
118 vcvtq_f32_s32(vld1q_s32(src_ptr + x + 12))
119 }
120 };
121
122 vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1])));
123 vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3])));
124 }
125
126 // Compute left-over elements
127 for(; x < window_end_x; ++x)
128 {
129 *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
130 }
131 },
132 src, dst);
133 }
134
neon_fp32_to_fp16_cast(const ITensor * _src,ITensor * _dst,const ThreadInfo & info,ConvertPolicy _policy,const Window & window)135 void neon_fp32_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window)
136 {
137 ARM_COMPUTE_UNUSED(info);
138 ARM_COMPUTE_UNUSED(_policy);
139
140 const auto window_start_x = static_cast<int>(window.x().start());
141 const auto window_end_x = static_cast<int>(window.x().end());
142 const int window_step_x = 16;
143
144 ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
145 ARM_COMPUTE_ERROR_ON(_src == _dst);
146
147 ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
148
149 Window win{ window };
150 win.set(Window::DimX, Window::Dimension(0, 1, 1));
151
152 Iterator src(_src, win);
153 Iterator dst(_dst, win);
154
155 execute_window_loop(win, [&](const Coordinates &)
156 {
157 const auto src_ptr = reinterpret_cast<const float *>(src.ptr());
158 const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
159
160 int x = window_start_x;
161 for(; x <= (window_end_x - window_step_x); x += window_step_x)
162 {
163 const float32x4x4_t texels =
164 {
165 {
166 vld1q_f32(src_ptr + x),
167 vld1q_f32(src_ptr + x + 4),
168 vld1q_f32(src_ptr + x + 8),
169 vld1q_f32(src_ptr + x + 12)
170 }
171 };
172
173 vst1q_f16(dst_ptr + x, vcombine_f16(vcvt_f16_f32(texels.val[0]), vcvt_f16_f32(texels.val[1])));
174 vst1q_f16(dst_ptr + x + 8, vcombine_f16(vcvt_f16_f32(texels.val[2]), vcvt_f16_f32(texels.val[3])));
175 }
176
177 // Compute left-over elements
178 for(; x < window_end_x; ++x)
179 {
180 *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
181 }
182 },
183 src, dst);
184 }
185
neon_fp16_to_other_dt_cast(const ITensor * _src,ITensor * _dst,const ThreadInfo & info,ConvertPolicy _policy,const Window & window)186 void neon_fp16_to_other_dt_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window)
187 {
188 ARM_COMPUTE_UNUSED(info);
189 ARM_COMPUTE_UNUSED(_policy);
190
191 const auto window_start_x = static_cast<int>(window.x().start());
192 const auto window_end_x = static_cast<int>(window.x().end());
193 const int window_step_x = 16;
194
195 ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
196 ARM_COMPUTE_ERROR_ON(_src == _dst);
197
198 ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
199
200 Window win{ window };
201 win.set(Window::DimX, Window::Dimension(0, 1, 1));
202
203 Iterator src(_src, win);
204 Iterator dst(_dst, win);
205 switch(_dst->info()->data_type())
206 {
207 case DataType::QASYMM8_SIGNED:
208 {
209 /* Down-conversion F16 -> QASYMM8_SIGNED (Always saturating) */
210 execute_window_loop(win, [&](const Coordinates &)
211 {
212 const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
213 const auto dst_ptr = reinterpret_cast<int8_t *>(dst.ptr());
214
215 int x = window_start_x;
216 for(; x <= (window_end_x - window_step_x); x += window_step_x)
217 {
218 const float16x8x2_t texels =
219 {
220 {
221 vld1q_f16(src_ptr + x),
222 vld1q_f16(src_ptr + x + 8),
223 }
224 };
225
226 vst1q_s8(dst_ptr + x, vcombine_s8(vqmovn_s16(vcvtq_s16_f16(texels.val[0])), vqmovn_s16(vcvtq_s16_f16(texels.val[1]))));
227 }
228
229 // Compute left-over elements
230 for(; x < window_end_x; ++x)
231 {
232 *(dst_ptr + x) = utils::cast::saturate_cast<int8_t>(*(src_ptr + x));
233 }
234 },
235 src, dst);
236 break;
237 }
238 case DataType::QASYMM8:
239 case DataType::U8:
240 {
241 /* Down-conversion F16 -> QASYMM8/U8 (Always saturating) */
242 execute_window_loop(win, [&](const Coordinates &)
243 {
244 const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
245 const auto dst_ptr = reinterpret_cast<uint8_t *>(dst.ptr());
246
247 int x = window_start_x;
248 for(; x <= (window_end_x - window_step_x); x += window_step_x)
249 {
250 const float16x8x2_t texels =
251 {
252 {
253 vld1q_f16(src_ptr + x),
254 vld1q_f16(src_ptr + x + 8),
255 }
256 };
257
258 vst1q_u8(dst_ptr + x, vcombine_u8(vqmovun_s16(vcvtq_s16_f16(texels.val[0])), vqmovun_s16(vcvtq_s16_f16(texels.val[1]))));
259 }
260
261 // Compute left-over elements
262 for(; x < window_end_x; ++x)
263 {
264 *(dst_ptr + x) = utils::cast::saturate_cast<uint8_t>(*(src_ptr + x));
265 }
266
267 },
268 src, dst);
269 break;
270 }
271 case DataType::F32:
272 {
273 /* Up-conversion F16 -> F32 */
274 execute_window_loop(win, [&](const Coordinates &)
275 {
276 const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
277 const auto dst_ptr = reinterpret_cast<float *>(dst.ptr());
278
279 int x = window_start_x;
280 for(; x <= (window_end_x - window_step_x); x += window_step_x)
281 {
282 const float16x8x2_t texels =
283 {
284 {
285 vld1q_f16(src_ptr + x),
286 vld1q_f16(src_ptr + x + 8)
287 }
288 };
289 vst1q_f32(dst_ptr + x, vcvt_f32_f16(vget_low_f16(texels.val[0])));
290 vst1q_f32(dst_ptr + x + 4, vcvt_f32_f16(vget_high_f16(texels.val[0])));
291 vst1q_f32(dst_ptr + x + 8, vcvt_f32_f16(vget_low_f16(texels.val[1])));
292 vst1q_f32(dst_ptr + x + 12, vcvt_f32_f16(vget_high_f16(texels.val[1])));
293 }
294
295 // Compute left-over elements
296 for(; x < window_end_x; ++x)
297 {
298 *(dst_ptr + x) = static_cast<float>(*(src_ptr + x));
299 }
300 },
301 src, dst);
302 break;
303 }
304 case DataType::S32:
305 {
306 /* Up-conversion F16 -> S32 */
307 execute_window_loop(win, [&](const Coordinates &)
308 {
309 const auto src_ptr = reinterpret_cast<const float16_t *>(src.ptr());
310 const auto dst_ptr = reinterpret_cast<int32_t *>(dst.ptr());
311
312 int x = window_start_x;
313 for(; x <= (window_end_x - window_step_x); x += window_step_x)
314 {
315 const float16x8x2_t texels =
316 {
317 {
318 vld1q_f16(src_ptr + x),
319 vld1q_f16(src_ptr + x + 8)
320 }
321 };
322
323 vst1q_s32(dst_ptr + x, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[0]))));
324 vst1q_s32(dst_ptr + x + 4, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[0]))));
325 vst1q_s32(dst_ptr + x + 8, vcvtq_s32_f32(vcvt_f32_f16(vget_low_f16(texels.val[1]))));
326 vst1q_s32(dst_ptr + x + 12, vcvtq_s32_f32(vcvt_f32_f16(vget_high_f16(texels.val[1]))));
327 }
328
329 // Compute left-over elements
330 for(; x < window_end_x; ++x)
331 {
332 *(dst_ptr + x) = static_cast<int32_t>(*(src_ptr + x));
333 }
334 },
335 src, dst);
336 break;
337 }
338 default:
339 ARM_COMPUTE_ERROR("dst data type not supported");
340 }
341 }
342
neon_u8_to_fp16_cast(const ITensor * _src,ITensor * _dst,const ThreadInfo & info,ConvertPolicy _policy,const Window & window)343 void neon_u8_to_fp16_cast(const ITensor *_src, ITensor *_dst, const ThreadInfo &info, ConvertPolicy _policy, const Window &window)
344 {
345 ARM_COMPUTE_UNUSED(info);
346 ARM_COMPUTE_UNUSED(_policy);
347
348 const auto window_start_x = static_cast<int>(window.x().start());
349 const auto window_end_x = static_cast<int>(window.x().end());
350 const int window_step_x = 16;
351
352 ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
353 ARM_COMPUTE_ERROR_ON(_src == _dst);
354
355 ARM_COMPUTE_ERROR_ON_NULLPTR(_src, _dst);
356
357 Window win{ window };
358 win.set(Window::DimX, Window::Dimension(0, 1, 1));
359
360 Iterator src(_src, win);
361 Iterator dst(_dst, win);
362 /* Up-conversion U8 -> F16 */
363 execute_window_loop(win, [&](const Coordinates &)
364 {
365 const auto src_ptr = reinterpret_cast<const uint8_t *>(src.ptr());
366 const auto dst_ptr = reinterpret_cast<float16_t *>(dst.ptr());
367
368 int x = window_start_x;
369 for(; x <= (window_end_x - window_step_x); x += window_step_x)
370 {
371 const uint8x16_t texels_u8 = vld1q_u8(src_ptr + x);
372
373 const int16x8x2_t texels =
374 {
375 {
376 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(texels_u8))),
377 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(texels_u8)))
378 }
379 };
380 vst1q_f16(dst_ptr + x, vcvtq_f16_s16(texels.val[0]));
381 vst1q_f16(dst_ptr + x + 8, vcvtq_f16_s16(texels.val[1]));
382 }
383
384 // Compute left-over elements
385 for(; x < window_end_x; ++x)
386 {
387 *(dst_ptr + x) = static_cast<float16_t>(*(src_ptr + x));
388 }
389 },
390 src, dst);
391 return;
392 }
393
394 } // namespace cpu
395 } // namespace arm_compute
396 #endif /* #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && defined(ENABLE_FP16_KERNELS) */
397