xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2020, 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifdef __aarch64__
25 
26 #include "arm_gemm.hpp"
27 #include "gemm_common.hpp"
28 #include "gemm_implementation.hpp"
29 #include "gemm_interleaved.hpp"
30 #include "gemm_hybrid.hpp"
31 #include "gemm_hybrid_indirect.hpp"
32 
33 #include "kernels/a64_gemm_u16_8x12.hpp"
34 #include "kernels/a64_gemm_u8_4x4.hpp"
35 #include "kernels/a64_gemm_u8_8x12.hpp"
36 #include "kernels/a64_hybrid_u8u32_dot_6x16.hpp"
37 #include "kernels/a64_hybrid_u8u32_mmla_6x16.hpp"
38 #include "kernels/a64_interleaved_u8u32_mmla_8x12.hpp"
39 #include "kernels/a64_smallK_hybrid_u8u32_dot_6x4.hpp"
40 #include "kernels/a64_smallK_hybrid_u8u32_dot_8x4.hpp"
41 
42 #include "kernels/sve_hybrid_u8u32_dot_6x4VL.hpp"
43 #include "kernels/sve_hybrid_u8u32_mmla_6x4VL.hpp"
44 #include "kernels/sve_interleaved_u8u32_dot_8x3VL.hpp"
45 #include "kernels/sve_interleaved_u8u32_mmla_8x3VL.hpp"
46 #include "kernels/sve_smallK_hybrid_u8u32_dot_8x1VL.hpp"
47 
48 namespace arm_gemm {
49 
50 static const GemmImplementation<uint8_t, uint32_t> gemm_u8_methods[] = {
51 #ifdef ARM_COMPUTE_ENABLE_SVE
52 GemmImplementation<uint8_t, uint32_t>::with_estimate(
53     GemmMethod::GEMM_HYBRID,
54     "sve_hybrid_u8u32_mmla_6x4VL",
__anonedbc98db0102() 55     [](const GemmArgs &args) { return args._ci->has_svei8mm(); },
__anonedbc98db0202() 56     [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_u8u32_mmla_6x4VL, uint8_t, uint32_t>::estimate_cycles<uint32_t>(args); },
__anonedbc98db0302() 57     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_u8u32_mmla_6x4VL, uint8_t, uint32_t>(args); }
58 ),
59 GemmImplementation<uint8_t, uint32_t>::with_estimate(
60     GemmMethod::GEMM_INTERLEAVED,
61     "sve_interleaved_u8u32_mmla_8x3VL",
__anonedbc98db0402() 62     [](const GemmArgs &args) { return args._ci->has_svei8mm() && (args._Ksize>8); },
__anonedbc98db0502() 63     [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_u8u32_mmla_8x3VL, uint8_t, uint32_t>::estimate_cycles<uint32_t>(args); },
__anonedbc98db0602() 64     [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_u8u32_mmla_8x3VL, uint8_t, uint32_t>(args); }
65 ),
66 {
67     GemmMethod::GEMM_HYBRID,
68     "sve_smallK_hybrid_u8u32_dot_8x1VL",
__anonedbc98db0702() 69     [](const GemmArgs &args) { return args._ci->has_sve() && args._Ksize<=64 && !args._indirect_input; },
__anonedbc98db0802() 70     [](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
__anonedbc98db0902() 71     [](const GemmArgs &args) { return new GemmHybrid<cls_sve_smallK_hybrid_u8u32_dot_8x1VL, uint8_t, uint32_t>(args); }
72 },
73 GemmImplementation<uint8_t, uint32_t>::with_estimate(
74     GemmMethod::GEMM_HYBRID,
75     "sve_hybrid_u8u32_dot_6x4VL",
__anonedbc98db0a02() 76     [](const GemmArgs &args) { return args._ci->has_sve(); },
__anonedbc98db0b02() 77     [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_u8u32_dot_6x4VL, uint8_t, uint32_t>::estimate_cycles<uint32_t>(args); },
__anonedbc98db0c02() 78     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_u8u32_dot_6x4VL, uint8_t, uint32_t>(args); }
79 ),
80 GemmImplementation<uint8_t, uint32_t>::with_estimate(
81     GemmMethod::GEMM_INTERLEAVED,
82     "sve_interleaved_u8u32_dot_8x3VL",
__anonedbc98db0d02() 83     [](const GemmArgs &args) { return args._ci->has_sve() && (args._Ksize>4); },
__anonedbc98db0e02() 84     [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_u8u32_dot_8x3VL, uint8_t, uint32_t>::estimate_cycles<uint32_t>(args); },
__anonedbc98db0f02() 85     [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_u8u32_dot_8x3VL, uint8_t, uint32_t>(args); }
86 ),
87 #endif // ARM_COMPUTE_ENABLE_SVE
88 GemmImplementation<uint8_t, uint32_t>::with_estimate(
89     GemmMethod::GEMM_INTERLEAVED,
90     "a64_interleaved_u8u32_mmla_8x12",
__anonedbc98db1002() 91     [](const GemmArgs &args) { return args._ci->has_i8mm() && (args._Ksize>8); },
__anonedbc98db1102() 92     [](const GemmArgs &args) { return GemmInterleaved<cls_a64_interleaved_u8u32_mmla_8x12, uint8_t, uint32_t>::estimate_cycles<uint32_t>(args); },
__anonedbc98db1202() 93     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_interleaved_u8u32_mmla_8x12, uint8_t, uint32_t>(args); }
94 ),
95 GemmImplementation<uint8_t, uint32_t>::with_estimate(
96     GemmMethod::GEMM_HYBRID,
97     "a64_hybrid_u8u32_mmla_6x16",
__anonedbc98db1302() 98     [](const GemmArgs &args) { return args._ci->has_i8mm(); },
__anonedbc98db1402() 99     [](const GemmArgs &args) { return GemmHybridIndirect<cls_a64_hybrid_u8u32_mmla_6x16, uint8_t, uint32_t>::estimate_cycles<uint32_t>(args); },
__anonedbc98db1502() 100     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_u8u32_mmla_6x16, uint8_t, uint32_t>(args); }
101 ),
102 {
103     GemmMethod::GEMM_HYBRID,
104     "a64_smallK_hybrid_u8u32_dot_8x4",
__anonedbc98db1602() 105     [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize<=32) && !args._indirect_input; },
__anonedbc98db1702() 106     [](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
__anonedbc98db1802() 107     [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_8x4, uint8_t, uint32_t>(args); }
108 },
109 {
110     GemmMethod::GEMM_HYBRID,
111     "a64_smallK_hybrid_u8u32_dot_6x4",
__anonedbc98db1902() 112     [](const GemmArgs &args) { return args._ci->has_dotprod() && (args._Nsize % 4 == 0) && (args._Ksize>32) && (args._Ksize<=64) && !args._indirect_input; },
__anonedbc98db1a02() 113     [](const GemmArgs &args) { return !(args._ci->has_svei8mm() || args._ci->has_i8mm()); },
__anonedbc98db1b02() 114     [](const GemmArgs &args) { return new GemmHybrid<cls_a64_smallK_hybrid_u8u32_dot_6x4, uint8_t, uint32_t>(args); }
115 },
116 {
117     GemmMethod::GEMM_INTERLEAVED,
118     "a64_gemm_u16_8x12",
119     nullptr,
__anonedbc98db1c02() 120     [](const GemmArgs &args) { return args._ci->get_cpu_model() == CPUModel::A53 && args._Msize > 4; },
__anonedbc98db1d02() 121     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_gemm_u16_8x12, uint8_t, uint32_t>(args); },
122 },
123 GemmImplementation<uint8_t, uint32_t>::with_estimate(
124     GemmMethod::GEMM_HYBRID,
125     "a64_hybrid_u8u32_dot_6x16",
__anonedbc98db1e02() 126     [](const GemmArgs &args) { return args._ci->has_dotprod(); },
__anonedbc98db1f02() 127     [](const GemmArgs &args) { return GemmHybridIndirect<cls_a64_hybrid_u8u32_dot_6x16, uint8_t, uint32_t>::estimate_cycles<uint32_t>(args); },
__anonedbc98db2002() 128     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_u8u32_dot_6x16, uint8_t, uint32_t>(args); }
129 ),
130 GemmImplementation<uint8_t, uint32_t>::with_estimate(
131     GemmMethod::GEMM_INTERLEAVED,
132     "a64_gemm_u8_8x12",
__anonedbc98db2102() 133     [](const GemmArgs &args) { return args._ci->has_dotprod(); },
__anonedbc98db2202() 134     [](const GemmArgs &args) { return GemmInterleaved<cls_a64_gemm_u8_8x12, uint8_t, uint32_t>::estimate_cycles<uint32_t>(args); },
__anonedbc98db2302() 135     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_gemm_u8_8x12, uint8_t, uint32_t>(args); }
136 ),
137 GemmImplementation<uint8_t, uint32_t>::with_estimate(
138     GemmMethod::GEMM_INTERLEAVED,
139     "a64_gemm_u8_4x4",
140     nullptr,
__anonedbc98db2402() 141     [](const GemmArgs &args) { return GemmInterleaved<cls_a64_gemm_u8_4x4, uint8_t, uint32_t>::estimate_cycles<uint32_t>(args); },
__anonedbc98db2502() 142     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_gemm_u8_4x4, uint8_t, uint32_t>(args); }
143 ),
144 {
145     GemmMethod::DEFAULT,
146     "",
147     nullptr,
148     nullptr,
149     nullptr
150 }
151 };
152 
153 template<>
gemm_implementation_list()154 const GemmImplementation<uint8_t, uint32_t> *gemm_implementation_list<uint8_t, uint32_t>() {
155     return gemm_u8_methods;
156 }
157 
158 /* Explicitly instantiate the external functions for these types. */
159 template UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t, Nothing>(const GemmArgs &args, const Nothing &);
160 template bool has_opt_gemm<uint8_t, uint32_t, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
161 template std::vector<KernelDescription> get_compatible_kernels<uint8_t, uint32_t, Nothing> (const GemmArgs &args, const Nothing &);
162 
163 } // namespace arm_gemm
164 
165 #endif // __aarch64__
166