xref: /aosp_15_r20/external/XNNPACK/src/f32-gemm/1x8-aarch64-neonfma-cortex-a75.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 
7 #include <cassert>
8 #include <cstddef>
9 #include <limits>
10 
11 #include <xnnpack.h>
12 #include <xnnpack/aarch64-assembler.h>
13 #include <xnnpack/allocator.h>
14 #include <xnnpack/gemm.h>
15 
16 namespace xnnpack {
17 namespace aarch64 {
18 namespace {
19 class Generator : public Assembler {
20   using Assembler::Assembler;
21 
22 public:
23   void generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, float min, float max);
24 };
25 
26 // void xnn_f32_gemm_minmax_ukernel_1x8__aarch64_neonfma_prfm_cortex_a75(
27 //     size_t mr,                (x0) - unused.  mr = 1
28 //     size_t nc,                x1
29 //     size_t kc,                x2 / x0
30 //     const uint8_t*restrict a, x3
31 //     size_t a_stride,          (x4) - unused
32 //     const void*restrict w,    x5
33 //     uint8_t*restrict c,       x6
34 //     size_t cm_stride,         (x7) - unused
35 //     size_t cn_stride,         [sp] -> x14
36 //     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])  [sp + 8] -> (x8)
37 
38 // d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
39 
40 // A pointer
41 // x3  a0
42 
43 // C pointer
44 // x6  c0
45 
46 // Clamp v4 v5
47 
48 // Converted from: src/f32-gemm/gen/1x8-minmax-aarch64-neonfma-prfm-cortex-a75.S
generate(bool prefetch,size_t max_mr,size_t nc_mod_nr,size_t kc,float min,float max)49 void Generator::generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, float min, float max)
50 {
51   assert(nc_mod_nr < 8);
52   assert(kc != 0);
53   assert(kc % sizeof(float) == 0);
54 
55   Label l0, l1, l2, l3, l4, l5, l6, l7, l8, l9, l10, l11, l12;
56   const bool clamp_min = min != -std::numeric_limits<float>::infinity();
57   const bool clamp_max = max != +std::numeric_limits<float>::infinity();
58 
59   // Load cn_stride, params pointer
60   ldp(x14, x8, mem[sp]);
61 
62   // Load min/max values
63   if (clamp_max) {
64     ld2r({v4.v4s(), v5.v4s()}, mem[x8]);
65   } else if (clamp_min) {
66     if (min == 0.f) {
67       movi(v4.v4s(), 0);
68     } else {
69       ld1r({v4.v4s()}, mem[x8]);
70     }
71   }
72 
73   bind(l0);
74   // Load initial bias from w into accumulators
75   ldp(q16, q17, mem[x5], 32);
76 
77   movi(v18.v4s(), 0); // second set of C for pipelining FMLA
78   if (prefetch) {
79     prfm(kPLDL1KEEP, mem[x5]);
80   }
81   movi(v19.v4s(), 0);
82   if (prefetch) {
83     prfm(kPLDL1KEEP, mem[x5, 64]);
84     prfm(kPLDL1KEEP, mem[x5, 128]);
85     prfm(kPLDL1KEEP, mem[x5, 192]);
86   }
87 
88   // Is there at least 8 floats (32 bytes) for prologue + epilogue?
89   subs(x0, x2, 32); // k = kc - 32
90 
91   b_lo(l3);
92 
93   // 16 prologue
94   // Read first block of 1 A and B.
95   ldp(q20, q21, mem[x5], 32);
96   ldp(q22, q23, mem[x5], 32);
97   ldp(q24, q25, mem[x5], 32);
98   ldp(q26, q27, mem[x5], 32);
99   ldr(q0, mem[x3], 16);
100 
101   // Is there at least 32.  yes do main loop
102   subs(x0, x0, 32);
103   b_lo(l2);
104 
105   // Main loop - 8 floats of A (32 bytes)
106   bind(l1);
107   // First block of 4.  FMA for first 4, loads for 2nd block of 4.
108   fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
109   ldr(q1, mem[x3], 16);
110   fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
111   ldp(q20, q21, mem[x5], 32);
112   fmla(v18.v4s(), v22.v4s(), v0.s()[1]);
113   if (prefetch) {
114     prfm(kPLDL1KEEP, mem[x5, 96]);
115   }
116   fmla(v19.v4s(), v23.v4s(), v0.s()[1]);
117   ldp(q22, q23, mem[x5], 32);
118   fmla(v16.v4s(), v24.v4s(), v0.s()[2]);
119   fmla(v17.v4s(), v25.v4s(), v0.s()[2]);
120   ldp(q24, q25, mem[x5], 32);
121   fmla(v18.v4s(), v26.v4s(), v0.s()[3]);
122   fmla(v19.v4s(), v27.v4s(), v0.s()[3]);
123   ldp(q26, q27, mem[x5], 32);
124 
125   // Second block of 4.  FMA for second 4, loads for 1st block of 4.
126   fmla(v16.v4s(), v20.v4s(), v1.s()[0]);
127   ldr(q0, mem[x3], 16);
128   fmla(v17.v4s(), v21.v4s(), v1.s()[0]);
129   ldp(q20, q21, mem[x5], 32);
130   fmla(v18.v4s(), v22.v4s(), v1.s()[1]);
131   fmla(v19.v4s(), v23.v4s(), v1.s()[1]);
132   ldp(q22, q23, mem[x5], 32);
133   fmla(v16.v4s(), v24.v4s(), v1.s()[2]);
134   fmla(v17.v4s(), v25.v4s(), v1.s()[2]);
135   ldp(q24, q25, mem[x5], 32);
136   fmla(v18.v4s(), v26.v4s(), v1.s()[3]);
137   fmla(v19.v4s(), v27.v4s(), v1.s()[3]);
138   subs(x0, x0, 32);
139   ldp(q26, q27, mem[x5], 32);
140   b_hs(l1);
141 
142   bind(l2);
143   // Epilogue
144 
145   // First block of 4.  FMA for first 4, loads for 2nd block of 4.
146   fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
147   ldr(q1, mem[x3], 16);
148   fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
149   ldp(q20, q21, mem[x5], 32);
150   fmla(v18.v4s(), v22.v4s(), v0.s()[1]);
151   fmla(v19.v4s(), v23.v4s(), v0.s()[1]);
152   ldp(q22, q23, mem[x5], 32);
153   fmla(v16.v4s(), v24.v4s(), v0.s()[2]);
154   fmla(v17.v4s(), v25.v4s(), v0.s()[2]);
155   ldp(q24, q25, mem[x5], 32);
156   fmla(v18.v4s(), v26.v4s(), v0.s()[3]);
157   fmla(v19.v4s(), v27.v4s(), v0.s()[3]);
158   ldp(q26, q27, mem[x5], 32);
159 
160   // Second block of 4.  no loads
161   fmla(v16.v4s(), v20.v4s(), v1.s()[0]);
162   fmla(v17.v4s(), v21.v4s(), v1.s()[0]);
163   fmla(v18.v4s(), v22.v4s(), v1.s()[1]);
164   fmla(v19.v4s(), v23.v4s(), v1.s()[1]);
165   fmla(v16.v4s(), v24.v4s(), v1.s()[2]);
166   fmla(v17.v4s(), v25.v4s(), v1.s()[2]);
167   fmla(v18.v4s(), v26.v4s(), v1.s()[3]);
168   fmla(v19.v4s(), v27.v4s(), v1.s()[3]);
169 
170   bind(l3);
171   // Is there a remainder?- 4 floats of A (16 bytes)
172   tbnz(x0, 4, l5);
173   // Is there a remainder?- 2 floats of A (8 bytes)
174   tbnz(x0, 3, l6);
175   // Is there a remainder?- 1 float of A (4 bytes)
176   tbnz(x0, 2, l8);
177 
178   bind(l4);
179   fadd(v16.v4s(), v16.v4s(), v18.v4s());
180   subs(x1, x1, 8);
181   fadd(v17.v4s(), v17.v4s(), v19.v4s());
182 
183   // Clamp
184   if (clamp_min) {
185     fmax(v16.v4s(), v16.v4s(), v4.v4s());
186     fmax(v17.v4s(), v17.v4s(), v4.v4s());
187   }
188   if (clamp_max) {
189     fmin(v16.v4s(), v16.v4s(), v5.v4s());
190     fmin(v17.v4s(), v17.v4s(), v5.v4s());
191   }
192 
193   // Store full 1 x 8
194   b_lo(l9);
195 
196   stp(q16, q17, mem[x6]);
197   add(x6, x6, x14);
198 
199   sub(x3, x3, x2); // a0 -= kc
200 
201   b_hi(l0);
202 
203   ret();
204 
205   bind(l5);
206   // Remainder- 4 floats of A (16 bytes)
207   ldp(q20, q21, mem[x5], 32);
208   ldr(q0, mem[x3], 16);
209   fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
210   fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
211   ldp(q22, q23, mem[x5], 32);
212   ldp(q24, q25, mem[x5], 32);
213   ldp(q26, q27, mem[x5], 32);
214   fmla(v18.v4s(), v22.v4s(), v0.s()[1]);
215   fmla(v19.v4s(), v23.v4s(), v0.s()[1]);
216   fmla(v16.v4s(), v24.v4s(), v0.s()[2]);
217   fmla(v17.v4s(), v25.v4s(), v0.s()[2]);
218   fmla(v18.v4s(), v26.v4s(), v0.s()[3]);
219   fmla(v19.v4s(), v27.v4s(), v0.s()[3]);
220 
221   tbz(x0, 3, l7);
222   bind(l6);
223   // Remainder- 2 floats of A (8 bytes)
224   ldp(q20, q21, mem[x5], 32);
225   ldr(d0, mem[x3], 8);
226   fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
227   fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
228   ldp(q22, q23, mem[x5], 32);
229   fmla(v18.v4s(), v22.v4s(), v0.s()[1]);
230   fmla(v19.v4s(), v23.v4s(), v0.s()[1]);
231   bind(l7);
232   tbz(x0, 2, l4);
233   bind(l8);
234   // Remainder- 1 float of A (4 bytes)
235   ldp(q20, q21, mem[x5], 32);
236   ldr(s0, mem[x3], 4);
237   fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
238   fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
239   b(l4);
240 
241   // Store odd channels
242   bind(l9);
243   tbz(x1, 2, l10);
244   str(q16, mem[x6], 16);
245   mov(v16.v16b(), v17.v16b());
246 
247   bind(l10);
248   tbz(x1, 1, l11);
249   str(d16, mem[x6], 8);
250   dup(d16, v16.d()[1]);
251 
252   bind(l11);
253   tbz(x1, 0, l12);
254   str(s16, mem[x6]);
255   bind(l12);
256   ret();
257 
258   align(16, AlignInstruction::kHlt);
259 }
260 } // namespace
261 } // namespace aarch64
262 } // namespace xnnpack
263 
xnn_generate_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,const void * params)264 xnn_status_t xnn_generate_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75(
265     xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params)
266 {
267   using namespace xnnpack::aarch64;
268   Generator g(code);
269   assert(params != nullptr);
270   const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);
271   g.generate(false, max_mr, nc_mod_nr, kc, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);
272   g.finalize();
273   if (g.error() != xnnpack::Error::kNoError) {
274     return xnn_status_invalid_state;
275   }
276   return xnn_status_success;
277 }
278 
xnn_generate_f32_gemm_ukernel_1x8__aarch64_neonfma_prfm_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,const void * params)279 xnn_status_t xnn_generate_f32_gemm_ukernel_1x8__aarch64_neonfma_prfm_cortex_a75(
280     xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params)
281 {
282   using namespace xnnpack::aarch64;
283   Generator g(code);
284   assert(params != nullptr);
285   const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);
286   g.generate(true, max_mr, nc_mod_nr, kc, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);
287   g.finalize();
288   if (g.error() != xnnpack::Error::kNoError) {
289     return xnn_status_invalid_state;
290   }
291   return xnn_status_success;
292 }
293