1 /*******************************************************************************
2 * Copyright (c) 2018-2024 Cadence Design Systems, Inc.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining
5 * a copy of this software and associated documentation files (the
6 * "Software"), to use this Software with Cadence processor cores only and
7 * not with any other processors and platforms, subject to
8 * the following conditions:
9 *
10 * The above copyright notice and this permission notice shall be included
11 * in all copies or substantial portions of the Software.
12 *
13 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
14 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
15 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
17 * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
18 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
19 * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
20
21 ******************************************************************************/
22 #include "xa_type_def.h"
23 #include "xa_nnlib_common_fpu.h"
24 #include "xa_nn_common.h"
25 #include "xa_nnlib_err_chk.h"
26 #include "xa_nnlib_kernels_api.h"
27
28
29 #if HAVE_VFPU
internal_elm_add_broadcast_2D_f32xf32_f32(FLOAT32 * __restrict__ p_out,const FLOAT32 * __restrict__ p_inp1,const FLOAT32 * __restrict__ p_inp2,WORD32 out_lc,WORD32 in_lc,xtbool sign_flag)30 static void internal_elm_add_broadcast_2D_f32xf32_f32(FLOAT32 * __restrict__ p_out,
31 const FLOAT32 * __restrict__ p_inp1,
32 const FLOAT32 * __restrict__ p_inp2,
33 WORD32 out_lc,
34 WORD32 in_lc,
35 xtbool sign_flag)
36 {
37 int i, j;
38
39 xtfloatx2 * __restrict__ p_a = (xtfloatx2 *)p_inp1;
40 xtfloatx2 * __restrict__ p_b = (xtfloatx2 *)p_inp2;
41 xtfloatx2 *__restrict__ p_c = (xtfloatx2 *)p_out;
42
43 int num_simd2_ops;
44 int num_scalar_ops;
45
46 if(out_lc)
47 {
48 num_simd2_ops = in_lc >> 1;
49 num_scalar_ops = in_lc & 1;
50 }
51 else
52 {
53 num_simd2_ops = (in_lc >> 2) << 1;
54 num_scalar_ops = in_lc & 3;
55 }
56
57 xtfloatx2 x1, x2, y;
58 xtfloat a0, b0, c0;
59
60 /* For computing inp2 + inp1 */
61 if(sign_flag){
62 for(i = 0; i < out_lc; i++)
63 {
64 p_a = (xtfloatx2 *)&p_inp1[i * in_lc];
65 p_b = (xtfloatx2 *)p_inp2;
66 p_c = (xtfloatx2 *)&p_out[i * in_lc];
67 if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_b)&7) == 0) && ((((unsigned)p_c)&7) == 0))
68 {
69 for(j = 0; j < num_simd2_ops; j++)
70 {
71 XT_LSX2IP(x1, p_a, 2 * sizeof(FLOAT32));
72 XT_LSX2IP(x2, p_b, 2 * sizeof(FLOAT32));
73 y = XT_ADD_SX2(x2, x1);
74 XT_SSX2IP(y, p_c, 2 * sizeof(FLOAT32));
75 }
76 }
77 else
78 {
79 ae_valign vinp1, vinp2, out_a = AE_ZALIGN64();
80 vinp1 = XT_LASX2PP(p_a);
81 vinp2 = XT_LASX2PP(p_b);
82 for(j = 0; j < num_simd2_ops; j++)
83 {
84 XT_LASX2IP(x1, vinp1, p_a);
85 XT_LASX2IP(x2, vinp2, p_b);
86 y = XT_ADD_SX2(x2, x1);
87 XT_SASX2IP(y, out_a, p_c);
88 }
89 XT_SASX2POSFP(out_a, (xtfloatx2 *)p_c);
90 }
91 if(num_scalar_ops !=0)
92 {
93 XT_LSIP(a0, (xtfloat *)p_a, sizeof(FLOAT32));
94 XT_LSIP(b0, (xtfloat *)p_b, sizeof(FLOAT32));
95 c0 = XT_ADD_S(b0, a0);
96 XT_SSI(c0, (xtfloat *)p_c, 0);
97 }
98 }
99 }
100 /* For computing inp1 + inp2 */
101 else
102 {
103 for(i = 0; i < out_lc; i++)
104 {
105 p_a = (xtfloatx2 *)&p_inp1[i * in_lc];
106 p_b = (xtfloatx2 *)p_inp2;
107 p_c = (xtfloatx2 *)&p_out[i * in_lc];
108 if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_b)&7) == 0) && ((((unsigned)p_c)&7) == 0))
109 {
110 for(j = 0; j < num_simd2_ops; j++)
111 {
112 XT_LSX2IP(x1, p_a, 2 * sizeof(FLOAT32));
113 XT_LSX2IP(x2, p_b, 2 * sizeof(FLOAT32));
114 y = XT_ADD_SX2(x1, x2);
115 XT_SSX2IP(y, p_c, 2 * sizeof(FLOAT32));
116 }
117 }
118 else
119 {
120 ae_valign vinp1, vinp2, out_a = AE_ZALIGN64();
121 vinp1 = XT_LASX2PP(p_a);
122 vinp2 = XT_LASX2PP(p_b);
123
124 for(j = 0; j < num_simd2_ops; j++)
125 {
126 XT_LASX2IP(x1, vinp1, p_a);
127 XT_LASX2IP(x2, vinp2, p_b);
128 y = XT_ADD_SX2(x1, x2);
129 XT_SASX2IP(y, out_a, p_c);
130 }
131 XT_SASX2POSFP(out_a, (xtfloatx2 *)p_c);
132 }
133 if(num_scalar_ops !=0)
134 {
135 XT_LSIP(a0, (xtfloat *)p_a, sizeof(FLOAT32));
136 XT_LSIP(b0, (xtfloat *)p_b, sizeof(FLOAT32));
137 c0 = XT_ADD_S(a0, b0);
138 XT_SSI(c0, (xtfloat *)p_c, 0);
139 }
140 }
141 }
142 }
143
internal_elm_add_broadcast_f32xf32_f32(FLOAT32 * __restrict__ p_out,const FLOAT32 * __restrict__ p_inp1,const FLOAT32 * __restrict__ p_inp2,WORD32 num_elm,xtbool sign_flag)144 static void internal_elm_add_broadcast_f32xf32_f32(FLOAT32 * __restrict__ p_out,
145 const FLOAT32 * __restrict__ p_inp1,
146 const FLOAT32 * __restrict__ p_inp2,
147 WORD32 num_elm,
148 xtbool sign_flag)
149 {
150 int i;
151 xtfloatx2 * __restrict__ p_a = (xtfloatx2 *)p_inp1;
152 xtfloatx2 * __restrict__ p_b = (xtfloatx2 *)p_inp2;
153 xtfloatx2 *__restrict__ p_c = (xtfloatx2 *)p_out;
154
155 const int num_simd2_ops = num_elm >> 1;
156 const int num_scalar_ops = num_elm & 1;
157
158 xtfloat a0_7, out;
159 xtfloatx2 x1, x2, y;
160 x2 = XT_LSI((xtfloat *)p_b, 0);
161
162 /* For computing inp2 + inp1 */
163 if(sign_flag){
164 if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_c)&7) == 0))
165 {
166 for(i=0; i<num_simd2_ops; i++)
167 {
168 XT_LSX2IP(x1, p_a, 2 * sizeof(FLOAT32));
169 y = XT_ADD_SX2(x2, x1);
170 XT_SSX2IP(y, p_c, 2 * sizeof(FLOAT32));
171 }
172 }
173 else
174 {
175 ae_valign inp1_a, out_a;
176 inp1_a = XT_LASX2PP(p_a);
177 out_a = AE_ZALIGN64();
178 for(i=0; i<num_simd2_ops; i++)
179 {
180 XT_LASX2IP(x1, inp1_a, p_a);
181 y = XT_ADD_SX2(x2, x1);
182 XT_SASX2IP(y, out_a, p_c);
183 }
184 XT_SASX2POSFP(out_a, (xtfloatx2 *)p_c);
185 }
186 if(num_scalar_ops !=0)
187 {
188 XT_LSIP(a0_7, (xtfloat *)p_a, sizeof(FLOAT32));
189 out = XT_ADD_S(x2, a0_7);
190 XT_SSI(out, (xtfloat *)p_c, 0);
191 }
192 }
193 /* For computing inp1 + inp2 */
194 else
195 {
196 if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_c)&7) == 0))
197 {
198 for(i=0; i<num_simd2_ops; i++)
199 {
200 XT_LSX2IP(x1, p_a, 2 * sizeof(FLOAT32));
201 y = XT_ADD_SX2(x1, x2);
202 XT_SSX2IP(y, p_c, 2 * sizeof(FLOAT32));
203 }
204 }
205 else
206 {
207 ae_valign inp1_a, out_a;
208 inp1_a = XT_LASX2PP(p_a);
209 out_a = AE_ZALIGN64();
210 for(i=0; i<num_simd2_ops; i++)
211 {
212 XT_LASX2IP(x1, inp1_a, p_a);
213 y = XT_ADD_SX2(x1, x2);
214 XT_SASX2IP(y, out_a, p_c);
215 }
216 XT_SASX2POSFP(out_a, (xtfloatx2 *)p_c);
217 }
218 if(num_scalar_ops !=0)
219 {
220 XT_LSIP(a0_7, (xtfloat *)p_a, sizeof(FLOAT32));
221 out = XT_ADD_S(a0_7, x2);
222 XT_SSI(out, (xtfloat *)p_c, 0);
223 }
224 }
225 }
226 #endif
227
xa_nn_elm_add_broadcast_4D_f32xf32_f32(FLOAT32 * __restrict__ p_out,const WORD32 * const p_out_shape,const FLOAT32 * __restrict__ p_inp1,const WORD32 * const p_inp1_shape,const FLOAT32 * __restrict__ p_inp2,const WORD32 * const p_inp2_shape)228 WORD32 xa_nn_elm_add_broadcast_4D_f32xf32_f32(FLOAT32 * __restrict__ p_out,
229 const WORD32 *const p_out_shape,
230 const FLOAT32 * __restrict__ p_inp1,
231 const WORD32 *const p_inp1_shape,
232 const FLOAT32 * __restrict__ p_inp2,
233 const WORD32 *const p_inp2_shape)
234 {
235 /* NULL pointer checks */
236 XA_NNLIB_ARG_CHK_PTR(p_out, -1);
237 XA_NNLIB_ARG_CHK_PTR(p_inp1, -1);
238 XA_NNLIB_ARG_CHK_PTR(p_inp2, -1);
239 XA_NNLIB_ARG_CHK_PTR(p_out_shape, -1);
240 XA_NNLIB_ARG_CHK_PTR(p_inp1_shape, -1);
241 XA_NNLIB_ARG_CHK_PTR(p_inp2_shape, -1);
242 /* Pointer alignment checks */
243 XA_NNLIB_ARG_CHK_ALIGN(p_out, sizeof(FLOAT32), -1);
244 XA_NNLIB_ARG_CHK_ALIGN(p_inp1, sizeof(FLOAT32), -1);
245 XA_NNLIB_ARG_CHK_ALIGN(p_inp2, sizeof(FLOAT32), -1);
246 XA_NNLIB_ARG_CHK_ALIGN(p_out_shape, sizeof(WORD32), -1);
247 XA_NNLIB_ARG_CHK_ALIGN(p_inp1_shape, sizeof(WORD32), -1);
248 XA_NNLIB_ARG_CHK_ALIGN(p_inp2_shape, sizeof(WORD32), -1);
249
250 /* Check shapes */
251 int i;
252 xtbool sign_flag;
253 for(i = 0; i < 4; i++)
254 {
255 if((p_inp1_shape[i] != p_inp2_shape[i] && p_inp1_shape[i] != 1 && p_inp2_shape[i] != 1) ||
256 (p_out_shape[i] != (p_inp1_shape[i] > p_inp2_shape[i] ? p_inp1_shape[i] : p_inp2_shape[i])))
257 {
258 return -1;
259 }
260 }
261
262 WORD32 inp1_strides[4], inp2_strides[4];
263 inp1_strides[3] = 1;
264 inp2_strides[3] = 1;
265 for(i = 2; i >= 0; i--)
266 {
267 ae_int32x2 d_str, d_shape;
268 d_str = AE_MOVDA32X2(inp1_strides[i + 1], inp2_strides[i + 1]);
269 d_shape = AE_MOVDA32X2(p_inp1_shape[i + 1], p_inp2_shape[i + 1]);
270 d_str = AE_MULP32X2(d_str, d_shape);
271 inp1_strides[i] = AE_MOVAD32_H(d_str);
272 inp2_strides[i] = AE_MOVAD32_L(d_str);
273 }
274
275 int need_broadcast = 0;
276 int inp1_const = 1, inp2_const = 1;
277 for(i = 0; i < 4; i++)
278 {
279 if(p_inp1_shape[i] != p_inp2_shape[i])
280 {
281 if(p_inp1_shape[i] == 1)
282 inp1_strides[i] = 0;
283 else
284 inp2_strides[i] = 0;
285
286 need_broadcast = 1;
287 }
288 if(p_inp1_shape[i] != 1)
289 inp1_const &= 0;
290 if(p_inp2_shape[i] != 1)
291 inp2_const &= 0;
292 }
293 int itr0, itr1, itr2;
294
295 FLOAT32 *p_out_tmp = p_out;
296 const FLOAT32 *__restrict__ p_inp1_tmp = p_inp1;
297 const FLOAT32 *__restrict__ p_inp2_tmp = p_inp2;
298 if(need_broadcast == 0)
299 {
300 sign_flag = 0;
301 internal_elm_add_broadcast_2D_f32xf32_f32(
302 p_out,
303 p_inp1,
304 p_inp2,
305 1,
306 p_out_shape[0] * inp1_strides[0],
307 sign_flag);
308 }
309 else if(inp1_strides[3] == inp2_strides[3])
310 {
311 WORD32 in_lc, out_lc;
312 sign_flag = 0;
313 in_lc = p_out_shape[2] * p_out_shape[3];
314 out_lc = 1;
315 if(inp1_strides[2] == 0)
316 {
317 const FLOAT32 *tmp;
318 tmp = p_inp1_tmp; p_inp1_tmp = p_inp2_tmp; p_inp2_tmp = tmp;
319 sign_flag = 1;
320 int tmp_strides[2];
321 tmp_strides[0] = inp1_strides[0];
322 tmp_strides[1] = inp1_strides[1];
323
324 inp1_strides[0] = inp2_strides[0];
325 inp1_strides[1] = inp2_strides[1];
326
327 inp2_strides[0] = tmp_strides[0];
328 inp2_strides[1] = tmp_strides[1];
329 in_lc = p_out_shape[3];
330 out_lc = p_out_shape[2];
331 }
332 else if(inp2_strides[2] == 0)
333 {
334 in_lc = p_out_shape[3];
335 out_lc = p_out_shape[2];
336 }
337
338 for(itr0 = 0; itr0 < p_out_shape[0]; itr0++)
339 {
340 const FLOAT32 *__restrict__ p_inp1_tmp0 = p_inp1_tmp;
341 const FLOAT32 *__restrict__ p_inp2_tmp0 = p_inp2_tmp;
342 for(itr1 = 0; itr1 < p_out_shape[1]; itr1++)
343 {
344 internal_elm_add_broadcast_2D_f32xf32_f32(
345 p_out_tmp,
346 p_inp1_tmp0,
347 p_inp2_tmp0,
348 out_lc,
349 in_lc,
350 sign_flag);
351 p_out_tmp += in_lc * out_lc;
352 p_inp1_tmp0 += inp1_strides[1];
353 p_inp2_tmp0 += inp2_strides[1];
354 }
355 p_inp1_tmp += inp1_strides[0];
356 p_inp2_tmp += inp2_strides[0];
357 }
358 }
359 else if(inp1_const == 1 || inp2_const == 1)
360 {
361 sign_flag = 0;
362 if(inp1_strides[3] == 0)
363 {
364 sign_flag = 1;
365 const FLOAT32 *tmp;
366 tmp = p_inp1_tmp; p_inp1_tmp = p_inp2_tmp; p_inp2_tmp = tmp;
367 }
368 internal_elm_add_broadcast_f32xf32_f32(
369 p_out_tmp,
370 p_inp1_tmp,
371 p_inp2_tmp,
372 p_out_shape[0] * p_out_shape[1] * p_out_shape[2] * p_out_shape[3],
373 sign_flag);
374 }
375 else
376 {
377 sign_flag = 0;
378 if(inp1_strides[3] == 0)
379 {
380 const FLOAT32 *tmp;
381 tmp = p_inp1_tmp; p_inp1_tmp = p_inp2_tmp; p_inp2_tmp = tmp;
382 sign_flag = 1;
383 int tmp_strides[3];
384 tmp_strides[0] = inp1_strides[0];
385 tmp_strides[1] = inp1_strides[1];
386 tmp_strides[2] = inp1_strides[2];
387
388 inp1_strides[0] = inp2_strides[0];
389 inp1_strides[1] = inp2_strides[1];
390 inp1_strides[2] = inp2_strides[2];
391
392 inp2_strides[0] = tmp_strides[0];
393 inp2_strides[1] = tmp_strides[1];
394 inp2_strides[2] = tmp_strides[2];
395 }
396 for(itr0 = 0; itr0 < p_out_shape[0]; itr0++)
397 {
398 const FLOAT32 *__restrict__ p_inp1_tmp0 = p_inp1_tmp;
399 const FLOAT32 *__restrict__ p_inp2_tmp0 = p_inp2_tmp;
400 for(itr1 = 0; itr1 < p_out_shape[1]; itr1++)
401 {
402 const FLOAT32 *__restrict__ p_inp1_tmp1 = p_inp1_tmp0;
403 const FLOAT32 *__restrict__ p_inp2_tmp1 = p_inp2_tmp0;
404 for(itr2 = 0; itr2 < p_out_shape[2]; itr2++)
405 {
406 {
407 internal_elm_add_broadcast_f32xf32_f32(
408 p_out_tmp,
409 p_inp1_tmp1,
410 p_inp2_tmp1,
411 p_out_shape[3],
412 sign_flag);
413 }
414 p_out_tmp += p_out_shape[3];
415 p_inp1_tmp1 += inp1_strides[2];
416 p_inp2_tmp1 += inp2_strides[2];
417 }
418 p_inp1_tmp0 += inp1_strides[1];
419 p_inp2_tmp0 += inp2_strides[1];
420 }
421 p_inp1_tmp += inp1_strides[0];
422 p_inp2_tmp += inp2_strides[0];
423 }
424 }
425 return 0;
426
427 }
428
429