xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2020, 2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 // This can only be built if the target/compiler supports FP16 arguments.
26 #if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC))
27 
28 #include "arm_gemm.hpp"
29 
30 #include "gemm_common.hpp"
31 #include "gemm_hybrid.hpp"
32 #include "gemm_hybrid_indirect.hpp"
33 #include "gemm_implementation.hpp"
34 #include "gemm_interleaved.hpp"
35 
36 #include "kernels/a32_sgemm_8x6.hpp"
37 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
38 #include "kernels/a64_ffhybrid_fp16_mla_6x32.hpp"
39 #include "kernels/a64_ffinterleaved_fp16_mla_8x24.hpp"
40 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
41 #include "kernels/a64_hgemm_8x24.hpp"
42 #include "kernels/a64_hybrid_fp16_mla_6x32.hpp"
43 #include "kernels/a64_sgemm_8x12.hpp"
44 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
45 #include "kernels/sve_ffhybrid_fp16_mla_6x4VL.hpp"
46 #include "kernels/sve_ffinterleaved_fp16_mla_8x3VL.hpp"
47 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
48 #include "kernels/sve_hybrid_fp16_mla_6x4VL.hpp"
49 #include "kernels/sve_interleaved_fp16_mla_8x3VL.hpp"
50 
51 namespace arm_gemm {
52 
53 static const GemmImplementation<__fp16, __fp16> gemm_fp16_methods[] = {
54 #ifdef ARM_COMPUTE_ENABLE_SVE
55 GemmImplementation<__fp16, __fp16>::with_estimate(
56     GemmMethod::GEMM_HYBRID,
57     "sve_hybrid_fp16_mla_6x4VL",
__anonfd50b0e00102() 58     [](const GemmArgs &args) { return args._ci->has_sve(); },
__anonfd50b0e00202() 59     [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp16_mla_6x4VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
__anonfd50b0e00302() 60     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp16_mla_6x4VL, __fp16, __fp16>(args); }
61 ),
62 GemmImplementation<__fp16, __fp16>::with_estimate(
63     GemmMethod::GEMM_INTERLEAVED,
64     "sve_interleaved_fp16_mla_8x3VL",
__anonfd50b0e00402() 65     [](const GemmArgs &args) { return args._ci->has_sve() && (args._Ksize > 4); },
__anonfd50b0e00502() 66     [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
__anonfd50b0e00602() 67     [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>(args); }
68 ),
69 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
70 GemmImplementation<__fp16, __fp16>::with_estimate(
71     GemmMethod::GEMM_INTERLEAVED,
72     "sve_ffinterleaved_fp16_mla_8x3VL",
73     KernelWeightFormat::VL1VL_BL16,
__anonfd50b0e00702() 74     [](const GemmArgs &args) { return args._ci->has_sve(); },
__anonfd50b0e00802() 75     [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp16_mla_8x3VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
__anonfd50b0e00902() 76     [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp16_mla_8x3VL, __fp16, __fp16>(args); }
77 ),
78 GemmImplementation<__fp16, __fp16>::with_estimate(
79     GemmMethod::GEMM_HYBRID,
80     "sve_ffhybrid_fp16_mla_6x4VL",
81     KernelWeightFormat::VL1VL_BL16,
__anonfd50b0e00a02() 82     [](const GemmArgs &args) { return args._ci->has_sve(); },
__anonfd50b0e00b02() 83     [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
__anonfd50b0e00c02() 84     [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>(args); }
85 ),
86 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
87 #endif // ARM_COMPUTE_ENABLE_SVE
88 #if defined(__aarch64__)
89 GemmImplementation<__fp16, __fp16>::with_estimate(
90     GemmMethod::GEMM_HYBRID,
91     "a64_hybrid_fp16_mla_6x32",
__anonfd50b0e00d02() 92     [](const GemmArgs &args) { return args._ci->has_fp16(); },
__anonfd50b0e00e02() 93     [](const GemmArgs &args) { return GemmHybridIndirect<cls_a64_hybrid_fp16_mla_6x32, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
__anonfd50b0e00f02() 94     [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_fp16_mla_6x32, __fp16, __fp16>(args); }
95 ),
96 GemmImplementation<__fp16, __fp16>::with_estimate(
97     GemmMethod::GEMM_INTERLEAVED,
98     "a64_hgemm_8x24",
__anonfd50b0e01002() 99     [](const GemmArgs &args) { return args._ci->has_fp16(); },
__anonfd50b0e01102() 100     [](const GemmArgs &args) { return GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
__anonfd50b0e01202() 101     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>(args); }
102 ),
103 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
104 GemmImplementation<__fp16, __fp16>::with_estimate(
105     GemmMethod::GEMM_INTERLEAVED,
106     "a64_ffinterleaved_fp16_mla_8x24",
107     KernelWeightFormat::VL128_BL16,
__anonfd50b0e01302() 108     [](const GemmArgs &args) { return args._ci->has_fp16(); },
__anonfd50b0e01402() 109     [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp16_mla_8x24, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
__anonfd50b0e01502() 110     [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp16_mla_8x24, __fp16, __fp16>(args); }
111 ),
112 GemmImplementation<__fp16, __fp16>::with_estimate(
113     GemmMethod::GEMM_HYBRID,
114     "a64_ffhybrid_fp16_mla_6x32",
115     KernelWeightFormat::VL128_BL16,
__anonfd50b0e01602() 116     [](const GemmArgs &args) { return args._ci->has_fp16(); },
__anonfd50b0e01702() 117     [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
__anonfd50b0e01802() 118     [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>(args); }
119 ),
120 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
121 {
122     GemmMethod::GEMM_INTERLEAVED,
123     "a64_sgemm_8x12",
124     nullptr,
__anonfd50b0e01902() 125     [](const GemmArgs &args) { return !args._ci->has_fp16(); },
__anonfd50b0e01a02() 126     [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x12, __fp16, __fp16>(args); }
127 },
128 #elif defined(__arm__)
129 {
130     GemmMethod::GEMM_INTERLEAVED,
131     "sgemm_8x6",
132     nullptr,
133     nullptr,
134     [](const GemmArgs &args) { return new GemmInterleaved<sgemm_8x6, __fp16, __fp16>(args); }
135 },
136 #else // not AArch64 or AArch32
137 # error Unknown Architecture
138 #endif
139 {
140     GemmMethod::DEFAULT,
141     "",
142     nullptr,
143     nullptr,
144     nullptr,
145 }
146 };
147 
148 template<>
gemm_implementation_list()149 const GemmImplementation<__fp16, __fp16> *gemm_implementation_list<__fp16, __fp16>() {
150     return gemm_fp16_methods;
151 }
152 
153 /* Explicitly instantiate the external functions for these types. */
154 template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
155 template bool has_opt_gemm<__fp16, __fp16, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
156 template KernelDescription get_gemm_method<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
157 template std::vector<KernelDescription> get_compatible_kernels<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
158 
159 } // namespace arm_gemm
160 
161 #endif // defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC))
162