xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/arm_gemm/gemm_quint8.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2019-2020, 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifdef __aarch64__
25 
26 #include "arm_gemm.hpp"
27 
28 #include "kernels/a64_gemm_u16_8x12.hpp"
29 #include "kernels/a64_gemm_u8_4x4.hpp"
30 #include "kernels/a64_gemm_u8_8x12.hpp"
31 #include "kernels/a64_hybrid_u8qa_dot_4x16.hpp"
32 #include "kernels/a64_hybrid_u8qa_mmla_4x16.hpp"
33 #include "kernels/a64_hybrid_u8u32_dot_6x16.hpp"
34 #include "kernels/a64_hybrid_u8u32_mmla_6x16.hpp"
35 #include "kernels/a64_interleaved_u8u32_mmla_8x12.hpp"
36 #include "kernels/a64_smallK_hybrid_u8u32_dot_6x4.hpp"
37 #include "kernels/a64_smallK_hybrid_u8u32_dot_8x4.hpp"
38 
39 #ifdef ARM_COMPUTE_ENABLE_SVE
40 #ifdef ARM_COMPUTE_ENABLE_SME2
41 #include "kernels/sme2_gemv_u8qa_dot_16VL.hpp"
42 #include "kernels/sme2_interleaved_nomerge_u8q_mopa_1VLx4VL.hpp"
43 #include "kernels/sme2_interleaved_nomerge_u8q_mopa_2VLx2VL.hpp"
44 #include "kernels/sme2_interleaved_nomerge_u8q_mopa_4VLx1VL.hpp"
45 #endif // ARM_COMPUTE_ENABLE_SME2
46 
47 #include "kernels/sve_hybrid_u8qa_dot_4x4VL.hpp"
48 #include "kernels/sve_hybrid_u8qa_mmla_4x4VL.hpp"
49 #include "kernels/sve_hybrid_u8u32_dot_6x4VL.hpp"
50 #include "kernels/sve_hybrid_u8u32_mmla_6x4VL.hpp"
51 #include "kernels/sve_interleaved_u8u32_dot_8x3VL.hpp"
52 #include "kernels/sve_interleaved_u8u32_mmla_8x3VL.hpp"
53 #include "kernels/sve_smallK_hybrid_u8u32_dot_8x1VL.hpp"
54 #endif // ARM_COMPUTE_ENABLE_SVE
55 
56 #include "gemm_hybrid_indirect.hpp"
57 #include "gemm_hybrid_quantized.hpp"
58 #include "gemm_hybrid_quantized_inline.hpp"
59 #include "gemm_interleaved.hpp"
60 #include "gemv_pretransposed.hpp"
61 #include "quantize_wrapper.hpp"
62 
63 namespace arm_gemm {
64 
65 static const GemmImplementation<uint8_t, uint8_t, Requantize32> gemm_quint8_methods[] =
66 {
67 #ifdef ARM_COMPUTE_ENABLE_SVE
68 #ifdef ARM_COMPUTE_ENABLE_SME2
69 // SME kernels
70 {
71     GemmMethod::GEMM_HYBRID,
72     "sme2_gemv_u8qa_dot_16VL",
__anon2c615aac0102() 73     [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && quant_hybrid_asymmetric(qp) && args._Msize == 1 && !args._indirect_input && args._nbatches == 1;  },
74     nullptr,
__anon2c615aac0202() 75     [](const GemmArgs &args, const Requantize32 &qp) { return new GemvPretransposed<cls_sme2_gemv_u8qa_dot_16VL, uint8_t, uint8_t, Requantize32>(args, qp); }
76 },
77 {
78     GemmMethod::GEMM_INTERLEAVED,
79     "sme2_interleaved_nomerge_u8q_mopa_1VLx4VL",
__anon2c615aac0302() 80     [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
__anon2c615aac0402() 81     [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<uint32_t>();
82                                return args._Msize <= VL || (2*VL < args._Msize && args._Msize <= 3*VL); },
__anon2c615aac0502() 83     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_1VLx4VL, uint8_t, uint8_t>(args, qp); }
84 },
85 {
86     GemmMethod::GEMM_INTERLEAVED,
87     "sme2_interleaved_nomerge_u8q_mopa_4VLx1VL",
__anon2c615aac0602() 88     [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
__anon2c615aac0702() 89     [](const GemmArgs &args, const Requantize32 &) { const auto VL = sme::get_vector_length<int32_t>();
90                                return args._Nsize <= VL || (2*VL < args._Nsize && args._Nsize <= 3*VL); },
__anon2c615aac0802() 91     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_4VLx1VL, uint8_t, uint8_t>(args, qp); }
92 },
93 {
94     GemmMethod::GEMM_INTERLEAVED,
95     "sme2_interleaved_nomerge_u8q_mopa_2VLx2VL",
__anon2c615aac0902() 96     [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sme2() && args._maxthreads == 1 && ((qp.per_channel_requant && (qp.per_channel_left_shifts == nullptr)) || (!qp.per_channel_requant && (qp.per_layer_left_shift == 0)));},
97     nullptr,
__anon2c615aac0a02() 98     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedPretransposedNoMergeQuantizedInline<cls_sme2_interleaved_nomerge_u8q_mopa_2VLx2VL, uint8_t, uint8_t>(args, qp); }
99 },
100 #endif // ARM_COMPUTE_ENABLE_SME2
101 GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
102     GemmMethod::GEMM_HYBRID,
103     "sve_hybrid_u8qa_mmla_4x4VL",
__anon2c615aac0b02() 104     [](const GemmArgs &args, const Requantize32 &qp) { return quant_hybrid_asymmetric(qp) && args._ci->has_sve2() && args._ci->has_svei8mm(); },
__anon2c615aac0c02() 105     [](const GemmArgs &args, const Requantize32 &) { return GemmHybridIndirect<cls_sve_hybrid_u8qa_mmla_4x4VL, uint8_t, uint8_t, Requantize32>::estimate_cycles<uint8_t>(args); },
__anon2c615aac0d02() 106     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_sve_hybrid_u8qa_mmla_4x4VL, uint8_t, uint8_t, Requantize32>(args, qp); }
107 ),
108 GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
109     GemmMethod::GEMM_INTERLEAVED,
110     "sve_interleaved_u8u32_mmla_8x3VL",
__anon2c615aac0e02() 111     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_svei8mm() && (args._Ksize>8); },
__anon2c615aac0f02() 112     [](const GemmArgs &args, const Requantize32 &) { return GemmInterleavedQuantized<cls_sve_interleaved_u8u32_mmla_8x3VL, uint8_t, uint8_t>::estimate_cycles<uint8_t>(args); },
__anon2c615aac1002() 113     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_sve_interleaved_u8u32_mmla_8x3VL, uint8_t, uint8_t>(args, qp); }
114 ),
115 GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
116     GemmMethod::GEMM_INTERLEAVED,
117     "sve_hybrid_u8u32_mmla_6x4VL",
__anon2c615aac1102() 118     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_svei8mm(); },
__anon2c615aac1202() 119     [](const GemmArgs &args, const Requantize32 &) { return GemmHybridIndirect<cls_sve_hybrid_u8u32_mmla_6x4VL, uint8_t, uint8_t, Requantize32, true>::estimate_cycles<uint8_t>(args); },
__anon2c615aac1302() 120     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_sve_hybrid_u8u32_mmla_6x4VL, uint8_t, uint8_t, Requantize32, true>(args, qp); }
121 ),
122 {
123     GemmMethod::GEMM_HYBRID_QUANTIZED,
124     "sve_smallK_hybrid_u8u32_dot_8x1VL",
__anon2c615aac1402() 125     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_sve() && args._Ksize<=64 && !args._indirect_input; },
__anon2c615aac1502() 126     [](const GemmArgs &args, const Requantize32 &) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
__anon2c615aac1602() 127     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<cls_sve_smallK_hybrid_u8u32_dot_8x1VL, uint8_t, uint8_t>(args, qp); }
128 },
129 GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
130     GemmMethod::GEMM_HYBRID,
131     "sve_hybrid_u8qa_dot_4x4VL",
__anon2c615aac1702() 132     [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_sve2() && quant_hybrid_asymmetric(qp); },
__anon2c615aac1802() 133     [](const GemmArgs &args, const Requantize32 &) { return GemmHybridIndirect<cls_sve_hybrid_u8qa_dot_4x4VL, uint8_t, uint8_t, Requantize32>::estimate_cycles<uint8_t>(args); },
__anon2c615aac1902() 134     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_sve_hybrid_u8qa_dot_4x4VL, uint8_t, uint8_t, Requantize32>(args, qp); }
135 ),
136 GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
137     GemmMethod::GEMM_HYBRID,
138     "sve_hybrid_u8u32_dot_6x4VL",
__anon2c615aac1a02() 139     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_sve(); },
__anon2c615aac1b02() 140     [](const GemmArgs &args, const Requantize32 &) { return GemmHybridIndirect<cls_sve_hybrid_u8u32_dot_6x4VL, uint8_t, uint8_t, Requantize32, true>::estimate_cycles<uint8_t>(args); },
__anon2c615aac1c02() 141     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_sve_hybrid_u8u32_dot_6x4VL, uint8_t, uint8_t, Requantize32, true>(args, qp); }
142 ),
143 GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
144     GemmMethod::GEMM_INTERLEAVED,
145     "sve_interleaved_u8u32_dot_8x3VL",
__anon2c615aac1d02() 146     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_sve() && (args._Ksize>4); },
__anon2c615aac1e02() 147     [](const GemmArgs &args, const Requantize32 &) { return GemmInterleavedQuantized<cls_sve_interleaved_u8u32_dot_8x3VL, uint8_t, uint8_t>::estimate_cycles<uint8_t>(args); },
__anon2c615aac1f02() 148     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_sve_interleaved_u8u32_dot_8x3VL, uint8_t, uint8_t>(args, qp); }
149 ),
150 #endif // ARM_COMPUTE_ENABLE_SVE
151 GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
152     GemmMethod::GEMM_HYBRID,
153     "a64_hybrid_u8qa_mmla_4x16",
__anon2c615aac2002() 154     [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_i8mm() && quant_hybrid_asymmetric(qp); },
__anon2c615aac2102() 155     [](const GemmArgs &args, const Requantize32 &) { return GemmHybridIndirect<cls_a64_hybrid_u8qa_mmla_4x16, uint8_t, uint8_t, Requantize32>::estimate_cycles<uint8_t>(args); },
__anon2c615aac2202() 156     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_a64_hybrid_u8qa_mmla_4x16, uint8_t, uint8_t, Requantize32>(args, qp); }
157 ),
158 GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
159     GemmMethod::GEMM_INTERLEAVED,
160     "a64_interleaved_u8u32_mmla_8x12",
__anon2c615aac2302() 161     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_i8mm() && (args._Ksize>8); },
__anon2c615aac2402() 162     [](const GemmArgs &args, const Requantize32 &) { return GemmInterleavedQuantized<cls_a64_interleaved_u8u32_mmla_8x12, uint8_t, uint8_t>::estimate_cycles<uint8_t>(args); },
__anon2c615aac2502() 163     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_interleaved_u8u32_mmla_8x12, uint8_t, uint8_t>(args, qp); }
164 ),
165 GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
166     GemmMethod::GEMM_INTERLEAVED,
167     "a64_hybrid_u8u32_mmla_6x16",
__anon2c615aac2602() 168     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_i8mm(); },
__anon2c615aac2702() 169     [](const GemmArgs &args, const Requantize32 &) { return GemmHybridIndirect<cls_a64_hybrid_u8u32_mmla_6x16, uint8_t, uint8_t, Requantize32, true>::estimate_cycles<uint8_t>(args); },
__anon2c615aac2802() 170     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_a64_hybrid_u8u32_mmla_6x16, uint8_t, uint8_t, Requantize32, true>(args, qp); }
171 ),
172 {
173     GemmMethod::GEMM_HYBRID_QUANTIZED,
174     "a64_smallK_hybrid_u8u32_dot_8x4",
__anon2c615aac2902() 175     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; },
__anon2c615aac2a02() 176     [](const GemmArgs &args, const Requantize32 &) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
__anon2c615aac2b02() 177     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<cls_a64_smallK_hybrid_u8u32_dot_8x4, uint8_t, uint8_t>(args, qp); }
178 },
179 {
180     GemmMethod::GEMM_HYBRID_QUANTIZED,
181     "a64_smallK_hybrid_u8u32_dot_6x4",
__anon2c615aac2c02() 182     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; },
__anon2c615aac2d02() 183     [](const GemmArgs &args, const Requantize32 &) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
__anon2c615aac2e02() 184     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridQuantized<cls_a64_smallK_hybrid_u8u32_dot_6x4, uint8_t, uint8_t>(args, qp); }
185 },
186 {
187     GemmMethod::GEMM_INTERLEAVED,
188     "a64_gemm_u16_8x12",
189     nullptr,
__anon2c615aac2f02() 190     [](const GemmArgs &args, const Requantize32 &) { return args._ci->get_cpu_model() == CPUModel::A53 && args._Msize > 4; },
__anon2c615aac3002() 191     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_gemm_u16_8x12, uint8_t, uint8_t>(args, qp); },
192 },
193 GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
194     GemmMethod::GEMM_HYBRID,
195     "a64_hybrid_u8qa_dot_4x16",
__anon2c615aac3102() 196     [](const GemmArgs &args, const Requantize32 &qp) { return args._ci->has_dotprod() && quant_hybrid_asymmetric(qp); },
__anon2c615aac3202() 197     [](const GemmArgs &args, const Requantize32 &) { return GemmHybridIndirect<cls_a64_hybrid_u8qa_dot_4x16, uint8_t, uint8_t, Requantize32>::estimate_cycles<uint8_t>(args); },
__anon2c615aac3302() 198     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_a64_hybrid_u8qa_dot_4x16, uint8_t, uint8_t, Requantize32>(args, qp); }
199 ),
200 GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
201     GemmMethod::GEMM_HYBRID,
202     "a64_hybrid_u8u32_dot_6x16",
__anon2c615aac3402() 203     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod(); },
__anon2c615aac3502() 204     [](const GemmArgs &args, const Requantize32 &) { return GemmHybridIndirect<cls_a64_hybrid_u8u32_dot_6x16, uint8_t, uint8_t, Requantize32, true>::estimate_cycles<uint8_t>(args); },
__anon2c615aac3602() 205     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmHybridIndirect<cls_a64_hybrid_u8u32_dot_6x16, uint8_t, uint8_t, Requantize32, true>(args, qp); }
206 ),
207 GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
208     GemmMethod::GEMM_INTERLEAVED,
209     "a64_gemm_u8_8x12",
__anon2c615aac3702() 210     [](const GemmArgs &args, const Requantize32 &) { return args._ci->has_dotprod(); },
__anon2c615aac3802() 211     [](const GemmArgs &args, const Requantize32 &) { return GemmInterleavedQuantized<cls_a64_gemm_u8_8x12, uint8_t, uint8_t>::estimate_cycles<uint8_t>(args); },
__anon2c615aac3902() 212     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_gemm_u8_8x12, uint8_t, uint8_t>(args, qp); }
213 ),
214 GemmImplementation<uint8_t, uint8_t, Requantize32>::with_estimate(
215     GemmMethod::GEMM_INTERLEAVED,
216     "a64_gemm_u8_4x4",
217     nullptr,
__anon2c615aac3a02() 218     [](const GemmArgs &args, const Requantize32 &) { return GemmInterleavedQuantized<cls_a64_gemm_u8_4x4, uint8_t, uint8_t>::estimate_cycles<uint8_t>(args); },
__anon2c615aac3b02() 219     [](const GemmArgs &args, const Requantize32 &qp) { return new GemmInterleavedQuantized<cls_a64_gemm_u8_4x4, uint8_t, uint8_t>(args, qp); }
220 ),
221 {
222     GemmMethod::QUANTIZE_WRAPPER,
223     "quantized_wrapper",
__anon2c615aac3c02() 224     [](const GemmArgs &args, const Requantize32 &) { return !args._indirect_input; },
__anon2c615aac3d02() 225     [](const GemmArgs &, const Requantize32 &) { return false; },
__anon2c615aac3e02() 226     [](const GemmArgs &args, const Requantize32 &qp) { return new QuantizeWrapper<uint8_t, uint8_t, uint32_t>(args, qp); }
227 },
228 {
229     GemmMethod::DEFAULT,
230     "",
231     nullptr,
232     nullptr,
233     nullptr
234 }
235 };
236 
237 template<>
gemm_implementation_list()238 const GemmImplementation<uint8_t, uint8_t, Requantize32> *gemm_implementation_list<uint8_t, uint8_t, Requantize32>() {
239     return gemm_quint8_methods;
240 }
241 
242 template UniqueGemmCommon<uint8_t, uint8_t> gemm<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
243 template bool has_opt_gemm<uint8_t, uint8_t, Requantize32>(WeightFormat &weight_format, const GemmArgs &args, const Requantize32 &os);
244 template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint8_t, Requantize32>(const GemmArgs &args, const Requantize32 &os);
245 
246 } // namespace arm_gemm
247 
248 #endif // __aarch64__
249