xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_gemm.hpp"
25 #include "gemm_common.hpp"
26 #include "gemm_hybrid.hpp"
27 #include "gemm_hybrid_indirect.hpp"
28 #include "gemm_implementation.hpp"
29 #include "gemm_interleaved.hpp"
30 #include "gemv_batched.hpp"
31 #include "gemv_pretransposed.hpp"
32 
33 #include "kernels/a32_sgemm_8x6.hpp"
34 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
35 #include "kernels/a64_ffhybrid_fp32_mla_6x16.hpp"
36 #include "kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp"
37 #include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp"
38 #include "kernels/a64_ffinterleaved_fp32_mla_8x12.hpp"
39 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
40 #include "kernels/a64_hybrid_fp32bf16fp32_mmla_4x24.hpp"
41 #include "kernels/a64_hybrid_fp32bf16fp32_mmla_6x16.hpp"
42 #include "kernels/a64_hybrid_fp32_mla_4x24.hpp"
43 #include "kernels/a64_hybrid_fp32_mla_6x16.hpp"
44 #include "kernels/a64_hybrid_fp32_mla_8x4.hpp"
45 #include "kernels/a64_interleaved_bf16fp32_mmla_8x12.hpp"
46 #include "kernels/a64_sgemm_8x12.hpp"
47 #include "kernels/a64_sgemm_8x6.hpp"
48 #include "kernels/a64_smallK_hybrid_fp32_mla_6x4.hpp"
49 #include "kernels/a64_smallK_hybrid_fp32_mla_8x4.hpp"
50 
51 #ifdef ARM_COMPUTE_ENABLE_SVE
52 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
53 #include "kernels/sve_ffhybrid_fp32_mla_6x4VL.hpp"
54 #include "kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp"
55 #include "kernels/sve_ffinterleaved_fp32_mla_8x3VL.hpp"
56 #include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp"
57 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
58 #ifdef ARM_COMPUTE_ENABLE_SME2
59 #include "kernels/sme2_gemv_fp32_mla_16VL.hpp"
60 #include "kernels/sme2_gemv_fp32bf16fp32_dot_16VL.hpp"
61 #include "kernels/sme2_interleaved_nomerge_fp32_mopa_1VLx4VL.hpp"
62 #include "kernels/sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL.hpp"
63 #include "kernels/sme2_interleaved_nomerge_fp32_mopa_2VLx2VL.hpp"
64 #include "kernels/sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL.hpp"
65 #include "kernels/sme2_interleaved_nomerge_fp32_mopa_4VLx1VL.hpp"
66 #include "kernels/sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL.hpp"
67 #endif // ARM_COMPUTE_ENABLE_SME2
68 
69 #include "kernels/sve_hybrid_fp32bf16fp32_mmla_4x6VL.hpp"
70 #include "kernels/sve_hybrid_fp32bf16fp32_mmla_6x4VL.hpp"
71 #include "kernels/sve_hybrid_fp32_mla_6x4VL.hpp"
72 #include "kernels/sve_hybrid_fp32_mla_8x1VL.hpp"
73 #include "kernels/sve_interleaved_bf16fp32_mmla_8x3VL.hpp"
74 #include "kernels/sve_interleaved_fp32_mla_8x3VL.hpp"
75 #include "kernels/sve_interleaved_fp32_mmla_8x3VL.hpp"
76 #include "kernels/sve_smallK_hybrid_fp32_mla_8x1VL.hpp"
77 #endif // ARM_COMPUTE_ENABLE_SVE
78 
79 namespace arm_gemm {
80 
81 static const GemmImplementation<float, float> gemm_fp32_methods[] =
82 {
83 // GEMV cases - starting with 'gemv_batched' wrapper to turn batched GEMV into GEMM.
84 {
85     GemmMethod::GEMV_BATCHED,
86     "gemv_batched",
__anon01b2a01e0102() 87     [](const GemmArgs &args) { return args._Msize==1 && args._nbatches>1 && !args._indirect_input; },
88     nullptr,
__anon01b2a01e0202() 89     [](const GemmArgs &args) { return new GemvBatched<float, float>(args); }
90 },
91 #ifdef __aarch64__
92 #ifdef ARM_COMPUTE_ENABLE_BF16
93 // "fast mode" (BF16) kernels
94 GemmImplementation<float, float>::with_estimate(
95     GemmMethod::GEMM_INTERLEAVED,
96     "a64_interleaved_bf16fp32_mmla_8x12",
__anon01b2a01e0302() 97     [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); },
__anon01b2a01e0402() 98     [](const GemmArgs &args) { return GemmInterleaved<cls_a64_interleaved_bf16fp32_mmla_8x12, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e0502() 99     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_bf16fp32_mmla_8x12, float, float>(args); }
100 ),
101 
102 GemmImplementation<float, float>::with_estimate(
103     GemmMethod::GEMM_HYBRID,
104     "a64_hybrid_fp32bf16fp32_mmla_6x16",
__anon01b2a01e0602() 105     [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); },
__anon01b2a01e0702() 106     [](const GemmArgs &args) { return GemmHybridIndirect<cls_a64_hybrid_fp32bf16fp32_mmla_6x16, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e0802() 107     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_fp32bf16fp32_mmla_6x16, float, float>(args); }
108 ),
109 GemmImplementation<float, float>::with_estimate(
110     GemmMethod::GEMM_HYBRID,
111     "a64_hybrid_fp32bf16fp32_mmla_4x24",
__anon01b2a01e0902() 112     [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); },
__anon01b2a01e0a02() 113     [](const GemmArgs &args) { return GemmHybridIndirect<cls_a64_hybrid_fp32bf16fp32_mmla_4x24, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e0b02() 114     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_fp32bf16fp32_mmla_4x24, float, float>(args); }
115 ),
116 #endif // ARM_COMPUTE_ENABLE_BF16
117 #ifdef ARM_COMPUTE_ENABLE_SVE
118 #ifdef ARM_COMPUTE_ENABLE_SME2
119 // SME kernels
120 {
121     GemmMethod::GEMM_HYBRID,
122     "sme2_gemv_fp32bf16fp32_dot_16VL",
__anon01b2a01e0c02() 123     [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; },
124     nullptr,
__anon01b2a01e0d02() 125     [](const GemmArgs &args) { return new GemvPretransposed<cls_sme2_gemv_fp32bf16fp32_dot_16VL, float, float>(args); }
126 },
127 {
128     GemmMethod::GEMM_HYBRID,
129     "sme2_gemv_fp32_mla_16VL",
__anon01b2a01e0e02() 130     [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; },
131     nullptr,
__anon01b2a01e0f02() 132     [](const GemmArgs &args) { return new GemvPretransposed<cls_sme2_gemv_fp32_mla_16VL, float, float>(args); }
133 },
134 #ifdef ARM_COMPUTE_ENABLE_BF16
135 {
136     GemmMethod::GEMM_INTERLEAVED,
137     "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL",
__anon01b2a01e1002() 138     [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); },
__anon01b2a01e1102() 139     [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
140                                return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
__anon01b2a01e1202() 141     [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, float, float>(args); }
142 },
143 #endif // ARM_COMPUTE_ENABLE_BF16
144 {
145     GemmMethod::GEMM_INTERLEAVED,
146     "sme2_interleaved_nomerge_fp32_mopa_1VLx4VL",
__anon01b2a01e1302() 147     [](const GemmArgs &args) { return args._ci->has_sme2(); },
__anon01b2a01e1402() 148     [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
149                                return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
__anon01b2a01e1502() 150     [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_1VLx4VL, float, float>(args); }
151 },
152 #ifdef ARM_COMPUTE_ENABLE_BF16
153 {
154     GemmMethod::GEMM_INTERLEAVED,
155     "sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL",
__anon01b2a01e1602() 156     [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); },
__anon01b2a01e1702() 157     [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
158                                return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
__anon01b2a01e1802() 159     [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL, float, float>(args); }
160 },
161 #endif // ARM_COMPUTE_ENABLE_BF16
162 {
163     GemmMethod::GEMM_INTERLEAVED,
164     "sme2_interleaved_nomerge_fp32_mopa_4VLx1VL",
__anon01b2a01e1902() 165     [](const GemmArgs &args) { return args._ci->has_sme2(); },
__anon01b2a01e1a02() 166     [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
167                                return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
__anon01b2a01e1b02() 168     [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_4VLx1VL, float, float>(args); }
169 },
170 #ifdef ARM_COMPUTE_ENABLE_BF16
171 {
172     GemmMethod::GEMM_INTERLEAVED,
173     "sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL",
__anon01b2a01e1c02() 174     [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); },
175     nullptr,
__anon01b2a01e1d02() 176     [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL, float, float>(args); }
177 },
178 #endif // ARM_COMPUTE_ENABLE_BF16
179 {
180     GemmMethod::GEMM_INTERLEAVED,
181     "sme2_interleaved_nomerge_fp32_mopa_2VLx2VL",
__anon01b2a01e1e02() 182     [](const GemmArgs &args) { return args._ci->has_sme2(); },
183     nullptr,
__anon01b2a01e1f02() 184     [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_2VLx2VL, float, float>(args); }
185 },
186 #endif // ARM_COMPUTE_ENABLE_SME2
187 #ifdef ARM_COMPUTE_ENABLE_BF16
188 GemmImplementation<float, float>::with_estimate(
189     GemmMethod::GEMM_INTERLEAVED,
190     "sve_interleaved_bf16fp32_mmla_8x3VL",
__anon01b2a01e2002() 191     [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); },
__anon01b2a01e2102() 192     [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_bf16fp32_mmla_8x3VL, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e2202() 193     [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_bf16fp32_mmla_8x3VL, float, float>(args); }
194 ),
195 GemmImplementation<float, float>::with_estimate(
196     GemmMethod::GEMM_HYBRID,
197     "sve_hybrid_fp32bf16fp32_mmla_6x4VL",
__anon01b2a01e2302() 198     [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); },
__anon01b2a01e2402() 199     [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_6x4VL, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e2502() 200     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_6x4VL, float, float>(args); }
201 ),
202 GemmImplementation<float, float>::with_estimate(
203     GemmMethod::GEMM_HYBRID,
204     "sve_hybrid_fp32bf16fp32_mmla_4x6VL",
__anon01b2a01e2602() 205     [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); },
__anon01b2a01e2702() 206     [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_4x6VL, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e2802() 207     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_4x6VL, float, float>(args); }
208 ),
209 #endif // ARM_COMPUTE_ENABLE_BF16
210 #ifdef ARM_COMPUTE_ENABLE_SVEF32MM
211 // MMLA next due to higher throughput (which is SVE only)
212 // Prefer this in all cases, except if fast mode is requested and BF16 is available.
213 {
214     GemmMethod::GEMM_INTERLEAVED,
215     "sve_interleaved_fp32_mmla_8x3VL",
__anon01b2a01e2902() 216     [](const GemmArgs &args) { return args._ci->has_svef32mm() && (args._Ksize>4); },
__anon01b2a01e2a02() 217     [](const GemmArgs &args) { return !(args._fast_mode && args._ci->has_bf16()); },
__anon01b2a01e2b02() 218     [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp32_mmla_8x3VL, float, float>(args); }
219 },
220 #endif // ARM_COMPUTE_ENABLE_SVEF32MM
221 // SVE kernels
222 {
223     GemmMethod::GEMM_HYBRID,
224     "sve_smallK_hybrid_fp32_mla_8x1VL",
__anon01b2a01e2c02() 225     [](const GemmArgs &args) { return args._ci->has_sve() && args._Ksize <= 24 && !args._indirect_input; },
226     nullptr,
__anon01b2a01e2d02() 227     [](const GemmArgs &args) { return new GemmHybrid<cls_sve_smallK_hybrid_fp32_mla_8x1VL, float, float>(args); }
228 },
229 {
230     GemmMethod::GEMM_HYBRID,
231     "sve_hybrid_fp32_mla_8x1VL",
__anon01b2a01e2e02() 232     [](const GemmArgs &args) { return args._ci->has_sve(); },
__anon01b2a01e2f02() 233     [](const GemmArgs &args) { return (args._Nsize < 12); },
__anon01b2a01e3002() 234     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32_mla_8x1VL, float, float>(args); }
235 },
236 GemmImplementation<float, float>::with_estimate(
237     GemmMethod::GEMM_HYBRID,
238     "sve_hybrid_fp32_mla_6x4VL",
__anon01b2a01e3102() 239     [](const GemmArgs &args) { return args._ci->has_sve(); },
__anon01b2a01e3202() 240     [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp32_mla_6x4VL, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e3302() 241     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32_mla_6x4VL, float, float>(args); }
242 ),
243 GemmImplementation<float, float>::with_estimate(
244     GemmMethod::GEMM_INTERLEAVED,
245     "sve_interleaved_fp32_mla_8x3VL",
__anon01b2a01e3402() 246     [](const GemmArgs &args) { return args._ci->has_sve(); },
__anon01b2a01e3502() 247     [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e3602() 248     [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>(args); }
249 ),
250  #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
251 #ifdef ARM_COMPUTE_ENABLE_BF16
252 GemmImplementation<float, float>::with_estimate(
253     GemmMethod::GEMM_INTERLEAVED,
254     "sve_ffinterleaved_bf16fp32_mmla_8x3VL",
255     KernelWeightFormat::VL2VL_BL64_BF16,
__anon01b2a01e3702() 256     [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); },
__anon01b2a01e3802() 257     [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_sve_ffinterleaved_bf16fp32_mmla_8x3VL, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e3902() 258     [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_sve_ffinterleaved_bf16fp32_mmla_8x3VL, float, float>(args); }
259 ),
260 GemmImplementation<float, float>::with_estimate(
261     GemmMethod::GEMM_HYBRID,
262     "sve_ffhybrid_fp32bf16fp32_mmla_4x6VL",
263     KernelWeightFormat::VL2VL_BL64_BF16,
__anon01b2a01e3a02() 264     [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); },
__anon01b2a01e3b02() 265     [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32bf16fp32_mmla_4x6VL, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e3c02() 266     [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32bf16fp32_mmla_4x6VL, float, float>(args); }
267 ),
268 #endif
269 GemmImplementation<float, float>::with_estimate(
270     GemmMethod::GEMM_INTERLEAVED,
271     "sve_ffinterleaved_fp32_mla_8x3VL",
272     KernelWeightFormat::VL1VL_BL32,
__anon01b2a01e3d02() 273     [](const GemmArgs &args) { return args._ci->has_sve(); },
__anon01b2a01e3e02() 274     [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp32_mla_8x3VL, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e3f02() 275     [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp32_mla_8x3VL, float, float>(args); }
276 ),
277 GemmImplementation<float, float>::with_estimate(
278     GemmMethod::GEMM_HYBRID,
279     "sve_ffhybrid_fp32_mla_6x4VL",
280     KernelWeightFormat::VL1VL_BL32,
__anon01b2a01e4002() 281     [](const GemmArgs &args) { return args._ci->has_sve(); },
__anon01b2a01e4102() 282     [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e4202() 283     [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>(args); }
284 ),
285 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
286 #endif // ARM_COMPUTE_ENABLE_SVE
287 // Cortex-A35 specific kernel - use for any problem on A35, and never in any other cases.
288 {
289     GemmMethod::GEMM_INTERLEAVED,
290     "a64_sgemm_8x6",
291     nullptr,
__anon01b2a01e4302() 292     [](const GemmArgs &args) { return args._ci->get_cpu_model() == CPUModel::A35; },
__anon01b2a01e4402() 293     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x6, float, float>(args); }
294 },
295 // Arm® Neon™ hybrid methods
296 {
297     GemmMethod::GEMM_HYBRID,
298     "a64_smallK_hybrid_fp32_mla_8x4",
__anon01b2a01e4502() 299     [](const GemmArgs &args) { return args._Ksize <= 8 && (args._Nsize % 4)==0 && !args._indirect_input; },
300     nullptr,
__anon01b2a01e4602() 301     [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_fp32_mla_8x4, float, float>(args); }
302 },
303 {
304     GemmMethod::GEMM_HYBRID,
305     "a64_smallK_hybrid_fp32_mla_6x4",
__anon01b2a01e4702() 306     [](const GemmArgs &args) { return (args._Ksize > 8 && args._Ksize <= 16) && (args._Nsize % 4)==0 && !args._indirect_input; },
307     nullptr,
__anon01b2a01e4802() 308     [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_fp32_mla_6x4, float, float>(args); }
309 },
310 {
311     GemmMethod::GEMM_HYBRID,
312     "a64_hybrid_fp32_mla_8x4",
313     nullptr,
__anon01b2a01e4902() 314     [](const GemmArgs &args) { return (args._Nsize < 12); },
__anon01b2a01e4a02() 315     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_fp32_mla_8x4, float, float>(args); }
316 },
317 GemmImplementation<float, float>::with_estimate(
318     GemmMethod::GEMM_HYBRID,
319     "a64_hybrid_fp32_mla_4x24",
320     nullptr,
__anon01b2a01e4b02() 321     [](const GemmArgs &args) { return GemmHybridIndirect<cls_a64_hybrid_fp32_mla_4x24, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e4c02() 322     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_fp32_mla_4x24, float, float>(args); }
323 ),
324 GemmImplementation<float, float>::with_estimate(
325     GemmMethod::GEMM_HYBRID,
326     "a64_hybrid_fp32_mla_6x16",
327     nullptr,
__anon01b2a01e4d02() 328     [](const GemmArgs &args) { return GemmHybridIndirect<cls_a64_hybrid_fp32_mla_6x16, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e4e02() 329     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_fp32_mla_6x16, float, float>(args); }
330 ),
331 GemmImplementation<float, float>::with_estimate(
332     GemmMethod::GEMM_INTERLEAVED,
333     "a64_sgemm_8x12",
334     nullptr,
__anon01b2a01e4f02() 335     [](const GemmArgs &args) { return GemmInterleaved<cls_a64_sgemm_8x12, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e5002() 336     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x12, float, float>(args); }
337 ),
338 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
339 #ifdef ARM_COMPUTE_ENABLE_BF16
340 // "fast mode" (BF16) kernels
341 GemmImplementation<float, float>::with_estimate(
342     GemmMethod::GEMM_INTERLEAVED,
343     "a64_ffinterleaved_bf16fp32_mmla_8x12",
344     KernelWeightFormat::VL256_BL64_BF16,
__anon01b2a01e5102() 345     [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); },
__anon01b2a01e5202() 346     [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e5302() 347     [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, float, float>(args); }
348 ),
349 GemmImplementation<float, float>::with_estimate(
350     GemmMethod::GEMM_HYBRID,
351     "a64_ffhybrid_fp32bf16fp32_mmla_4x24",
352     KernelWeightFormat::VL256_BL64_BF16,
__anon01b2a01e5402() 353     [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); },
__anon01b2a01e5502() 354     [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_4x24, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e5602() 355     [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_4x24, float, float>(args); }
356 ),
357 #endif // BF16
358 GemmImplementation<float, float>::with_estimate(
359     GemmMethod::GEMM_INTERLEAVED,
360     "a64_ffinterleaved_fp32_mla_8x12",
361     KernelWeightFormat::VL128_BL32,
362     nullptr,
__anon01b2a01e5702() 363     [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp32_mla_8x12, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e5802() 364     [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp32_mla_8x12, float, float>(args); }
365 ),
366 GemmImplementation<float, float>::with_estimate(
367     GemmMethod::GEMM_HYBRID,
368     "a64_ffhybrid_fp32_mla_6x16",
369     KernelWeightFormat::VL128_BL32,
370     nullptr,
__anon01b2a01e5902() 371     [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e5a02() 372     [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>(args); }
373 ),
374 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
375 #endif // __aarch64__
376 
377 #ifdef __arm__
378 {
379     GemmMethod::GEMM_INTERLEAVED,
380     "sgemm_8x6",
381     nullptr,
382     nullptr,
__anon01b2a01e5b02() 383     [](const GemmArgs &args) { return new GemmInterleaved<sgemm_8x6, float, float>(args); }
384 },
385 #endif // __arm__
386 {
387     GemmMethod::DEFAULT,
388     "",
389     nullptr,
390     nullptr,
391     nullptr
392 }
393 };
394 
395 /* Templated function to return this list. */
396 template<>
gemm_implementation_list()397 const GemmImplementation<float, float> *gemm_implementation_list<float, float>() {
398     return gemm_fp32_methods;
399 }
400 
401 /* Explicitly instantiate the external functions for these types. */
402 template UniqueGemmCommon<float, float> gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &);
403 template bool has_opt_gemm<float, float, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
404 template KernelDescription get_gemm_method<float, float, Nothing>(const GemmArgs &args, const Nothing &);
405 template std::vector<KernelDescription> get_compatible_kernels<float, float, Nothing> (const GemmArgs &args, const Nothing &);
406 
407 } // namespace arm_gemm
408