1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <cassert>
7 #include <cstddef>
8 #include <limits>
9
10 #include <xnnpack.h>
11 #include <xnnpack/aarch64-assembler.h>
12 #include <xnnpack/allocator.h>
13 #include <xnnpack/igemm.h>
14
15
16 namespace xnnpack {
17 namespace aarch64 {
18 namespace {
19 class Generator : public Assembler {
20 using Assembler::Assembler;
21
22 public:
23 void generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, float min, float max);
24 };
25
26 // void xnn_f32_igemm_minmax_ukernel_4x8__aarch64_neonfma_prfm_cortex_a75(
27 // size_t mr, (x0) - unused. mr = 1
28 // size_t nc, x1
29 // size_t kc, x2 / x0
30 // size_t ks, x3 / x9
31 // const float**restrict a, x4
32 // const float*restrict w, x5
33 // float*restrict c, x6
34 // size_t cm_stride, (x7) - unused
35 // size_t cn_stride, [sp] -> x10
36 // size_t a_offset, [sp + 8] -> x11
37 // const float* zero, [sp + 16] -> x12
38 // const xnn_f32_minmax_params params [sp + 24] -> (x8)
39
40 // d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
41
42 // A pointer
43 // x8 a0
44
45 // C pointer
46 // x6 c0
47
48 // Converted from: src/f32-igemm/gen/1x8-minmax-aarch64-neonfma-prfm-cortex-a75.S
generate(bool prefetch,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,float min,float max)49 void Generator::generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, float min, float max)
50 {
51 assert(nc_mod_nr < 8);
52 assert(kc != 0);
53 assert(kc % sizeof(float) == 0);
54 assert(ks != 0);
55
56 Label l0, l1, l2, l3, l4, l5, l6, l7, l8, l9, l10, l11, l12, l13;
57 const bool clamp_min = min != -std::numeric_limits<float>::infinity();
58 const bool clamp_max = max != +std::numeric_limits<float>::infinity();
59
60 // Load cn_stride, a_offset
61 ldp(x10, x11, mem[sp]);
62
63 // Load zero, params pointer
64 ldp(x12, x8, mem[sp, 16]);
65
66 // Load min/max values
67 if (clamp_max) {
68 ld2r({v30.v4s(), v31.v4s()}, mem[x8]);
69 } else if (clamp_min) {
70 if (min == 0.f) {
71 movi(v30.v4s(), 0);
72 } else {
73 ld1r({v30.v4s()}, mem[x8]);
74 }
75 }
76
77 bind(l0);
78 // Load initial bias from w into accumulators
79 ldp(q16, q17, mem[x5], 32);
80 movi(v18.v4s(), 0); // second set of C for pipelining FMLA
81 if (prefetch) {
82 prfm(kPLDL1KEEP, mem[x5]);
83 }
84 movi(v19.v4s(), 0);
85 if (prefetch) {
86 prfm(kPLDL1KEEP, mem[x5, 64]);
87 prfm(kPLDL1KEEP, mem[x5, 128]);
88 prfm(kPLDL1KEEP, mem[x5, 192]);
89 }
90
91 mov(x9, x3); // p = ks
92
93 bind(l1);
94 // Load next A pointer
95 ldr(x8, mem[x4], 8);
96
97 cmp(x8, x12); // if a0 == zero
98 add(x8, x8, x11); // a0 += a_offset
99 csel(x8, x12, x8, kEQ); // a0 = zero, else += a0 + a_offset
100
101 // Is there at least 8 floats (32 bytes) for prologue + epilogue?
102 subs(x0, x2, 32); // k = kc - 32
103 b_lo(l4);
104
105 // 16 prologue
106 // Read first block of A and B.
107 ldp(q20, q21, mem[x5], 32);
108 ldp(q22, q23, mem[x5], 32);
109 ldp(q24, q25, mem[x5], 32);
110 ldp(q26, q27, mem[x5], 32);
111 ldr(q0, mem[x8], 16);
112
113 // Is there at least 8. yes do main loop
114 subs(x0, x0, 32);
115 b_lo(l3);
116
117 // Main loop - 8 floats of A (32 bytes)
118 bind(l2);
119 // First block of 4. FMA for first 4, loads for 2nd block of 4.
120 fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
121 ldr(q1, mem[x8], 16);
122 fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
123 ldp(q20, q21, mem[x5], 32);
124 fmla(v18.v4s(), v22.v4s(), v0.s()[1]);
125 fmla(v19.v4s(), v23.v4s(), v0.s()[1]);
126 ldp(q22, q23, mem[x5], 32);
127 fmla(v16.v4s(), v24.v4s(), v0.s()[2]);
128 fmla(v17.v4s(), v25.v4s(), v0.s()[2]);
129 ldp(q24, q25, mem[x5], 32);
130 if (prefetch) {
131 prfm(kPLDL1KEEP, mem[x5, 128]);
132 }
133 fmla(v18.v4s(), v26.v4s(), v0.s()[3]);
134 if (prefetch) {
135 prfm(kPLDL1KEEP, mem[x5, 256]);
136 }
137 fmla(v19.v4s(), v27.v4s(), v0.s()[3]);
138 ldp(q26, q27, mem[x5], 32);
139
140 // Second block of 4. FMA for second 4, loads for 1st block of 4.
141 fmla(v16.v4s(), v20.v4s(), v1.s()[0]);
142 ldr(q0, mem[x8], 16);
143 fmla(v17.v4s(), v21.v4s(), v1.s()[0]);
144 ldp(q20, q21, mem[x5], 32);
145 fmla(v18.v4s(), v22.v4s(), v1.s()[1]);
146 fmla(v19.v4s(), v23.v4s(), v1.s()[1]);
147 ldp(q22, q23, mem[x5], 32);
148 fmla(v16.v4s(), v24.v4s(), v1.s()[2]);
149 fmla(v17.v4s(), v25.v4s(), v1.s()[2]);
150 ldp(q24, q25, mem[x5], 32);
151 if (prefetch) {
152 prfm(kPLDL1KEEP, mem[x5, 128]);
153 }
154 fmla(v18.v4s(), v26.v4s(), v1.s()[3]);
155 if (prefetch) {
156 prfm(kPLDL1KEEP, mem[x5, 256]);
157 }
158 fmla(v19.v4s(), v27.v4s(), v1.s()[3]);
159 subs(x0, x0, 32);
160 ldp(q26, q27, mem[x5], 32);
161 b_hs(l2);
162
163 bind(l3);
164 // Epilogue
165
166 // First block of 4. FMA for first 4, loads for 2nd block of 4.
167 fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
168 ldr(q1, mem[x8], 16);
169 fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
170 ldp(q20, q21, mem[x5], 32);
171 fmla(v18.v4s(), v22.v4s(), v0.s()[1]);
172 fmla(v19.v4s(), v23.v4s(), v0.s()[1]);
173 ldp(q22, q23, mem[x5], 32);
174 fmla(v16.v4s(), v24.v4s(), v0.s()[2]);
175 fmla(v17.v4s(), v25.v4s(), v0.s()[2]);
176 ldp(q24, q25, mem[x5], 32);
177 if (prefetch) {
178 prfm(kPLDL1KEEP, mem[x5, 128]);
179 }
180 fmla(v18.v4s(), v26.v4s(), v0.s()[3]);
181 if (prefetch) {
182 prfm(kPLDL1KEEP, mem[x5, 256]);
183 }
184 fmla(v19.v4s(), v27.v4s(), v0.s()[3]);
185 ldp(q26, q27, mem[x5], 32);
186
187 // Second block of 4. no loads
188 fmla(v16.v4s(), v20.v4s(), v1.s()[0]);
189 fmla(v17.v4s(), v21.v4s(), v1.s()[0]);
190 fmla(v18.v4s(), v22.v4s(), v1.s()[1]);
191 fmla(v19.v4s(), v23.v4s(), v1.s()[1]);
192 fmla(v16.v4s(), v24.v4s(), v1.s()[2]);
193 fmla(v17.v4s(), v25.v4s(), v1.s()[2]);
194 fmla(v18.v4s(), v26.v4s(), v1.s()[3]);
195 fmla(v19.v4s(), v27.v4s(), v1.s()[3]);
196
197 bind(l4);
198 // Is there a remainder?- 4 floats of A (16 bytes)
199 tbnz(x0, 4, l6);
200 // Is there a remainder?- 2 floats of A (8 bytes)
201 tbnz(x0, 3, l7);
202 // Is there a remainder?- 1 float of A (4 bytes)
203 tbnz(x0, 2, l9);
204
205 bind(l5);
206 // ks loop
207 subs(x9, x9, 8); // ks -= MR * sizeof(void*)
208 b_hi(l1);
209
210 fadd(v16.v4s(), v16.v4s(), v18.v4s());
211 fadd(v17.v4s(), v17.v4s(), v19.v4s());
212
213 // Clamp
214 if (clamp_min) {
215 fmax(v16.v4s(), v16.v4s(), v30.v4s());
216 fmax(v17.v4s(), v17.v4s(), v30.v4s());
217 }
218 if (clamp_max) {
219 fmin(v16.v4s(), v16.v4s(), v31.v4s());
220 fmin(v17.v4s(), v17.v4s(), v31.v4s());
221 }
222
223 // Store full 1 x 8
224 subs(x1, x1, 8);
225 b_lo(l10);
226
227 stp(q16, q17, mem[x6]);
228 add(x6, x6, x10);
229
230 sub(x4, x4, x3); // a -= ks
231
232 // nc loop
233 b_hi(l0);
234
235 ret();
236
237 bind(l6);
238 // Remainder- 4 floats of A (16 bytes)
239 ldp(q20, q21, mem[x5], 32);
240 ldr(q0, mem[x8], 16);
241 fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
242 fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
243 ldp(q22, q23, mem[x5], 32);
244 ldp(q24, q25, mem[x5], 32);
245 ldp(q26, q27, mem[x5], 32);
246 fmla(v18.v4s(), v22.v4s(), v0.s()[1]);
247 fmla(v19.v4s(), v23.v4s(), v0.s()[1]);
248 fmla(v16.v4s(), v24.v4s(), v0.s()[2]);
249 fmla(v17.v4s(), v25.v4s(), v0.s()[2]);
250 fmla(v18.v4s(), v26.v4s(), v0.s()[3]);
251 fmla(v19.v4s(), v27.v4s(), v0.s()[3]);
252
253 tbz(x0, 3, l8);
254 bind(l7);
255 // Remainder- 2 floats of A (8 bytes)
256 ldp(q20, q21, mem[x5], 32);
257 ldr(d0, mem[x8], 8);
258 fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
259 fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
260 ldp(q22, q23, mem[x5], 32);
261 fmla(v18.v4s(), v22.v4s(), v0.s()[1]);
262 fmla(v19.v4s(), v23.v4s(), v0.s()[1]);
263 bind(l8);
264 tbz(x0, 2, l5);
265 bind(l9);
266 // Remainder- 1 float of A (4 bytes)
267 ldp(q20, q21, mem[x5], 32);
268 ldr(s0, mem[x8], 4);
269 fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
270 fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
271 b(l5);
272
273 bind(l10);
274 // Store odd channels
275 tbz(x1, 2, l11);
276 str(q16, mem[x6], 16);
277 mov(v16.v16b(), v17.v16b());
278
279 bind(l11);
280 tbz(x1, 1, l12);
281 str(d16, mem[x6], 8);
282 dup(d16, v16.d()[1]);
283
284 bind(l12);
285 tbz(x1, 0, l13);
286 str(s16, mem[x6], 4);
287 bind(l13);
288 ret();
289
290 align(16, AlignInstruction::kHlt);
291 }
292 } // namespace
293 } // namespace aarch64
294 } // namespace xnnpack
295
xnn_generate_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,const void * params)296 xnn_status_t xnn_generate_f32_igemm_ukernel_1x8__aarch64_neonfma_cortex_a75(
297 xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params)
298 {
299 using namespace xnnpack::aarch64;
300 Generator g(code);
301 assert(params != nullptr);
302 auto jit_params = static_cast<const jit_gemm_params*>(params);
303 g.generate(false, max_mr, nc_mod_nr, kc, ks, jit_params->f32_minmax.min, jit_params->f32_minmax.max);
304 g.finalize();
305 if (g.error() != xnnpack::Error::kNoError) {
306 return xnn_status_invalid_state;
307 }
308 return xnn_status_success;
309 }
310
xnn_generate_f32_igemm_ukernel_1x8__aarch64_neonfma_prfm_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,const void * params)311 xnn_status_t xnn_generate_f32_igemm_ukernel_1x8__aarch64_neonfma_prfm_cortex_a75(
312 xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params)
313 {
314 using namespace xnnpack::aarch64;
315 Generator g(code);
316 assert(params != nullptr);
317 auto jit_params = static_cast<const jit_gemm_params*>(params);
318 g.generate(true, max_mr, nc_mod_nr, kc, ks, jit_params->f32_minmax.min, jit_params->f32_minmax.max);
319 g.finalize();
320 if (g.error() != xnnpack::Error::kNoError) {
321 return xnn_status_invalid_state;
322 }
323 return xnn_status_success;
324 }
325