xref: /aosp_15_r20/external/executorch/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_add_f32_broadcast.c (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*******************************************************************************
2 * Copyright (c) 2018-2024 Cadence Design Systems, Inc.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining
5 * a copy of this software and associated documentation files (the
6 * "Software"), to use this Software with Cadence processor cores only and
7 * not with any other processors and platforms, subject to
8 * the following conditions:
9 *
10 * The above copyright notice and this permission notice shall be included
11 * in all copies or substantial portions of the Software.
12 *
13 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
14 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
15 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
17 * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
18 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
19 * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
20 
21 ******************************************************************************/
22 #include "xa_type_def.h"
23 #include "xa_nnlib_common_fpu.h"
24 #include "xa_nn_common.h"
25 #include "xa_nnlib_err_chk.h"
26 #include "xa_nnlib_kernels_api.h"
27 
28 
29 #if HAVE_VFPU
internal_elm_add_broadcast_2D_f32xf32_f32(FLOAT32 * __restrict__ p_out,const FLOAT32 * __restrict__ p_inp1,const FLOAT32 * __restrict__ p_inp2,WORD32 out_lc,WORD32 in_lc,xtbool sign_flag)30 static void internal_elm_add_broadcast_2D_f32xf32_f32(FLOAT32 * __restrict__ p_out,
31                     const    FLOAT32 * __restrict__ p_inp1,
32                     const    FLOAT32 * __restrict__ p_inp2,
33                              WORD32  out_lc,
34                              WORD32  in_lc,
35                              xtbool  sign_flag)
36 {
37   int i, j;
38 
39   xtfloatx2  * __restrict__ p_a = (xtfloatx2 *)p_inp1;
40   xtfloatx2  * __restrict__ p_b = (xtfloatx2 *)p_inp2;
41   xtfloatx2  *__restrict__  p_c =  (xtfloatx2 *)p_out;
42 
43   int num_simd2_ops;
44   int num_scalar_ops;
45 
46   if(out_lc)
47   {
48     num_simd2_ops = in_lc >> 1;
49     num_scalar_ops = in_lc & 1;
50   }
51   else
52   {
53     num_simd2_ops = (in_lc >> 2) << 1;
54     num_scalar_ops = in_lc & 3;
55   }
56 
57     xtfloatx2 x1, x2, y;
58     xtfloat a0, b0, c0;
59 
60   /* For computing inp2 + inp1 */
61   if(sign_flag){
62     for(i = 0; i < out_lc; i++)
63     {
64       p_a = (xtfloatx2 *)&p_inp1[i * in_lc];
65       p_b = (xtfloatx2 *)p_inp2;
66       p_c = (xtfloatx2 *)&p_out[i * in_lc];
67       if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_b)&7) == 0) && ((((unsigned)p_c)&7) == 0))
68       {
69         for(j = 0; j < num_simd2_ops; j++)
70         {
71           XT_LSX2IP(x1, p_a, 2 * sizeof(FLOAT32));
72           XT_LSX2IP(x2, p_b, 2 * sizeof(FLOAT32));
73           y = XT_ADD_SX2(x2, x1);
74           XT_SSX2IP(y, p_c, 2 * sizeof(FLOAT32));
75         }
76       }
77       else
78       {
79         ae_valign vinp1, vinp2, out_a = AE_ZALIGN64();
80         vinp1 = XT_LASX2PP(p_a);
81         vinp2 = XT_LASX2PP(p_b);
82         for(j = 0; j < num_simd2_ops; j++)
83         {
84           XT_LASX2IP(x1, vinp1, p_a);
85           XT_LASX2IP(x2, vinp2, p_b);
86           y = XT_ADD_SX2(x2, x1);
87           XT_SASX2IP(y, out_a, p_c);
88         }
89         XT_SASX2POSFP(out_a, (xtfloatx2 *)p_c);
90       }
91       if(num_scalar_ops !=0)
92       {
93         XT_LSIP(a0, (xtfloat *)p_a, sizeof(FLOAT32));
94         XT_LSIP(b0, (xtfloat *)p_b, sizeof(FLOAT32));
95         c0 = XT_ADD_S(b0, a0);
96         XT_SSI(c0, (xtfloat *)p_c, 0);
97       }
98     }
99   }
100   /* For computing inp1 + inp2 */
101   else
102   {
103     for(i = 0; i < out_lc; i++)
104     {
105       p_a = (xtfloatx2 *)&p_inp1[i * in_lc];
106       p_b = (xtfloatx2 *)p_inp2;
107       p_c = (xtfloatx2 *)&p_out[i * in_lc];
108       if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_b)&7) == 0) && ((((unsigned)p_c)&7) == 0))
109       {
110         for(j = 0; j < num_simd2_ops; j++)
111         {
112           XT_LSX2IP(x1, p_a, 2 * sizeof(FLOAT32));
113           XT_LSX2IP(x2, p_b, 2 * sizeof(FLOAT32));
114           y = XT_ADD_SX2(x1, x2);
115           XT_SSX2IP(y, p_c, 2 * sizeof(FLOAT32));
116         }
117       }
118       else
119       {
120         ae_valign vinp1, vinp2, out_a = AE_ZALIGN64();
121         vinp1 = XT_LASX2PP(p_a);
122         vinp2 = XT_LASX2PP(p_b);
123 
124         for(j = 0; j < num_simd2_ops; j++)
125         {
126           XT_LASX2IP(x1, vinp1, p_a);
127           XT_LASX2IP(x2, vinp2, p_b);
128           y = XT_ADD_SX2(x1, x2);
129           XT_SASX2IP(y, out_a, p_c);
130         }
131         XT_SASX2POSFP(out_a, (xtfloatx2 *)p_c);
132       }
133       if(num_scalar_ops !=0)
134       {
135         XT_LSIP(a0, (xtfloat *)p_a, sizeof(FLOAT32));
136         XT_LSIP(b0, (xtfloat *)p_b, sizeof(FLOAT32));
137         c0 = XT_ADD_S(a0, b0);
138         XT_SSI(c0, (xtfloat *)p_c, 0);
139       }
140     }
141   }
142 }
143 
internal_elm_add_broadcast_f32xf32_f32(FLOAT32 * __restrict__ p_out,const FLOAT32 * __restrict__ p_inp1,const FLOAT32 * __restrict__ p_inp2,WORD32 num_elm,xtbool sign_flag)144 static void internal_elm_add_broadcast_f32xf32_f32(FLOAT32 * __restrict__ p_out,
145                     const    FLOAT32 * __restrict__ p_inp1,
146                     const    FLOAT32 * __restrict__ p_inp2,
147                              WORD32  num_elm,
148                              xtbool  sign_flag)
149 {
150   int i;
151   xtfloatx2  * __restrict__ p_a = (xtfloatx2 *)p_inp1;
152   xtfloatx2  * __restrict__ p_b = (xtfloatx2 *)p_inp2;
153   xtfloatx2  *__restrict__  p_c =  (xtfloatx2 *)p_out;
154 
155   const int num_simd2_ops = num_elm >> 1;
156   const int num_scalar_ops = num_elm & 1;
157 
158   xtfloat a0_7, out;
159   xtfloatx2 x1, x2, y;
160   x2 = XT_LSI((xtfloat *)p_b, 0);
161 
162   /* For computing inp2 + inp1 */
163   if(sign_flag){
164     if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_c)&7) == 0))
165     {
166       for(i=0; i<num_simd2_ops; i++)
167       {
168         XT_LSX2IP(x1, p_a, 2 * sizeof(FLOAT32));
169         y = XT_ADD_SX2(x2, x1);
170         XT_SSX2IP(y, p_c, 2 * sizeof(FLOAT32));
171       }
172     }
173     else
174     {
175       ae_valign inp1_a, out_a;
176       inp1_a = XT_LASX2PP(p_a);
177       out_a = AE_ZALIGN64();
178       for(i=0; i<num_simd2_ops; i++)
179       {
180         XT_LASX2IP(x1, inp1_a, p_a);
181         y = XT_ADD_SX2(x2, x1);
182         XT_SASX2IP(y, out_a, p_c);
183       }
184       XT_SASX2POSFP(out_a, (xtfloatx2 *)p_c);
185     }
186     if(num_scalar_ops !=0)
187     {
188       XT_LSIP(a0_7, (xtfloat *)p_a, sizeof(FLOAT32));
189       out = XT_ADD_S(x2, a0_7);
190       XT_SSI(out, (xtfloat *)p_c, 0);
191     }
192   }
193   /* For computing inp1 + inp2 */
194   else
195   {
196     if(((((unsigned)p_a)&7) == 0) && ((((unsigned)p_c)&7) == 0))
197     {
198       for(i=0; i<num_simd2_ops; i++)
199       {
200         XT_LSX2IP(x1, p_a, 2 * sizeof(FLOAT32));
201         y = XT_ADD_SX2(x1, x2);
202         XT_SSX2IP(y, p_c, 2 * sizeof(FLOAT32));
203       }
204     }
205     else
206     {
207       ae_valign inp1_a, out_a;
208       inp1_a = XT_LASX2PP(p_a);
209       out_a = AE_ZALIGN64();
210       for(i=0; i<num_simd2_ops; i++)
211       {
212         XT_LASX2IP(x1, inp1_a, p_a);
213         y = XT_ADD_SX2(x1, x2);
214         XT_SASX2IP(y, out_a, p_c);
215       }
216       XT_SASX2POSFP(out_a, (xtfloatx2 *)p_c);
217     }
218     if(num_scalar_ops !=0)
219     {
220       XT_LSIP(a0_7, (xtfloat *)p_a, sizeof(FLOAT32));
221       out = XT_ADD_S(a0_7, x2);
222       XT_SSI(out, (xtfloat *)p_c, 0);
223     }
224   }
225 }
226 #endif
227 
xa_nn_elm_add_broadcast_4D_f32xf32_f32(FLOAT32 * __restrict__ p_out,const WORD32 * const p_out_shape,const FLOAT32 * __restrict__ p_inp1,const WORD32 * const p_inp1_shape,const FLOAT32 * __restrict__ p_inp2,const WORD32 * const p_inp2_shape)228 WORD32 xa_nn_elm_add_broadcast_4D_f32xf32_f32(FLOAT32 * __restrict__ p_out,
229                       const WORD32 *const p_out_shape,
230                       const FLOAT32 * __restrict__ p_inp1,
231                       const WORD32 *const p_inp1_shape,
232                       const FLOAT32 * __restrict__ p_inp2,
233                       const WORD32 *const p_inp2_shape)
234 {
235   /* NULL pointer checks */
236   XA_NNLIB_ARG_CHK_PTR(p_out, -1);
237   XA_NNLIB_ARG_CHK_PTR(p_inp1, -1);
238   XA_NNLIB_ARG_CHK_PTR(p_inp2, -1);
239   XA_NNLIB_ARG_CHK_PTR(p_out_shape, -1);
240   XA_NNLIB_ARG_CHK_PTR(p_inp1_shape, -1);
241   XA_NNLIB_ARG_CHK_PTR(p_inp2_shape, -1);
242   /* Pointer alignment checks */
243   XA_NNLIB_ARG_CHK_ALIGN(p_out, sizeof(FLOAT32), -1);
244   XA_NNLIB_ARG_CHK_ALIGN(p_inp1, sizeof(FLOAT32), -1);
245   XA_NNLIB_ARG_CHK_ALIGN(p_inp2, sizeof(FLOAT32), -1);
246   XA_NNLIB_ARG_CHK_ALIGN(p_out_shape, sizeof(WORD32), -1);
247   XA_NNLIB_ARG_CHK_ALIGN(p_inp1_shape, sizeof(WORD32), -1);
248   XA_NNLIB_ARG_CHK_ALIGN(p_inp2_shape, sizeof(WORD32), -1);
249 
250   /* Check shapes */
251   int i;
252   xtbool sign_flag;
253   for(i = 0; i < 4; i++)
254   {
255     if((p_inp1_shape[i] != p_inp2_shape[i] && p_inp1_shape[i] != 1 && p_inp2_shape[i] != 1) ||
256        (p_out_shape[i] != (p_inp1_shape[i] > p_inp2_shape[i] ? p_inp1_shape[i] : p_inp2_shape[i])))
257     {
258       return -1;
259     }
260   }
261 
262   WORD32 inp1_strides[4], inp2_strides[4];
263   inp1_strides[3] = 1;
264   inp2_strides[3] = 1;
265   for(i = 2; i >= 0; i--)
266   {
267     ae_int32x2 d_str, d_shape;
268     d_str = AE_MOVDA32X2(inp1_strides[i + 1], inp2_strides[i + 1]);
269     d_shape = AE_MOVDA32X2(p_inp1_shape[i + 1], p_inp2_shape[i + 1]);
270     d_str = AE_MULP32X2(d_str, d_shape);
271     inp1_strides[i] = AE_MOVAD32_H(d_str);
272     inp2_strides[i] = AE_MOVAD32_L(d_str);
273   }
274 
275   int need_broadcast = 0;
276   int inp1_const = 1, inp2_const = 1;
277   for(i = 0; i < 4; i++)
278   {
279     if(p_inp1_shape[i] != p_inp2_shape[i])
280     {
281       if(p_inp1_shape[i] == 1)
282         inp1_strides[i] = 0;
283       else
284         inp2_strides[i] = 0;
285 
286       need_broadcast = 1;
287     }
288     if(p_inp1_shape[i] != 1)
289       inp1_const &= 0;
290     if(p_inp2_shape[i] != 1)
291       inp2_const &= 0;
292   }
293   int itr0, itr1, itr2;
294 
295   FLOAT32 *p_out_tmp = p_out;
296   const FLOAT32 *__restrict__ p_inp1_tmp = p_inp1;
297   const FLOAT32 *__restrict__ p_inp2_tmp = p_inp2;
298   if(need_broadcast == 0)
299   {
300     sign_flag = 0;
301     internal_elm_add_broadcast_2D_f32xf32_f32(
302                 p_out,
303                 p_inp1,
304                 p_inp2,
305                 1,
306                 p_out_shape[0] * inp1_strides[0],
307                 sign_flag);
308   }
309   else if(inp1_strides[3] == inp2_strides[3])
310   {
311     WORD32 in_lc, out_lc;
312     sign_flag = 0;
313     in_lc = p_out_shape[2] * p_out_shape[3];
314     out_lc = 1;
315     if(inp1_strides[2] == 0)
316     {
317       const FLOAT32 *tmp;
318       tmp = p_inp1_tmp;   p_inp1_tmp = p_inp2_tmp;    p_inp2_tmp = tmp;
319       sign_flag = 1;
320       int tmp_strides[2];
321       tmp_strides[0] = inp1_strides[0];
322       tmp_strides[1] = inp1_strides[1];
323 
324       inp1_strides[0] = inp2_strides[0];
325       inp1_strides[1] = inp2_strides[1];
326 
327       inp2_strides[0] = tmp_strides[0];
328       inp2_strides[1] = tmp_strides[1];
329       in_lc = p_out_shape[3];
330       out_lc = p_out_shape[2];
331     }
332     else if(inp2_strides[2] == 0)
333     {
334       in_lc = p_out_shape[3];
335       out_lc = p_out_shape[2];
336     }
337 
338     for(itr0 = 0; itr0 < p_out_shape[0]; itr0++)
339     {
340       const FLOAT32 *__restrict__ p_inp1_tmp0 = p_inp1_tmp;
341       const FLOAT32 *__restrict__ p_inp2_tmp0 = p_inp2_tmp;
342       for(itr1 = 0; itr1 < p_out_shape[1]; itr1++)
343       {
344         internal_elm_add_broadcast_2D_f32xf32_f32(
345             p_out_tmp,
346             p_inp1_tmp0,
347             p_inp2_tmp0,
348             out_lc,
349             in_lc,
350             sign_flag);
351         p_out_tmp += in_lc * out_lc;
352         p_inp1_tmp0 += inp1_strides[1];
353         p_inp2_tmp0 += inp2_strides[1];
354       }
355       p_inp1_tmp += inp1_strides[0];
356       p_inp2_tmp += inp2_strides[0];
357     }
358   }
359   else if(inp1_const == 1 || inp2_const == 1)
360   {
361     sign_flag = 0;
362     if(inp1_strides[3] == 0)
363     {
364       sign_flag = 1;
365       const FLOAT32 *tmp;
366       tmp = p_inp1_tmp;   p_inp1_tmp = p_inp2_tmp;    p_inp2_tmp = tmp;
367     }
368     internal_elm_add_broadcast_f32xf32_f32(
369         p_out_tmp,
370         p_inp1_tmp,
371         p_inp2_tmp,
372         p_out_shape[0] * p_out_shape[1] * p_out_shape[2] * p_out_shape[3],
373         sign_flag);
374   }
375   else
376   {
377     sign_flag = 0;
378     if(inp1_strides[3] == 0)
379     {
380       const FLOAT32 *tmp;
381       tmp = p_inp1_tmp;   p_inp1_tmp = p_inp2_tmp;    p_inp2_tmp = tmp;
382       sign_flag = 1;
383       int tmp_strides[3];
384       tmp_strides[0] = inp1_strides[0];
385       tmp_strides[1] = inp1_strides[1];
386       tmp_strides[2] = inp1_strides[2];
387 
388       inp1_strides[0] = inp2_strides[0];
389       inp1_strides[1] = inp2_strides[1];
390       inp1_strides[2] = inp2_strides[2];
391 
392       inp2_strides[0] = tmp_strides[0];
393       inp2_strides[1] = tmp_strides[1];
394       inp2_strides[2] = tmp_strides[2];
395     }
396     for(itr0 = 0; itr0 < p_out_shape[0]; itr0++)
397     {
398       const FLOAT32 *__restrict__ p_inp1_tmp0 = p_inp1_tmp;
399       const FLOAT32 *__restrict__ p_inp2_tmp0 = p_inp2_tmp;
400       for(itr1 = 0; itr1 < p_out_shape[1]; itr1++)
401       {
402         const FLOAT32 *__restrict__ p_inp1_tmp1 = p_inp1_tmp0;
403         const FLOAT32 *__restrict__ p_inp2_tmp1 = p_inp2_tmp0;
404         for(itr2 = 0; itr2 < p_out_shape[2]; itr2++)
405         {
406           {
407             internal_elm_add_broadcast_f32xf32_f32(
408                 p_out_tmp,
409                 p_inp1_tmp1,
410                 p_inp2_tmp1,
411                 p_out_shape[3],
412                 sign_flag);
413           }
414           p_out_tmp += p_out_shape[3];
415           p_inp1_tmp1 += inp1_strides[2];
416           p_inp2_tmp1 += inp2_strides[2];
417         }
418         p_inp1_tmp0 += inp1_strides[1];
419         p_inp2_tmp0 += inp2_strides[1];
420       }
421       p_inp1_tmp += inp1_strides[0];
422       p_inp2_tmp += inp2_strides[0];
423     }
424   }
425   return 0;
426 
427 }
428 
429