xref: /aosp_15_r20/external/XNNPACK/src/f32-spmm/gen/16x1-minmax-wasmsimd-x86-x4.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Auto-generated file. Do not edit!
2 //   Template: src/f32-spmm/wasmsimd.c.in
3 //   Generator: tools/xngen
4 //
5 // Copyright 2020 Google LLC
6 //
7 // This source code is licensed under the BSD-style license found in the
8 // LICENSE file in the root directory of this source tree.
9 
10 #include <assert.h>
11 
12 #include <wasm_simd128.h>
13 
14 #include <xnnpack/spmm.h>
15 
16 
xnn_f32_spmm_minmax_ukernel_16x1__wasmsimd_x86_x4(size_t mc,size_t nc,const float * restrict input,const float * restrict weights,const int32_t * restrict widx_dmap,const uint32_t * restrict nidx_nnzmap,float * restrict output,size_t output_stride,const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS (1)])17 void xnn_f32_spmm_minmax_ukernel_16x1__wasmsimd_x86_x4(
18     size_t mc,
19     size_t nc,
20     const float*restrict input,
21     const float*restrict weights,
22     const int32_t*restrict widx_dmap,
23     const uint32_t*restrict nidx_nnzmap,
24     float*restrict output,
25     size_t output_stride,
26     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
27 {
28   assert(mc != 0);
29   assert(mc % sizeof(float) == 0);
30   assert(nc != 0);
31 
32   const v128_t vmin = wasm_v128_load64_splat(params->wasmsimd.min);
33   const v128_t vmax = wasm_v128_load64_splat(params->wasmsimd.max);
34   size_t output_decrement = output_stride * nc - 16 * sizeof(float);
35   while XNN_LIKELY(mc >= 16 * sizeof(float)) {
36     const float*restrict w = weights;
37     const int32_t* dmap = widx_dmap;
38     const uint32_t* nnzmap = nidx_nnzmap;
39     size_t n = nc;
40     do {
41       uint32_t nnz = *nnzmap++;
42       v128_t vacc0123x0 = wasm_v128_load32_splat(w);
43       w += 1;
44       v128_t vacc0123x1 = wasm_f32x4_const_splat(0.0f);
45       v128_t vacc0123x2 = wasm_f32x4_const_splat(0.0f);
46       v128_t vacc0123x3 = wasm_f32x4_const_splat(0.0f);
47       v128_t vacc4567x0 = vacc0123x0;
48       v128_t vacc4567x1 = wasm_f32x4_const_splat(0.0f);
49       v128_t vacc4567x2 = wasm_f32x4_const_splat(0.0f);
50       v128_t vacc4567x3 = wasm_f32x4_const_splat(0.0f);
51       v128_t vacc89ABx0 = vacc0123x0;
52       v128_t vacc89ABx1 = wasm_f32x4_const_splat(0.0f);
53       v128_t vacc89ABx2 = wasm_f32x4_const_splat(0.0f);
54       v128_t vacc89ABx3 = wasm_f32x4_const_splat(0.0f);
55       v128_t vaccCDEFx0 = vacc0123x0;
56       v128_t vaccCDEFx1 = wasm_f32x4_const_splat(0.0f);
57       v128_t vaccCDEFx2 = wasm_f32x4_const_splat(0.0f);
58       v128_t vaccCDEFx3 = wasm_f32x4_const_splat(0.0f);
59       for (; nnz >= 4; nnz -= 4) {
60         const intptr_t diff0 = dmap[0];
61         const intptr_t diff1 = dmap[1];
62         const intptr_t diff2 = dmap[2];
63         const intptr_t diff3 = dmap[3];
64         dmap += 4;
65         const v128_t vi0123x0 = wasm_v128_load(input);
66         const v128_t vi4567x0 = wasm_v128_load(input + 4);
67         const v128_t vi89ABx0 = wasm_v128_load(input + 8);
68         const v128_t viCDEFx0 = wasm_v128_load(input + 12);
69         input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff0);
70         const v128_t vw0 = wasm_v128_load32_splat(w);
71         w += 1;
72         vacc0123x0 = wasm_f32x4_add(vacc0123x0, wasm_f32x4_mul(vi0123x0, vw0));
73         vacc4567x0 = wasm_f32x4_add(vacc4567x0, wasm_f32x4_mul(vi4567x0, vw0));
74         vacc89ABx0 = wasm_f32x4_add(vacc89ABx0, wasm_f32x4_mul(vi89ABx0, vw0));
75         vaccCDEFx0 = wasm_f32x4_add(vaccCDEFx0, wasm_f32x4_mul(viCDEFx0, vw0));
76         const v128_t vi0123x1 = wasm_v128_load(input);
77         const v128_t vi4567x1 = wasm_v128_load(input + 4);
78         const v128_t vi89ABx1 = wasm_v128_load(input + 8);
79         const v128_t viCDEFx1 = wasm_v128_load(input + 12);
80         input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff1);
81         const v128_t vw1 = wasm_v128_load32_splat(w);
82         w += 1;
83         vacc0123x1 = wasm_f32x4_add(vacc0123x1, wasm_f32x4_mul(vi0123x1, vw1));
84         vacc4567x1 = wasm_f32x4_add(vacc4567x1, wasm_f32x4_mul(vi4567x1, vw1));
85         vacc89ABx1 = wasm_f32x4_add(vacc89ABx1, wasm_f32x4_mul(vi89ABx1, vw1));
86         vaccCDEFx1 = wasm_f32x4_add(vaccCDEFx1, wasm_f32x4_mul(viCDEFx1, vw1));
87         const v128_t vi0123x2 = wasm_v128_load(input);
88         const v128_t vi4567x2 = wasm_v128_load(input + 4);
89         const v128_t vi89ABx2 = wasm_v128_load(input + 8);
90         const v128_t viCDEFx2 = wasm_v128_load(input + 12);
91         input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff2);
92         const v128_t vw2 = wasm_v128_load32_splat(w);
93         w += 1;
94         vacc0123x2 = wasm_f32x4_add(vacc0123x2, wasm_f32x4_mul(vi0123x2, vw2));
95         vacc4567x2 = wasm_f32x4_add(vacc4567x2, wasm_f32x4_mul(vi4567x2, vw2));
96         vacc89ABx2 = wasm_f32x4_add(vacc89ABx2, wasm_f32x4_mul(vi89ABx2, vw2));
97         vaccCDEFx2 = wasm_f32x4_add(vaccCDEFx2, wasm_f32x4_mul(viCDEFx2, vw2));
98         const v128_t vi0123x3 = wasm_v128_load(input);
99         const v128_t vi4567x3 = wasm_v128_load(input + 4);
100         const v128_t vi89ABx3 = wasm_v128_load(input + 8);
101         const v128_t viCDEFx3 = wasm_v128_load(input + 12);
102         input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff3);
103         const v128_t vw3 = wasm_v128_load32_splat(w);
104         w += 1;
105         vacc0123x3 = wasm_f32x4_add(vacc0123x3, wasm_f32x4_mul(vi0123x3, vw3));
106         vacc4567x3 = wasm_f32x4_add(vacc4567x3, wasm_f32x4_mul(vi4567x3, vw3));
107         vacc89ABx3 = wasm_f32x4_add(vacc89ABx3, wasm_f32x4_mul(vi89ABx3, vw3));
108         vaccCDEFx3 = wasm_f32x4_add(vaccCDEFx3, wasm_f32x4_mul(viCDEFx3, vw3));
109       }
110       v128_t vacc0123 = vacc0123x0;
111       v128_t vacc4567 = vacc4567x0;
112       v128_t vacc89AB = vacc89ABx0;
113       v128_t vaccCDEF = vaccCDEFx0;
114       vacc0123 = wasm_f32x4_add(vacc0123, vacc0123x1);
115       vacc4567 = wasm_f32x4_add(vacc4567, vacc4567x1);
116       vacc89AB = wasm_f32x4_add(vacc89AB, vacc89ABx1);
117       vaccCDEF = wasm_f32x4_add(vaccCDEF, vaccCDEFx1);
118       vacc0123 = wasm_f32x4_add(vacc0123, vacc0123x2);
119       vacc4567 = wasm_f32x4_add(vacc4567, vacc4567x2);
120       vacc89AB = wasm_f32x4_add(vacc89AB, vacc89ABx2);
121       vaccCDEF = wasm_f32x4_add(vaccCDEF, vaccCDEFx2);
122       vacc0123 = wasm_f32x4_add(vacc0123, vacc0123x3);
123       vacc4567 = wasm_f32x4_add(vacc4567, vacc4567x3);
124       vacc89AB = wasm_f32x4_add(vacc89AB, vacc89ABx3);
125       vaccCDEF = wasm_f32x4_add(vaccCDEF, vaccCDEFx3);
126       if XNN_LIKELY(nnz != 0) {
127         do {
128           const intptr_t diff = *dmap++;
129           const v128_t vi0123 = wasm_v128_load(input);
130           const v128_t vi4567 = wasm_v128_load(input + 4);
131           const v128_t vi89AB = wasm_v128_load(input + 8);
132           const v128_t viCDEF = wasm_v128_load(input + 12);
133           input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
134           const v128_t vw = wasm_v128_load32_splat(w); w += 1;
135           vacc0123 = wasm_f32x4_add(vacc0123, wasm_f32x4_mul(vi0123, vw));
136           vacc4567 = wasm_f32x4_add(vacc4567, wasm_f32x4_mul(vi4567, vw));
137           vacc89AB = wasm_f32x4_add(vacc89AB, wasm_f32x4_mul(vi89AB, vw));
138           vaccCDEF = wasm_f32x4_add(vaccCDEF, wasm_f32x4_mul(viCDEF, vw));
139         } while (--nnz != 0);
140       }
141       v128_t vout0123 = wasm_f32x4_pmin(vmax, vacc0123);
142       v128_t vout4567 = wasm_f32x4_pmin(vmax, vacc4567);
143       v128_t vout89AB = wasm_f32x4_pmin(vmax, vacc89AB);
144       v128_t voutCDEF = wasm_f32x4_pmin(vmax, vaccCDEF);
145       vout0123 = wasm_f32x4_pmax(vmin, vout0123);
146       vout4567 = wasm_f32x4_pmax(vmin, vout4567);
147       vout89AB = wasm_f32x4_pmax(vmin, vout89AB);
148       voutCDEF = wasm_f32x4_pmax(vmin, voutCDEF);
149       wasm_v128_store(output, vout0123);
150       wasm_v128_store(output + 4, vout4567);
151       wasm_v128_store(output + 8, vout89AB);
152       wasm_v128_store(output + 12, voutCDEF);
153       output = (float*restrict) ((uintptr_t) output + output_stride);
154     } while (--n != 0);
155     output = (float*restrict) ((uintptr_t) output - output_decrement);
156     input += 16;
157     mc -= 16 * sizeof(float);
158   }
159   if XNN_UNLIKELY(mc != 0) {
160     output_decrement += 8 * sizeof(float);
161     if (mc & (8 * sizeof(float))) {
162       const float*restrict w = weights;
163       const int32_t* dmap = widx_dmap;
164       const uint32_t* nnzmap = nidx_nnzmap;
165       size_t n = nc;
166       do {
167         uint32_t nnz = *nnzmap++;
168         v128_t vacc0123 = wasm_v128_load32_splat(w); w += 1;
169         v128_t vacc4567 = vacc0123;
170         if XNN_LIKELY(nnz != 0) {
171           do {
172             const intptr_t diff = *dmap++;
173             const v128_t vi0123 = wasm_v128_load(input);
174             const v128_t vi4567 = wasm_v128_load(input + 4);
175             input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
176             const v128_t vw = wasm_v128_load32_splat(w); w += 1;
177             vacc0123 = wasm_f32x4_add(vacc0123, wasm_f32x4_mul(vi0123, vw));
178             vacc4567 = wasm_f32x4_add(vacc4567, wasm_f32x4_mul(vi4567, vw));
179           } while (--nnz != 0);
180         }
181         v128_t vout0123 = wasm_f32x4_pmin(vmax, vacc0123);
182         v128_t vout4567 = wasm_f32x4_pmin(vmax, vacc4567);
183         vout0123 = wasm_f32x4_pmax(vmin, vout0123);
184         vout4567 = wasm_f32x4_pmax(vmin, vout4567);
185         wasm_v128_store(output, vout0123);
186 
187         wasm_v128_store(output + 4, vout4567);
188         output = (float*restrict) ((uintptr_t) output + output_stride);
189       } while (--n != 0);
190       output = (float*restrict) ((uintptr_t) output - output_decrement);
191       input += 8;
192     }
193     output_decrement += 4 * sizeof(float);
194     if (mc & (4 * sizeof(float))) {
195       const float*restrict w = weights;
196       const int32_t* dmap = widx_dmap;
197       const uint32_t* nnzmap = nidx_nnzmap;
198       size_t n = nc;
199       do {
200         uint32_t nnz = *nnzmap++;
201         v128_t vacc0123 = wasm_v128_load32_splat(w); w += 1;
202         if XNN_LIKELY(nnz != 0) {
203           do {
204             const intptr_t diff = *dmap++;
205             const v128_t vi0123 = wasm_v128_load(input);
206             input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
207             const v128_t vw = wasm_v128_load32_splat(w); w += 1;
208             vacc0123 = wasm_f32x4_add(vacc0123, wasm_f32x4_mul(vi0123, vw));
209           } while (--nnz != 0);
210         }
211         v128_t vout0123 = wasm_f32x4_pmin(vmax, vacc0123);
212         vout0123 = wasm_f32x4_pmax(vmin, vout0123);
213         wasm_v128_store(output, vout0123);
214 
215         output = (float*restrict) ((uintptr_t) output + output_stride);
216       } while (--n != 0);
217       output = (float*restrict) ((uintptr_t) output - output_decrement);
218       input += 4;
219     }
220     output_decrement += 2 * sizeof(float);
221     if (mc & (2 * sizeof(float))) {
222       const float*restrict w = weights;
223       const int32_t* dmap = widx_dmap;
224       const uint32_t* nnzmap = nidx_nnzmap;
225       size_t n = nc;
226       do {
227         uint32_t nnz = *nnzmap++;
228         v128_t vacc01 = wasm_v128_load32_splat(w); w += 1;
229         if XNN_LIKELY(nnz != 0) {
230           do {
231             const intptr_t diff = *dmap++;
232             const v128_t vi01 = wasm_v128_load64_splat(input);
233             input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
234             const v128_t vw = wasm_v128_load32_splat(w); w += 1;
235             vacc01 = wasm_f32x4_add(vacc01, wasm_f32x4_mul(vi01, vw));
236           } while (--nnz != 0);
237         }
238         v128_t vout01 = wasm_f32x4_pmin(vmax, vacc01);
239         vout01 = wasm_f32x4_pmax(vmin, vout01);
240         *((double*) output) = wasm_f64x2_extract_lane(vout01, 0);
241 
242         output = (float*restrict) ((uintptr_t) output + output_stride);
243       } while (--n != 0);
244       output = (float*restrict) ((uintptr_t) output - output_decrement);
245       input += 2;
246     }
247     output_decrement += 1 * sizeof(float);
248     if (mc & (1 * sizeof(float))) {
249       const float*restrict w = weights;
250       const int32_t* dmap = widx_dmap;
251       const uint32_t* nnzmap = nidx_nnzmap;
252       size_t n = nc;
253       do {
254         uint32_t nnz = *nnzmap++;
255         v128_t vacc0 = wasm_v128_load32_splat(w); w += 1;
256         if XNN_LIKELY(nnz != 0) {
257           do {
258             const intptr_t diff = *dmap++;
259             const v128_t vi0 = wasm_v128_load32_splat(input);
260             input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
261             const v128_t vw = wasm_v128_load32_splat(w); w += 1;
262             vacc0 = wasm_f32x4_add(vacc0, wasm_f32x4_mul(vi0, vw));
263           } while (--nnz != 0);
264         }
265         v128_t vout0 = wasm_f32x4_pmin(vmax, vacc0);
266         vout0 = wasm_f32x4_pmax(vmin, vout0);
267         *output = wasm_f32x4_extract_lane(vout0, 0);
268 
269         output = (float*restrict) ((uintptr_t) output + output_stride);
270       } while (--n != 0);
271       output = (float*restrict) ((uintptr_t) output - output_decrement);
272       input += 1;
273     }
274   }
275 }
276