1 /*
2 * Copyright (c) 2017-2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "arm_gemm.hpp"
25 #include "gemm_common.hpp"
26 #include "gemm_hybrid.hpp"
27 #include "gemm_hybrid_indirect.hpp"
28 #include "gemm_implementation.hpp"
29 #include "gemm_interleaved.hpp"
30 #include "gemv_batched.hpp"
31 #include "gemv_pretransposed.hpp"
32
33 #include "kernels/a32_sgemm_8x6.hpp"
34 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
35 #include "kernels/a64_ffhybrid_fp32_mla_6x16.hpp"
36 #include "kernels/a64_ffhybrid_fp32bf16fp32_mmla_4x24.hpp"
37 #include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp"
38 #include "kernels/a64_ffinterleaved_fp32_mla_8x12.hpp"
39 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
40 #include "kernels/a64_hybrid_fp32bf16fp32_mmla_4x24.hpp"
41 #include "kernels/a64_hybrid_fp32bf16fp32_mmla_6x16.hpp"
42 #include "kernels/a64_hybrid_fp32_mla_4x24.hpp"
43 #include "kernels/a64_hybrid_fp32_mla_6x16.hpp"
44 #include "kernels/a64_hybrid_fp32_mla_8x4.hpp"
45 #include "kernels/a64_interleaved_bf16fp32_mmla_8x12.hpp"
46 #include "kernels/a64_sgemm_8x12.hpp"
47 #include "kernels/a64_sgemm_8x6.hpp"
48 #include "kernels/a64_smallK_hybrid_fp32_mla_6x4.hpp"
49 #include "kernels/a64_smallK_hybrid_fp32_mla_8x4.hpp"
50
51 #ifdef ARM_COMPUTE_ENABLE_SVE
52 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
53 #include "kernels/sve_ffhybrid_fp32_mla_6x4VL.hpp"
54 #include "kernels/sve_ffhybrid_fp32bf16fp32_mmla_4x6VL.hpp"
55 #include "kernels/sve_ffinterleaved_fp32_mla_8x3VL.hpp"
56 #include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp"
57 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
58 #ifdef ARM_COMPUTE_ENABLE_SME2
59 #include "kernels/sme2_gemv_fp32_mla_16VL.hpp"
60 #include "kernels/sme2_gemv_fp32bf16fp32_dot_16VL.hpp"
61 #include "kernels/sme2_interleaved_nomerge_fp32_mopa_1VLx4VL.hpp"
62 #include "kernels/sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL.hpp"
63 #include "kernels/sme2_interleaved_nomerge_fp32_mopa_2VLx2VL.hpp"
64 #include "kernels/sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL.hpp"
65 #include "kernels/sme2_interleaved_nomerge_fp32_mopa_4VLx1VL.hpp"
66 #include "kernels/sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL.hpp"
67 #endif // ARM_COMPUTE_ENABLE_SME2
68
69 #include "kernels/sve_hybrid_fp32bf16fp32_mmla_4x6VL.hpp"
70 #include "kernels/sve_hybrid_fp32bf16fp32_mmla_6x4VL.hpp"
71 #include "kernels/sve_hybrid_fp32_mla_6x4VL.hpp"
72 #include "kernels/sve_hybrid_fp32_mla_8x1VL.hpp"
73 #include "kernels/sve_interleaved_bf16fp32_mmla_8x3VL.hpp"
74 #include "kernels/sve_interleaved_fp32_mla_8x3VL.hpp"
75 #include "kernels/sve_interleaved_fp32_mmla_8x3VL.hpp"
76 #include "kernels/sve_smallK_hybrid_fp32_mla_8x1VL.hpp"
77 #endif // ARM_COMPUTE_ENABLE_SVE
78
79 namespace arm_gemm {
80
81 static const GemmImplementation<float, float> gemm_fp32_methods[] =
82 {
83 // GEMV cases - starting with 'gemv_batched' wrapper to turn batched GEMV into GEMM.
84 {
85 GemmMethod::GEMV_BATCHED,
86 "gemv_batched",
__anon01b2a01e0102() 87 [](const GemmArgs &args) { return args._Msize==1 && args._nbatches>1 && !args._indirect_input; },
88 nullptr,
__anon01b2a01e0202() 89 [](const GemmArgs &args) { return new GemvBatched<float, float>(args); }
90 },
91 #ifdef __aarch64__
92 #ifdef ARM_COMPUTE_ENABLE_BF16
93 // "fast mode" (BF16) kernels
94 GemmImplementation<float, float>::with_estimate(
95 GemmMethod::GEMM_INTERLEAVED,
96 "a64_interleaved_bf16fp32_mmla_8x12",
__anon01b2a01e0302() 97 [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); },
__anon01b2a01e0402() 98 [](const GemmArgs &args) { return GemmInterleaved<cls_a64_interleaved_bf16fp32_mmla_8x12, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e0502() 99 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_bf16fp32_mmla_8x12, float, float>(args); }
100 ),
101
102 GemmImplementation<float, float>::with_estimate(
103 GemmMethod::GEMM_HYBRID,
104 "a64_hybrid_fp32bf16fp32_mmla_6x16",
__anon01b2a01e0602() 105 [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); },
__anon01b2a01e0702() 106 [](const GemmArgs &args) { return GemmHybridIndirect<cls_a64_hybrid_fp32bf16fp32_mmla_6x16, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e0802() 107 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_fp32bf16fp32_mmla_6x16, float, float>(args); }
108 ),
109 GemmImplementation<float, float>::with_estimate(
110 GemmMethod::GEMM_HYBRID,
111 "a64_hybrid_fp32bf16fp32_mmla_4x24",
__anon01b2a01e0902() 112 [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); },
__anon01b2a01e0a02() 113 [](const GemmArgs &args) { return GemmHybridIndirect<cls_a64_hybrid_fp32bf16fp32_mmla_4x24, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e0b02() 114 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_fp32bf16fp32_mmla_4x24, float, float>(args); }
115 ),
116 #endif // ARM_COMPUTE_ENABLE_BF16
117 #ifdef ARM_COMPUTE_ENABLE_SVE
118 #ifdef ARM_COMPUTE_ENABLE_SME2
119 // SME kernels
120 {
121 GemmMethod::GEMM_HYBRID,
122 "sme2_gemv_fp32bf16fp32_dot_16VL",
__anon01b2a01e0c02() 123 [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; },
124 nullptr,
__anon01b2a01e0d02() 125 [](const GemmArgs &args) { return new GemvPretransposed<cls_sme2_gemv_fp32bf16fp32_dot_16VL, float, float>(args); }
126 },
127 {
128 GemmMethod::GEMM_HYBRID,
129 "sme2_gemv_fp32_mla_16VL",
__anon01b2a01e0e02() 130 [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; },
131 nullptr,
__anon01b2a01e0f02() 132 [](const GemmArgs &args) { return new GemvPretransposed<cls_sme2_gemv_fp32_mla_16VL, float, float>(args); }
133 },
134 #ifdef ARM_COMPUTE_ENABLE_BF16
135 {
136 GemmMethod::GEMM_INTERLEAVED,
137 "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL",
__anon01b2a01e1002() 138 [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); },
__anon01b2a01e1102() 139 [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
140 return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
__anon01b2a01e1202() 141 [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, float, float>(args); }
142 },
143 #endif // ARM_COMPUTE_ENABLE_BF16
144 {
145 GemmMethod::GEMM_INTERLEAVED,
146 "sme2_interleaved_nomerge_fp32_mopa_1VLx4VL",
__anon01b2a01e1302() 147 [](const GemmArgs &args) { return args._ci->has_sme2(); },
__anon01b2a01e1402() 148 [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
149 return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
__anon01b2a01e1502() 150 [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_1VLx4VL, float, float>(args); }
151 },
152 #ifdef ARM_COMPUTE_ENABLE_BF16
153 {
154 GemmMethod::GEMM_INTERLEAVED,
155 "sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL",
__anon01b2a01e1602() 156 [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); },
__anon01b2a01e1702() 157 [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
158 return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
__anon01b2a01e1802() 159 [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL, float, float>(args); }
160 },
161 #endif // ARM_COMPUTE_ENABLE_BF16
162 {
163 GemmMethod::GEMM_INTERLEAVED,
164 "sme2_interleaved_nomerge_fp32_mopa_4VLx1VL",
__anon01b2a01e1902() 165 [](const GemmArgs &args) { return args._ci->has_sme2(); },
__anon01b2a01e1a02() 166 [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
167 return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
__anon01b2a01e1b02() 168 [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_4VLx1VL, float, float>(args); }
169 },
170 #ifdef ARM_COMPUTE_ENABLE_BF16
171 {
172 GemmMethod::GEMM_INTERLEAVED,
173 "sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL",
__anon01b2a01e1c02() 174 [](const GemmArgs &args) { return args._fast_mode && args._ci->has_sme2(); },
175 nullptr,
__anon01b2a01e1d02() 176 [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL, float, float>(args); }
177 },
178 #endif // ARM_COMPUTE_ENABLE_BF16
179 {
180 GemmMethod::GEMM_INTERLEAVED,
181 "sme2_interleaved_nomerge_fp32_mopa_2VLx2VL",
__anon01b2a01e1e02() 182 [](const GemmArgs &args) { return args._ci->has_sme2(); },
183 nullptr,
__anon01b2a01e1f02() 184 [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_fp32_mopa_2VLx2VL, float, float>(args); }
185 },
186 #endif // ARM_COMPUTE_ENABLE_SME2
187 #ifdef ARM_COMPUTE_ENABLE_BF16
188 GemmImplementation<float, float>::with_estimate(
189 GemmMethod::GEMM_INTERLEAVED,
190 "sve_interleaved_bf16fp32_mmla_8x3VL",
__anon01b2a01e2002() 191 [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); },
__anon01b2a01e2102() 192 [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_bf16fp32_mmla_8x3VL, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e2202() 193 [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_bf16fp32_mmla_8x3VL, float, float>(args); }
194 ),
195 GemmImplementation<float, float>::with_estimate(
196 GemmMethod::GEMM_HYBRID,
197 "sve_hybrid_fp32bf16fp32_mmla_6x4VL",
__anon01b2a01e2302() 198 [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); },
__anon01b2a01e2402() 199 [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_6x4VL, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e2502() 200 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_6x4VL, float, float>(args); }
201 ),
202 GemmImplementation<float, float>::with_estimate(
203 GemmMethod::GEMM_HYBRID,
204 "sve_hybrid_fp32bf16fp32_mmla_4x6VL",
__anon01b2a01e2602() 205 [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); },
__anon01b2a01e2702() 206 [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_4x6VL, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e2802() 207 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32bf16fp32_mmla_4x6VL, float, float>(args); }
208 ),
209 #endif // ARM_COMPUTE_ENABLE_BF16
210 #ifdef ARM_COMPUTE_ENABLE_SVEF32MM
211 // MMLA next due to higher throughput (which is SVE only)
212 // Prefer this in all cases, except if fast mode is requested and BF16 is available.
213 {
214 GemmMethod::GEMM_INTERLEAVED,
215 "sve_interleaved_fp32_mmla_8x3VL",
__anon01b2a01e2902() 216 [](const GemmArgs &args) { return args._ci->has_svef32mm() && (args._Ksize>4); },
__anon01b2a01e2a02() 217 [](const GemmArgs &args) { return !(args._fast_mode && args._ci->has_bf16()); },
__anon01b2a01e2b02() 218 [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp32_mmla_8x3VL, float, float>(args); }
219 },
220 #endif // ARM_COMPUTE_ENABLE_SVEF32MM
221 // SVE kernels
222 {
223 GemmMethod::GEMM_HYBRID,
224 "sve_smallK_hybrid_fp32_mla_8x1VL",
__anon01b2a01e2c02() 225 [](const GemmArgs &args) { return args._ci->has_sve() && args._Ksize <= 24 && !args._indirect_input; },
226 nullptr,
__anon01b2a01e2d02() 227 [](const GemmArgs &args) { return new GemmHybrid<cls_sve_smallK_hybrid_fp32_mla_8x1VL, float, float>(args); }
228 },
229 {
230 GemmMethod::GEMM_HYBRID,
231 "sve_hybrid_fp32_mla_8x1VL",
__anon01b2a01e2e02() 232 [](const GemmArgs &args) { return args._ci->has_sve(); },
__anon01b2a01e2f02() 233 [](const GemmArgs &args) { return (args._Nsize < 12); },
__anon01b2a01e3002() 234 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32_mla_8x1VL, float, float>(args); }
235 },
236 GemmImplementation<float, float>::with_estimate(
237 GemmMethod::GEMM_HYBRID,
238 "sve_hybrid_fp32_mla_6x4VL",
__anon01b2a01e3102() 239 [](const GemmArgs &args) { return args._ci->has_sve(); },
__anon01b2a01e3202() 240 [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp32_mla_6x4VL, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e3302() 241 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp32_mla_6x4VL, float, float>(args); }
242 ),
243 GemmImplementation<float, float>::with_estimate(
244 GemmMethod::GEMM_INTERLEAVED,
245 "sve_interleaved_fp32_mla_8x3VL",
__anon01b2a01e3402() 246 [](const GemmArgs &args) { return args._ci->has_sve(); },
__anon01b2a01e3502() 247 [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e3602() 248 [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp32_mla_8x3VL, float, float>(args); }
249 ),
250 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
251 #ifdef ARM_COMPUTE_ENABLE_BF16
252 GemmImplementation<float, float>::with_estimate(
253 GemmMethod::GEMM_INTERLEAVED,
254 "sve_ffinterleaved_bf16fp32_mmla_8x3VL",
255 KernelWeightFormat::VL2VL_BL64_BF16,
__anon01b2a01e3702() 256 [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); },
__anon01b2a01e3802() 257 [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_sve_ffinterleaved_bf16fp32_mmla_8x3VL, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e3902() 258 [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_sve_ffinterleaved_bf16fp32_mmla_8x3VL, float, float>(args); }
259 ),
260 GemmImplementation<float, float>::with_estimate(
261 GemmMethod::GEMM_HYBRID,
262 "sve_ffhybrid_fp32bf16fp32_mmla_4x6VL",
263 KernelWeightFormat::VL2VL_BL64_BF16,
__anon01b2a01e3a02() 264 [](const GemmArgs &args) { return args._fast_mode && args._ci->has_svebf16(); },
__anon01b2a01e3b02() 265 [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32bf16fp32_mmla_4x6VL, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e3c02() 266 [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32bf16fp32_mmla_4x6VL, float, float>(args); }
267 ),
268 #endif
269 GemmImplementation<float, float>::with_estimate(
270 GemmMethod::GEMM_INTERLEAVED,
271 "sve_ffinterleaved_fp32_mla_8x3VL",
272 KernelWeightFormat::VL1VL_BL32,
__anon01b2a01e3d02() 273 [](const GemmArgs &args) { return args._ci->has_sve(); },
__anon01b2a01e3e02() 274 [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp32_mla_8x3VL, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e3f02() 275 [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp32_mla_8x3VL, float, float>(args); }
276 ),
277 GemmImplementation<float, float>::with_estimate(
278 GemmMethod::GEMM_HYBRID,
279 "sve_ffhybrid_fp32_mla_6x4VL",
280 KernelWeightFormat::VL1VL_BL32,
__anon01b2a01e4002() 281 [](const GemmArgs &args) { return args._ci->has_sve(); },
__anon01b2a01e4102() 282 [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e4202() 283 [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp32_mla_6x4VL, float, float>(args); }
284 ),
285 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
286 #endif // ARM_COMPUTE_ENABLE_SVE
287 // Cortex-A35 specific kernel - use for any problem on A35, and never in any other cases.
288 {
289 GemmMethod::GEMM_INTERLEAVED,
290 "a64_sgemm_8x6",
291 nullptr,
__anon01b2a01e4302() 292 [](const GemmArgs &args) { return args._ci->get_cpu_model() == CPUModel::A35; },
__anon01b2a01e4402() 293 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x6, float, float>(args); }
294 },
295 // Arm® Neon™ hybrid methods
296 {
297 GemmMethod::GEMM_HYBRID,
298 "a64_smallK_hybrid_fp32_mla_8x4",
__anon01b2a01e4502() 299 [](const GemmArgs &args) { return args._Ksize <= 8 && (args._Nsize % 4)==0 && !args._indirect_input; },
300 nullptr,
__anon01b2a01e4602() 301 [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_fp32_mla_8x4, float, float>(args); }
302 },
303 {
304 GemmMethod::GEMM_HYBRID,
305 "a64_smallK_hybrid_fp32_mla_6x4",
__anon01b2a01e4702() 306 [](const GemmArgs &args) { return (args._Ksize > 8 && args._Ksize <= 16) && (args._Nsize % 4)==0 && !args._indirect_input; },
307 nullptr,
__anon01b2a01e4802() 308 [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_fp32_mla_6x4, float, float>(args); }
309 },
310 {
311 GemmMethod::GEMM_HYBRID,
312 "a64_hybrid_fp32_mla_8x4",
313 nullptr,
__anon01b2a01e4902() 314 [](const GemmArgs &args) { return (args._Nsize < 12); },
__anon01b2a01e4a02() 315 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_fp32_mla_8x4, float, float>(args); }
316 },
317 GemmImplementation<float, float>::with_estimate(
318 GemmMethod::GEMM_HYBRID,
319 "a64_hybrid_fp32_mla_4x24",
320 nullptr,
__anon01b2a01e4b02() 321 [](const GemmArgs &args) { return GemmHybridIndirect<cls_a64_hybrid_fp32_mla_4x24, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e4c02() 322 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_fp32_mla_4x24, float, float>(args); }
323 ),
324 GemmImplementation<float, float>::with_estimate(
325 GemmMethod::GEMM_HYBRID,
326 "a64_hybrid_fp32_mla_6x16",
327 nullptr,
__anon01b2a01e4d02() 328 [](const GemmArgs &args) { return GemmHybridIndirect<cls_a64_hybrid_fp32_mla_6x16, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e4e02() 329 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_fp32_mla_6x16, float, float>(args); }
330 ),
331 GemmImplementation<float, float>::with_estimate(
332 GemmMethod::GEMM_INTERLEAVED,
333 "a64_sgemm_8x12",
334 nullptr,
__anon01b2a01e4f02() 335 [](const GemmArgs &args) { return GemmInterleaved<cls_a64_sgemm_8x12, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e5002() 336 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x12, float, float>(args); }
337 ),
338 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
339 #ifdef ARM_COMPUTE_ENABLE_BF16
340 // "fast mode" (BF16) kernels
341 GemmImplementation<float, float>::with_estimate(
342 GemmMethod::GEMM_INTERLEAVED,
343 "a64_ffinterleaved_bf16fp32_mmla_8x12",
344 KernelWeightFormat::VL256_BL64_BF16,
__anon01b2a01e5102() 345 [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); },
__anon01b2a01e5202() 346 [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e5302() 347 [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, float, float>(args); }
348 ),
349 GemmImplementation<float, float>::with_estimate(
350 GemmMethod::GEMM_HYBRID,
351 "a64_ffhybrid_fp32bf16fp32_mmla_4x24",
352 KernelWeightFormat::VL256_BL64_BF16,
__anon01b2a01e5402() 353 [](const GemmArgs &args) { return args._fast_mode && args._ci->has_bf16(); },
__anon01b2a01e5502() 354 [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_4x24, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e5602() 355 [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32bf16fp32_mmla_4x24, float, float>(args); }
356 ),
357 #endif // BF16
358 GemmImplementation<float, float>::with_estimate(
359 GemmMethod::GEMM_INTERLEAVED,
360 "a64_ffinterleaved_fp32_mla_8x12",
361 KernelWeightFormat::VL128_BL32,
362 nullptr,
__anon01b2a01e5702() 363 [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp32_mla_8x12, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e5802() 364 [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp32_mla_8x12, float, float>(args); }
365 ),
366 GemmImplementation<float, float>::with_estimate(
367 GemmMethod::GEMM_HYBRID,
368 "a64_ffhybrid_fp32_mla_6x16",
369 KernelWeightFormat::VL128_BL32,
370 nullptr,
__anon01b2a01e5902() 371 [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>::estimate_cycles<float>(args); },
__anon01b2a01e5a02() 372 [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp32_mla_6x16, float, float>(args); }
373 ),
374 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
375 #endif // __aarch64__
376
377 #ifdef __arm__
378 {
379 GemmMethod::GEMM_INTERLEAVED,
380 "sgemm_8x6",
381 nullptr,
382 nullptr,
__anon01b2a01e5b02() 383 [](const GemmArgs &args) { return new GemmInterleaved<sgemm_8x6, float, float>(args); }
384 },
385 #endif // __arm__
386 {
387 GemmMethod::DEFAULT,
388 "",
389 nullptr,
390 nullptr,
391 nullptr
392 }
393 };
394
395 /* Templated function to return this list. */
396 template<>
gemm_implementation_list()397 const GemmImplementation<float, float> *gemm_implementation_list<float, float>() {
398 return gemm_fp32_methods;
399 }
400
401 /* Explicitly instantiate the external functions for these types. */
402 template UniqueGemmCommon<float, float> gemm<float, float, Nothing>(const GemmArgs &args, const Nothing &);
403 template bool has_opt_gemm<float, float, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
404 template KernelDescription get_gemm_method<float, float, Nothing>(const GemmArgs &args, const Nothing &);
405 template std::vector<KernelDescription> get_compatible_kernels<float, float, Nothing> (const GemmArgs &args, const Nothing &);
406
407 } // namespace arm_gemm
408