1 /*
2 * Copyright (c) 2017-2020, 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "arm_gemm.hpp"
25 #include "bfloat.hpp"
26 #include "gemm_common.hpp"
27 #include "gemm_hybrid.hpp"
28 #include "gemm_hybrid_indirect.hpp"
29 #include "gemm_implementation.hpp"
30 #include "gemm_interleaved.hpp"
31 #include "gemv_batched.hpp"
32 #include "gemv_pretransposed.hpp"
33
34 #include "kernels/a32_sgemm_8x6.hpp"
35
36 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
37 #include "kernels/a64_ffhybrid_bf16fp32_mmla_6x16.hpp"
38 #include "kernels/a64_ffinterleaved_bf16fp32_dot_8x12.hpp"
39 #include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp"
40 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
41 #include "kernels/a64_hybrid_bf16fp32_dot_6x16.hpp"
42 #include "kernels/a64_hybrid_bf16fp32_mmla_6x16.hpp"
43 #include "kernels/a64_interleaved_bf16fp32_dot_8x12.hpp"
44 #include "kernels/a64_interleaved_bf16fp32_mmla_8x12.hpp"
45 #include "kernels/a64_sgemm_8x12.hpp"
46
47 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
48 #include "kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL.hpp"
49 #include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp"
50 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
51
52 #ifdef ARM_COMPUTE_ENABLE_SVE
53 #ifdef ARM_COMPUTE_ENABLE_SME2
54 #include "kernels/sme2_gemv_bf16fp32_dot_16VL.hpp"
55 #include "kernels/sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL.hpp"
56 #include "kernels/sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL.hpp"
57 #include "kernels/sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL.hpp"
58 #endif // ARM_COMPUTE_ENABLE_SME2
59
60 #include "kernels/sve_hybrid_bf16fp32_dot_6x4VL.hpp"
61 #include "kernels/sve_hybrid_bf16fp32_mmla_6x4VL.hpp"
62 #include "kernels/sve_interleaved_bf16fp32_dot_8x3VL.hpp"
63 #include "kernels/sve_interleaved_bf16fp32_mmla_8x3VL.hpp"
64 #endif // ARM_COMPUTE_ENABLE_SVE
65
66 namespace arm_gemm {
67
68 static const GemmImplementation<bfloat16, float> gemm_bf16_methods[] =
69 {
70 #ifdef __aarch64__
71 #ifdef ARM_COMPUTE_ENABLE_BF16
72 #ifdef ARM_COMPUTE_ENABLE_SVE
73 #ifdef ARM_COMPUTE_ENABLE_SME2
74 // SME kernels
75 {
76 GemmMethod::GEMM_HYBRID,
77 "sme2_gemv_bf16fp32_dot_16VL",
__anon4a82fdd20102() 78 [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; },
79 nullptr,
__anon4a82fdd20202() 80 [](const GemmArgs &args) { return new GemvPretransposed<cls_sme2_gemv_bf16fp32_dot_16VL, bfloat16, float>(args); }
81 },
82 {
83 GemmMethod::GEMM_INTERLEAVED,
84 "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL",
__anon4a82fdd20302() 85 [](const GemmArgs &args) { return args._ci->has_sme2(); },
__anon4a82fdd20402() 86 [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
87 return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
__anon4a82fdd20502() 88 [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, bfloat16, float>(args); }
89 },
90 {
91 GemmMethod::GEMM_INTERLEAVED,
92 "sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL",
__anon4a82fdd20602() 93 [](const GemmArgs &args) { return args._ci->has_sme2(); },
__anon4a82fdd20702() 94 [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
95 return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
__anon4a82fdd20802() 96 [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL, bfloat16, float>(args); }
97 },
98 {
99 GemmMethod::GEMM_INTERLEAVED,
100 "sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL",
__anon4a82fdd20902() 101 [](const GemmArgs &args) { return args._ci->has_sme2(); },
102 nullptr,
__anon4a82fdd20a02() 103 [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL, bfloat16, float>(args); }
104 },
105 #endif // ARM_COMPUTE_ENABLE_SME2
106 // gemm_bf16_interleaved
107 GemmImplementation<bfloat16, float>::with_estimate(
108 GemmMethod::GEMM_INTERLEAVED,
109 "sve_interleaved_bf16fp32_mmla_8x3VL",
__anon4a82fdd20b02() 110 [](const GemmArgs &args) { return args._ci->has_svebf16() && (args._Ksize>4); },
__anon4a82fdd20c02() 111 [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_bf16fp32_mmla_8x3VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd20d02() 112 [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_bf16fp32_mmla_8x3VL, bfloat16, float>(args); }
113 ),
114 GemmImplementation<bfloat16, float>::with_estimate(
115 GemmMethod::GEMM_HYBRID,
116 "sve_hybrid_bf16fp32_mmla_6x4VL",
__anon4a82fdd20e02() 117 [](const GemmArgs &args) { return args._ci->has_svebf16(); },
__anon4a82fdd20f02() 118 [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_bf16fp32_mmla_6x4VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd21002() 119 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_bf16fp32_mmla_6x4VL, bfloat16, float>(args); }
120 ),
121 GemmImplementation<bfloat16, float>::with_estimate(
122 GemmMethod::GEMM_HYBRID,
123 "sve_hybrid_bf16fp32_dot_6x4VL",
__anon4a82fdd21102() 124 [](const GemmArgs &args) { return args._ci->has_svebf16(); },
__anon4a82fdd21202() 125 [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_bf16fp32_dot_6x4VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd21302() 126 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_bf16fp32_dot_6x4VL, bfloat16, float>(args); }
127 ),
128 GemmImplementation<bfloat16, float>::with_estimate(
129 GemmMethod::GEMM_INTERLEAVED,
130 "sve_interleaved_bf16fp32_dot_8x3VL",
__anon4a82fdd21402() 131 [](const GemmArgs &args) { return args._ci->has_svebf16() && (args._Ksize>2); },
__anon4a82fdd21502() 132 [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd21602() 133 [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>(args); }
134 ),
135 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
136 GemmImplementation<bfloat16, float>::with_estimate(
137 GemmMethod::GEMM_INTERLEAVED,
138 "sve_ffinterleaved_bf16fp32_mmla_8x3VL",
139 KernelWeightFormat::VL2VL_BL64,
__anon4a82fdd21702() 140 [](const GemmArgs &args) { return args._ci->has_svebf16(); },
__anon4a82fdd21802() 141 [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_sve_ffinterleaved_bf16fp32_mmla_8x3VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd21902() 142 [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_sve_ffinterleaved_bf16fp32_mmla_8x3VL, bfloat16, float>(args); }
143 ),
144 GemmImplementation<bfloat16, float>::with_estimate(
145 GemmMethod::GEMM_INTERLEAVED,
146 "sve_ffhybrid_bf16fp32_mmla_6x4VL",
147 KernelWeightFormat::VL2VL_BL64,
__anon4a82fdd21a02() 148 [](const GemmArgs &args) { return args._ci->has_svebf16(); },
__anon4a82fdd21b02() 149 [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_bf16fp32_mmla_6x4VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd21c02() 150 [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_bf16fp32_mmla_6x4VL, bfloat16, float>(args); }
151 ),
152 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
153 #endif // ARM_COMPUTE_ENABLE_SVE
154 GemmImplementation<bfloat16, float>::with_estimate(
155 GemmMethod::GEMM_HYBRID,
156 "a64_hybrid_bf16fp32_mmla_6x16",
__anon4a82fdd21d02() 157 [](const GemmArgs &args) { return args._ci->has_bf16(); },
__anon4a82fdd21e02() 158 [](const GemmArgs &args) { return GemmHybridIndirect<cls_a64_hybrid_bf16fp32_mmla_6x16, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd21f02() 159 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_bf16fp32_mmla_6x16, bfloat16, float>(args); }
160 ),
161 GemmImplementation<bfloat16, float>::with_estimate(
162 GemmMethod::GEMM_INTERLEAVED,
163 "a64_interleaved_bf16fp32_mmla_8x12",
__anon4a82fdd22002() 164 [](const GemmArgs &args) { return args._ci->has_bf16() && (args._Ksize>4); },
__anon4a82fdd22102() 165 [](const GemmArgs &args) { return GemmInterleaved<cls_a64_interleaved_bf16fp32_mmla_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd22202() 166 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_bf16fp32_mmla_8x12, bfloat16, float>(args); }
167 ),
168 GemmImplementation<bfloat16, float>::with_estimate(
169 GemmMethod::GEMM_HYBRID,
170 "a64_hybrid_bf16fp32_dot_6x16",
__anon4a82fdd22302() 171 [](const GemmArgs &args) { return args._ci->has_bf16(); },
__anon4a82fdd22402() 172 [](const GemmArgs &args) { return GemmHybridIndirect<cls_a64_hybrid_bf16fp32_dot_6x16, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd22502() 173 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_bf16fp32_dot_6x16, bfloat16, float>(args); }
174 ),
175 GemmImplementation<bfloat16, float>::with_estimate(
176 GemmMethod::GEMM_INTERLEAVED,
177 "a64_interleaved_bf16fp32_dot_8x12",
__anon4a82fdd22602() 178 [](const GemmArgs &args) { return args._ci->has_bf16() && (args._Ksize>2); },
__anon4a82fdd22702() 179 [](const GemmArgs &args) { return GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd22802() 180 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>(args); }
181 ),
182 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
183 GemmImplementation<bfloat16, float>::with_estimate(
184 GemmMethod::GEMM_INTERLEAVED,
185 "a64_ffinterleaved_bf16fp32_mmla_8x12",
186 KernelWeightFormat::VL256_BL64,
__anon4a82fdd22902() 187 [](const GemmArgs &args) { return args._ci->has_bf16(); },
__anon4a82fdd22a02() 188 [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd22b02() 189 [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, bfloat16, float>(args); }
190 ),
191 GemmImplementation<bfloat16, float>::with_estimate(
192 GemmMethod::GEMM_INTERLEAVED,
193 "a64_ffhybrid_bf16fp32_mmla_6x16",
194 KernelWeightFormat::VL256_BL64,
__anon4a82fdd22c02() 195 [](const GemmArgs &args) { return args._ci->has_bf16(); },
__anon4a82fdd22d02() 196 [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_bf16fp32_mmla_6x16, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd22e02() 197 [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_bf16fp32_mmla_6x16, bfloat16, float>(args); }
198 ),
199 GemmImplementation<bfloat16, float>::with_estimate(
200 GemmMethod::GEMM_INTERLEAVED,
201 "a64_ffinterleaved_bf16fp32_dot_8x12",
202 KernelWeightFormat::VL128_BL32,
__anon4a82fdd22f02() 203 [](const GemmArgs &args) { return args._ci->has_bf16(); },
__anon4a82fdd23002() 204 [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd23102() 205 [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>(args); }
206 ),
207 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
208 GemmImplementation<bfloat16, float>::with_estimate(
209 GemmMethod::GEMM_INTERLEAVED,
210 "a64_sgemm_8x12",
211 nullptr,
__anon4a82fdd23202() 212 [](const GemmArgs &args) { return GemmInterleaved<cls_a64_sgemm_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd23302() 213 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x12, bfloat16, float>(args); }
214 ),
215 #endif // ARM_COMPUTE_ENABLE_BF16
216 #elif defined(__arm__)
217 {
218 GemmMethod::GEMM_INTERLEAVED,
219 "sgemm_8x6",
220 nullptr,
221 nullptr,
222 [](const GemmArgs &args) { return new GemmInterleaved<sgemm_8x6, bfloat16, float>(args); }
223 },
224 #else
225 # error "Unknown Architecture"
226 #endif
227 {
228 GemmMethod::DEFAULT,
229 "",
230 nullptr,
231 nullptr,
232 nullptr
233 }
234 };
235
236 template<>
gemm_implementation_list()237 const GemmImplementation<bfloat16, float> *gemm_implementation_list<bfloat16, float>() {
238 return gemm_bf16_methods;
239 }
240
241 /* Explicitly instantiate the external functions for these types. */
242 template UniqueGemmCommon<bfloat16, float> gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
243 template bool has_opt_gemm<bfloat16, float, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
244 template KernelDescription get_gemm_method<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
245 template std::vector<KernelDescription> get_compatible_kernels<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
246
247 } // namespace arm_gemm
248