xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/arm_gemm/gemm_bf16.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2020, 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_gemm.hpp"
25 #include "bfloat.hpp"
26 #include "gemm_common.hpp"
27 #include "gemm_hybrid.hpp"
28 #include "gemm_hybrid_indirect.hpp"
29 #include "gemm_implementation.hpp"
30 #include "gemm_interleaved.hpp"
31 #include "gemv_batched.hpp"
32 #include "gemv_pretransposed.hpp"
33 
34 #include "kernels/a32_sgemm_8x6.hpp"
35 
36 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
37 #include "kernels/a64_ffhybrid_bf16fp32_mmla_6x16.hpp"
38 #include "kernels/a64_ffinterleaved_bf16fp32_dot_8x12.hpp"
39 #include "kernels/a64_ffinterleaved_bf16fp32_mmla_8x12.hpp"
40 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
41 #include "kernels/a64_hybrid_bf16fp32_dot_6x16.hpp"
42 #include "kernels/a64_hybrid_bf16fp32_mmla_6x16.hpp"
43 #include "kernels/a64_interleaved_bf16fp32_dot_8x12.hpp"
44 #include "kernels/a64_interleaved_bf16fp32_mmla_8x12.hpp"
45 #include "kernels/a64_sgemm_8x12.hpp"
46 
47 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
48 #include "kernels/sve_ffhybrid_bf16fp32_mmla_6x4VL.hpp"
49 #include "kernels/sve_ffinterleaved_bf16fp32_mmla_8x3VL.hpp"
50 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
51 
52 #ifdef ARM_COMPUTE_ENABLE_SVE
53 #ifdef ARM_COMPUTE_ENABLE_SME2
54 #include "kernels/sme2_gemv_bf16fp32_dot_16VL.hpp"
55 #include "kernels/sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL.hpp"
56 #include "kernels/sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL.hpp"
57 #include "kernels/sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL.hpp"
58 #endif // ARM_COMPUTE_ENABLE_SME2
59 
60 #include "kernels/sve_hybrid_bf16fp32_dot_6x4VL.hpp"
61 #include "kernels/sve_hybrid_bf16fp32_mmla_6x4VL.hpp"
62 #include "kernels/sve_interleaved_bf16fp32_dot_8x3VL.hpp"
63 #include "kernels/sve_interleaved_bf16fp32_mmla_8x3VL.hpp"
64 #endif // ARM_COMPUTE_ENABLE_SVE
65 
66 namespace arm_gemm {
67 
68 static const GemmImplementation<bfloat16, float> gemm_bf16_methods[] =
69 {
70 #ifdef __aarch64__
71 #ifdef ARM_COMPUTE_ENABLE_BF16
72 #ifdef ARM_COMPUTE_ENABLE_SVE
73 #ifdef ARM_COMPUTE_ENABLE_SME2
74 // SME kernels
75 {
76     GemmMethod::GEMM_HYBRID,
77     "sme2_gemv_bf16fp32_dot_16VL",
__anon4a82fdd20102() 78     [](const GemmArgs &args) { return args._ci->has_sme2() && args._Msize==1 && args._nbatches==1 && !args._indirect_input; },
79     nullptr,
__anon4a82fdd20202() 80     [](const GemmArgs &args) { return new GemvPretransposed<cls_sme2_gemv_bf16fp32_dot_16VL, bfloat16, float>(args); }
81 },
82 {
83     GemmMethod::GEMM_INTERLEAVED,
84     "sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL",
__anon4a82fdd20302() 85     [](const GemmArgs &args) { return args._ci->has_sme2(); },
__anon4a82fdd20402() 86     [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
87                                return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
__anon4a82fdd20502() 88     [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_1VLx4VL, bfloat16, float>(args); }
89 },
90 {
91     GemmMethod::GEMM_INTERLEAVED,
92     "sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL",
__anon4a82fdd20602() 93     [](const GemmArgs &args) { return args._ci->has_sme2(); },
__anon4a82fdd20702() 94     [](const GemmArgs &args) { const auto VL = sme::get_vector_length<float>();
95                                return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
__anon4a82fdd20802() 96     [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_4VLx1VL, bfloat16, float>(args); }
97 },
98 {
99     GemmMethod::GEMM_INTERLEAVED,
100     "sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL",
__anon4a82fdd20902() 101     [](const GemmArgs &args) { return args._ci->has_sme2(); },
102     nullptr,
__anon4a82fdd20a02() 103     [](const GemmArgs &args) { return new GemmInterleavedNoMerge<cls_sme2_interleaved_nomerge_bf16fp32_mopa_2VLx2VL, bfloat16, float>(args); }
104 },
105 #endif // ARM_COMPUTE_ENABLE_SME2
106 // gemm_bf16_interleaved
107 GemmImplementation<bfloat16, float>::with_estimate(
108     GemmMethod::GEMM_INTERLEAVED,
109     "sve_interleaved_bf16fp32_mmla_8x3VL",
__anon4a82fdd20b02() 110     [](const GemmArgs &args) { return args._ci->has_svebf16() && (args._Ksize>4); },
__anon4a82fdd20c02() 111     [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_bf16fp32_mmla_8x3VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd20d02() 112     [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_bf16fp32_mmla_8x3VL, bfloat16, float>(args); }
113 ),
114 GemmImplementation<bfloat16, float>::with_estimate(
115     GemmMethod::GEMM_HYBRID,
116     "sve_hybrid_bf16fp32_mmla_6x4VL",
__anon4a82fdd20e02() 117     [](const GemmArgs &args) { return args._ci->has_svebf16(); },
__anon4a82fdd20f02() 118     [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_bf16fp32_mmla_6x4VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd21002() 119     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_bf16fp32_mmla_6x4VL, bfloat16, float>(args); }
120 ),
121 GemmImplementation<bfloat16, float>::with_estimate(
122     GemmMethod::GEMM_HYBRID,
123     "sve_hybrid_bf16fp32_dot_6x4VL",
__anon4a82fdd21102() 124     [](const GemmArgs &args) { return args._ci->has_svebf16(); },
__anon4a82fdd21202() 125     [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_bf16fp32_dot_6x4VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd21302() 126     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_bf16fp32_dot_6x4VL, bfloat16, float>(args); }
127 ),
128 GemmImplementation<bfloat16, float>::with_estimate(
129     GemmMethod::GEMM_INTERLEAVED,
130     "sve_interleaved_bf16fp32_dot_8x3VL",
__anon4a82fdd21402() 131     [](const GemmArgs &args) { return args._ci->has_svebf16() && (args._Ksize>2); },
__anon4a82fdd21502() 132     [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd21602() 133     [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_bf16fp32_dot_8x3VL, bfloat16, float>(args); }
134 ),
135 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
136 GemmImplementation<bfloat16, float>::with_estimate(
137     GemmMethod::GEMM_INTERLEAVED,
138     "sve_ffinterleaved_bf16fp32_mmla_8x3VL",
139     KernelWeightFormat::VL2VL_BL64,
__anon4a82fdd21702() 140     [](const GemmArgs &args) { return args._ci->has_svebf16(); },
__anon4a82fdd21802() 141     [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_sve_ffinterleaved_bf16fp32_mmla_8x3VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd21902() 142     [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_sve_ffinterleaved_bf16fp32_mmla_8x3VL, bfloat16, float>(args); }
143 ),
144 GemmImplementation<bfloat16, float>::with_estimate(
145     GemmMethod::GEMM_INTERLEAVED,
146     "sve_ffhybrid_bf16fp32_mmla_6x4VL",
147     KernelWeightFormat::VL2VL_BL64,
__anon4a82fdd21a02() 148     [](const GemmArgs &args) { return args._ci->has_svebf16(); },
__anon4a82fdd21b02() 149     [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_bf16fp32_mmla_6x4VL, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd21c02() 150     [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_bf16fp32_mmla_6x4VL, bfloat16, float>(args); }
151 ),
152 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
153 #endif // ARM_COMPUTE_ENABLE_SVE
154 GemmImplementation<bfloat16, float>::with_estimate(
155     GemmMethod::GEMM_HYBRID,
156     "a64_hybrid_bf16fp32_mmla_6x16",
__anon4a82fdd21d02() 157     [](const GemmArgs &args) { return args._ci->has_bf16(); },
__anon4a82fdd21e02() 158     [](const GemmArgs &args) { return GemmHybridIndirect<cls_a64_hybrid_bf16fp32_mmla_6x16, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd21f02() 159     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_bf16fp32_mmla_6x16, bfloat16, float>(args); }
160 ),
161 GemmImplementation<bfloat16, float>::with_estimate(
162     GemmMethod::GEMM_INTERLEAVED,
163     "a64_interleaved_bf16fp32_mmla_8x12",
__anon4a82fdd22002() 164     [](const GemmArgs &args) { return args._ci->has_bf16() && (args._Ksize>4); },
__anon4a82fdd22102() 165     [](const GemmArgs &args) { return GemmInterleaved<cls_a64_interleaved_bf16fp32_mmla_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd22202() 166     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_bf16fp32_mmla_8x12, bfloat16, float>(args); }
167 ),
168 GemmImplementation<bfloat16, float>::with_estimate(
169     GemmMethod::GEMM_HYBRID,
170     "a64_hybrid_bf16fp32_dot_6x16",
__anon4a82fdd22302() 171     [](const GemmArgs &args) { return args._ci->has_bf16(); },
__anon4a82fdd22402() 172     [](const GemmArgs &args) { return GemmHybridIndirect<cls_a64_hybrid_bf16fp32_dot_6x16, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd22502() 173     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_bf16fp32_dot_6x16, bfloat16, float>(args); }
174 ),
175 GemmImplementation<bfloat16, float>::with_estimate(
176     GemmMethod::GEMM_INTERLEAVED,
177     "a64_interleaved_bf16fp32_dot_8x12",
__anon4a82fdd22602() 178     [](const GemmArgs &args) { return args._ci->has_bf16() && (args._Ksize>2); },
__anon4a82fdd22702() 179     [](const GemmArgs &args) { return GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd22802() 180     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_bf16fp32_dot_8x12, bfloat16, float>(args); }
181 ),
182 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
183 GemmImplementation<bfloat16, float>::with_estimate(
184     GemmMethod::GEMM_INTERLEAVED,
185     "a64_ffinterleaved_bf16fp32_mmla_8x12",
186     KernelWeightFormat::VL256_BL64,
__anon4a82fdd22902() 187     [](const GemmArgs &args) { return args._ci->has_bf16(); },
__anon4a82fdd22a02() 188     [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd22b02() 189     [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_mmla_8x12, bfloat16, float>(args); }
190 ),
191 GemmImplementation<bfloat16, float>::with_estimate(
192     GemmMethod::GEMM_INTERLEAVED,
193     "a64_ffhybrid_bf16fp32_mmla_6x16",
194     KernelWeightFormat::VL256_BL64,
__anon4a82fdd22c02() 195     [](const GemmArgs &args) { return args._ci->has_bf16(); },
__anon4a82fdd22d02() 196     [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_bf16fp32_mmla_6x16, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd22e02() 197     [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_bf16fp32_mmla_6x16, bfloat16, float>(args); }
198 ),
199 GemmImplementation<bfloat16, float>::with_estimate(
200     GemmMethod::GEMM_INTERLEAVED,
201     "a64_ffinterleaved_bf16fp32_dot_8x12",
202     KernelWeightFormat::VL128_BL32,
__anon4a82fdd22f02() 203     [](const GemmArgs &args) { return args._ci->has_bf16(); },
__anon4a82fdd23002() 204     [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd23102() 205     [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_bf16fp32_dot_8x12, bfloat16, float>(args); }
206 ),
207 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
208 GemmImplementation<bfloat16, float>::with_estimate(
209     GemmMethod::GEMM_INTERLEAVED,
210     "a64_sgemm_8x12",
211     nullptr,
__anon4a82fdd23202() 212     [](const GemmArgs &args) { return GemmInterleaved<cls_a64_sgemm_8x12, bfloat16, float>::estimate_cycles<bfloat16>(args); },
__anon4a82fdd23302() 213     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x12, bfloat16, float>(args); }
214 ),
215 #endif // ARM_COMPUTE_ENABLE_BF16
216 #elif defined(__arm__)
217 {
218     GemmMethod::GEMM_INTERLEAVED,
219     "sgemm_8x6",
220     nullptr,
221     nullptr,
222     [](const GemmArgs &args) { return new GemmInterleaved<sgemm_8x6, bfloat16, float>(args); }
223 },
224 #else
225 # error "Unknown Architecture"
226 #endif
227 {
228     GemmMethod::DEFAULT,
229     "",
230     nullptr,
231     nullptr,
232     nullptr
233 }
234 };
235 
236 template<>
gemm_implementation_list()237 const GemmImplementation<bfloat16, float> *gemm_implementation_list<bfloat16, float>() {
238     return gemm_bf16_methods;
239 }
240 
241 /* Explicitly instantiate the external functions for these types. */
242 template UniqueGemmCommon<bfloat16, float> gemm<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
243 template bool has_opt_gemm<bfloat16, float, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
244 template KernelDescription get_gemm_method<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
245 template std::vector<KernelDescription> get_compatible_kernels<bfloat16, float, Nothing>(const GemmArgs &args, const Nothing &);
246 
247 } // namespace arm_gemm
248