xref: /aosp_15_r20/external/XNNPACK/src/f32-igemm/1x8-aarch64-neonfma-cortex-a75.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <cassert>
7 #include <cstddef>
8 #include <limits>
9 
10 #include <xnnpack.h>
11 #include <xnnpack/aarch64-assembler.h>
12 #include <xnnpack/allocator.h>
13 #include <xnnpack/igemm.h>
14 
15 
16 namespace xnnpack {
17 namespace aarch64 {
18 namespace {
19 class Generator : public Assembler {
20   using Assembler::Assembler;
21 
22 public:
23   void generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, float min, float max);
24 };
25 
26 // void xnn_f32_igemm_minmax_ukernel_4x8__aarch64_neonfma_prfm_cortex_a75(
27 //     size_t mr,                         (x0) - unused.  mr = 1
28 //     size_t nc,                         x1
29 //     size_t kc,                         x2 / x0
30 //     size_t ks,                         x3 / x9
31 //     const float**restrict a,           x4
32 //     const float*restrict w,            x5
33 //     float*restrict c,                  x6
34 //     size_t cm_stride,                  (x7) - unused
35 //     size_t cn_stride,                  [sp] -> x10
36 //     size_t a_offset,                   [sp + 8] -> x11
37 //     const float* zero,                 [sp + 16] -> x12
38 //     const xnn_f32_minmax_params params [sp + 24] -> (x8)
39 
40 // d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
41 
42 // A pointer
43 // x8  a0
44 
45 // C pointer
46 // x6  c0
47 
48 // Converted from: src/f32-igemm/gen/1x8-minmax-aarch64-neonfma-prfm-cortex-a75.S
generate(bool prefetch,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,float min,float max)49 void Generator::generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, float min, float max)
50 {
51   assert(nc_mod_nr < 8);
52   assert(kc != 0);
53   assert(kc % sizeof(float) == 0);
54   assert(ks != 0);
55 
56   Label l0, l1, l2, l3, l4, l5, l6, l7, l8, l9, l10, l11, l12, l13;
57   const bool clamp_min = min != -std::numeric_limits<float>::infinity();
58   const bool clamp_max = max != +std::numeric_limits<float>::infinity();
59 
60   // Load cn_stride, a_offset
61   ldp(x10, x11, mem[sp]);
62 
63   // Load zero, params pointer
64   ldp(x12, x8, mem[sp, 16]);
65 
66   // Load min/max values
67   if (clamp_max) {
68     ld2r({v30.v4s(), v31.v4s()}, mem[x8]);
69   } else if (clamp_min) {
70     if (min == 0.f) {
71       movi(v30.v4s(), 0);
72     } else {
73       ld1r({v30.v4s()}, mem[x8]);
74     }
75   }
76 
77   bind(l0);
78   // Load initial bias from w into accumulators
79   ldp(q16, q17, mem[x5], 32);
80   movi(v18.v4s(), 0); // second set of C for pipelining FMLA
81   if (prefetch) {
82     prfm(kPLDL1KEEP, mem[x5]);
83   }
84   movi(v19.v4s(), 0);
85   if (prefetch) {
86     prfm(kPLDL1KEEP, mem[x5, 64]);
87     prfm(kPLDL1KEEP, mem[x5, 128]);
88     prfm(kPLDL1KEEP, mem[x5, 192]);
89   }
90 
91   mov(x9, x3); // p = ks
92 
93   bind(l1);
94   // Load next A pointer
95   ldr(x8, mem[x4], 8);
96 
97   cmp(x8, x12);           // if a0 == zero
98   add(x8, x8, x11);       // a0 += a_offset
99   csel(x8, x12, x8, kEQ); //   a0 = zero, else += a0 + a_offset
100 
101   // Is there at least 8 floats (32 bytes) for prologue + epilogue?
102   subs(x0, x2, 32); // k = kc - 32
103   b_lo(l4);
104 
105   // 16 prologue
106   // Read first block of A and B.
107   ldp(q20, q21, mem[x5], 32);
108   ldp(q22, q23, mem[x5], 32);
109   ldp(q24, q25, mem[x5], 32);
110   ldp(q26, q27, mem[x5], 32);
111   ldr(q0, mem[x8], 16);
112 
113   // Is there at least 8.  yes do main loop
114   subs(x0, x0, 32);
115   b_lo(l3);
116 
117   // Main loop - 8 floats of A (32 bytes)
118   bind(l2);
119   // First block of 4.  FMA for first 4, loads for 2nd block of 4.
120   fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
121   ldr(q1, mem[x8], 16);
122   fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
123   ldp(q20, q21, mem[x5], 32);
124   fmla(v18.v4s(), v22.v4s(), v0.s()[1]);
125   fmla(v19.v4s(), v23.v4s(), v0.s()[1]);
126   ldp(q22, q23, mem[x5], 32);
127   fmla(v16.v4s(), v24.v4s(), v0.s()[2]);
128   fmla(v17.v4s(), v25.v4s(), v0.s()[2]);
129   ldp(q24, q25, mem[x5], 32);
130   if (prefetch) {
131     prfm(kPLDL1KEEP, mem[x5, 128]);
132   }
133   fmla(v18.v4s(), v26.v4s(), v0.s()[3]);
134   if (prefetch) {
135     prfm(kPLDL1KEEP, mem[x5, 256]);
136   }
137   fmla(v19.v4s(), v27.v4s(), v0.s()[3]);
138   ldp(q26, q27, mem[x5], 32);
139 
140   // Second block of 4.  FMA for second 4, loads for 1st block of 4.
141   fmla(v16.v4s(), v20.v4s(), v1.s()[0]);
142   ldr(q0, mem[x8], 16);
143   fmla(v17.v4s(), v21.v4s(), v1.s()[0]);
144   ldp(q20, q21, mem[x5], 32);
145   fmla(v18.v4s(), v22.v4s(), v1.s()[1]);
146   fmla(v19.v4s(), v23.v4s(), v1.s()[1]);
147   ldp(q22, q23, mem[x5], 32);
148   fmla(v16.v4s(), v24.v4s(), v1.s()[2]);
149   fmla(v17.v4s(), v25.v4s(), v1.s()[2]);
150   ldp(q24, q25, mem[x5], 32);
151   if (prefetch) {
152     prfm(kPLDL1KEEP, mem[x5, 128]);
153   }
154   fmla(v18.v4s(), v26.v4s(), v1.s()[3]);
155   if (prefetch) {
156     prfm(kPLDL1KEEP, mem[x5, 256]);
157   }
158   fmla(v19.v4s(), v27.v4s(), v1.s()[3]);
159   subs(x0, x0, 32);
160   ldp(q26, q27, mem[x5], 32);
161   b_hs(l2);
162 
163   bind(l3);
164   // Epilogue
165 
166   // First block of 4.  FMA for first 4, loads for 2nd block of 4.
167   fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
168   ldr(q1, mem[x8], 16);
169   fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
170   ldp(q20, q21, mem[x5], 32);
171   fmla(v18.v4s(), v22.v4s(), v0.s()[1]);
172   fmla(v19.v4s(), v23.v4s(), v0.s()[1]);
173   ldp(q22, q23, mem[x5], 32);
174   fmla(v16.v4s(), v24.v4s(), v0.s()[2]);
175   fmla(v17.v4s(), v25.v4s(), v0.s()[2]);
176   ldp(q24, q25, mem[x5], 32);
177   if (prefetch) {
178     prfm(kPLDL1KEEP, mem[x5, 128]);
179   }
180   fmla(v18.v4s(), v26.v4s(), v0.s()[3]);
181   if (prefetch) {
182     prfm(kPLDL1KEEP, mem[x5, 256]);
183   }
184   fmla(v19.v4s(), v27.v4s(), v0.s()[3]);
185   ldp(q26, q27, mem[x5], 32);
186 
187   // Second block of 4.  no loads
188   fmla(v16.v4s(), v20.v4s(), v1.s()[0]);
189   fmla(v17.v4s(), v21.v4s(), v1.s()[0]);
190   fmla(v18.v4s(), v22.v4s(), v1.s()[1]);
191   fmla(v19.v4s(), v23.v4s(), v1.s()[1]);
192   fmla(v16.v4s(), v24.v4s(), v1.s()[2]);
193   fmla(v17.v4s(), v25.v4s(), v1.s()[2]);
194   fmla(v18.v4s(), v26.v4s(), v1.s()[3]);
195   fmla(v19.v4s(), v27.v4s(), v1.s()[3]);
196 
197   bind(l4);
198   // Is there a remainder?- 4 floats of A (16 bytes)
199   tbnz(x0, 4, l6);
200   // Is there a remainder?- 2 floats of A (8 bytes)
201   tbnz(x0, 3, l7);
202   // Is there a remainder?- 1 float of A (4 bytes)
203   tbnz(x0, 2, l9);
204 
205   bind(l5);
206   // ks loop
207   subs(x9, x9, 8); // ks -= MR * sizeof(void*)
208   b_hi(l1);
209 
210   fadd(v16.v4s(), v16.v4s(), v18.v4s());
211   fadd(v17.v4s(), v17.v4s(), v19.v4s());
212 
213   // Clamp
214   if (clamp_min) {
215     fmax(v16.v4s(), v16.v4s(), v30.v4s());
216     fmax(v17.v4s(), v17.v4s(), v30.v4s());
217   }
218   if (clamp_max) {
219     fmin(v16.v4s(), v16.v4s(), v31.v4s());
220     fmin(v17.v4s(), v17.v4s(), v31.v4s());
221   }
222 
223   // Store full 1 x 8
224   subs(x1, x1, 8);
225   b_lo(l10);
226 
227   stp(q16, q17, mem[x6]);
228   add(x6, x6, x10);
229 
230   sub(x4, x4, x3); // a -= ks
231 
232   // nc loop
233   b_hi(l0);
234 
235   ret();
236 
237   bind(l6);
238   // Remainder- 4 floats of A (16 bytes)
239   ldp(q20, q21, mem[x5], 32);
240   ldr(q0, mem[x8], 16);
241   fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
242   fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
243   ldp(q22, q23, mem[x5], 32);
244   ldp(q24, q25, mem[x5], 32);
245   ldp(q26, q27, mem[x5], 32);
246   fmla(v18.v4s(), v22.v4s(), v0.s()[1]);
247   fmla(v19.v4s(), v23.v4s(), v0.s()[1]);
248   fmla(v16.v4s(), v24.v4s(), v0.s()[2]);
249   fmla(v17.v4s(), v25.v4s(), v0.s()[2]);
250   fmla(v18.v4s(), v26.v4s(), v0.s()[3]);
251   fmla(v19.v4s(), v27.v4s(), v0.s()[3]);
252 
253   tbz(x0, 3, l8);
254   bind(l7);
255   // Remainder- 2 floats of A (8 bytes)
256   ldp(q20, q21, mem[x5], 32);
257   ldr(d0, mem[x8], 8);
258   fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
259   fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
260   ldp(q22, q23, mem[x5], 32);
261   fmla(v18.v4s(), v22.v4s(), v0.s()[1]);
262   fmla(v19.v4s(), v23.v4s(), v0.s()[1]);
263   bind(l8);
264   tbz(x0, 2, l5);
265   bind(l9);
266   // Remainder- 1 float of A (4 bytes)
267   ldp(q20, q21, mem[x5], 32);
268   ldr(s0, mem[x8], 4);
269   fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
270   fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
271   b(l5);
272 
273   bind(l10);
274   // Store odd channels
275   tbz(x1, 2, l11);
276   str(q16, mem[x6], 16);
277   mov(v16.v16b(), v17.v16b());
278 
279   bind(l11);
280   tbz(x1, 1, l12);
281   str(d16, mem[x6], 8);
282   dup(d16, v16.d()[1]);
283 
284   bind(l12);
285   tbz(x1, 0, l13);
286   str(s16, mem[x6], 4);
287   bind(l13);
288   ret();
289 
290   align(16, AlignInstruction::kHlt);
291 }
292 } // namespace
293 } // namespace aarch64
294 } // namespace xnnpack
295 
xnn_generate_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,const void * params)296 xnn_status_t xnn_generate_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a75(
297     xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params)
298 {
299   using namespace xnnpack::aarch64;
300   Generator g(code);
301   assert(params != nullptr);
302   auto jit_params = static_cast<const jit_gemm_params*>(params);
303   g.generate(false, max_mr, nc_mod_nr, kc, ks, jit_params->f32_minmax.min, jit_params->f32_minmax.max);
304   g.finalize();
305   if (g.error() != xnnpack::Error::kNoError) {
306     return xnn_status_invalid_state;
307   }
308   return xnn_status_success;
309 }
310 
xnn_generate_f32_igemm_ukernel_1x8__aarch64_neonfma_prfm_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,const void * params)311 xnn_status_t xnn_generate_f32_igemm_ukernel_1x8__aarch64_neonfma_prfm_cortex_a75(
312     xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params)
313 {
314   using namespace xnnpack::aarch64;
315   Generator g(code);
316   assert(params != nullptr);
317   auto jit_params = static_cast<const jit_gemm_params*>(params);
318   g.generate(true, max_mr, nc_mod_nr, kc, ks, jit_params->f32_minmax.min, jit_params->f32_minmax.max);
319   g.finalize();
320   if (g.error() != xnnpack::Error::kNoError) {
321     return xnn_status_invalid_state;
322   }
323   return xnn_status_success;
324 }
325