1 /*
2 * Copyright (c) 2017-2020, 2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25 // This can only be built if the target/compiler supports FP16 arguments.
26 #if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC))
27
28 #include "arm_gemm.hpp"
29
30 #include "gemm_common.hpp"
31 #include "gemm_hybrid.hpp"
32 #include "gemm_hybrid_indirect.hpp"
33 #include "gemm_implementation.hpp"
34 #include "gemm_interleaved.hpp"
35
36 #include "kernels/a32_sgemm_8x6.hpp"
37 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
38 #include "kernels/a64_ffhybrid_fp16_mla_6x32.hpp"
39 #include "kernels/a64_ffinterleaved_fp16_mla_8x24.hpp"
40 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
41 #include "kernels/a64_hgemm_8x24.hpp"
42 #include "kernels/a64_hybrid_fp16_mla_6x32.hpp"
43 #include "kernels/a64_sgemm_8x12.hpp"
44 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
45 #include "kernels/sve_ffhybrid_fp16_mla_6x4VL.hpp"
46 #include "kernels/sve_ffinterleaved_fp16_mla_8x3VL.hpp"
47 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
48 #include "kernels/sve_hybrid_fp16_mla_6x4VL.hpp"
49 #include "kernels/sve_interleaved_fp16_mla_8x3VL.hpp"
50
51 namespace arm_gemm {
52
53 static const GemmImplementation<__fp16, __fp16> gemm_fp16_methods[] = {
54 #ifdef ARM_COMPUTE_ENABLE_SVE
55 GemmImplementation<__fp16, __fp16>::with_estimate(
56 GemmMethod::GEMM_HYBRID,
57 "sve_hybrid_fp16_mla_6x4VL",
__anonfd50b0e00102() 58 [](const GemmArgs &args) { return args._ci->has_sve(); },
__anonfd50b0e00202() 59 [](const GemmArgs &args) { return GemmHybridIndirect<cls_sve_hybrid_fp16_mla_6x4VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
__anonfd50b0e00302() 60 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_sve_hybrid_fp16_mla_6x4VL, __fp16, __fp16>(args); }
61 ),
62 GemmImplementation<__fp16, __fp16>::with_estimate(
63 GemmMethod::GEMM_INTERLEAVED,
64 "sve_interleaved_fp16_mla_8x3VL",
__anonfd50b0e00402() 65 [](const GemmArgs &args) { return args._ci->has_sve() && (args._Ksize > 4); },
__anonfd50b0e00502() 66 [](const GemmArgs &args) { return GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
__anonfd50b0e00602() 67 [](const GemmArgs &args) { return new GemmInterleaved<cls_sve_interleaved_fp16_mla_8x3VL, __fp16, __fp16>(args); }
68 ),
69 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
70 GemmImplementation<__fp16, __fp16>::with_estimate(
71 GemmMethod::GEMM_INTERLEAVED,
72 "sve_ffinterleaved_fp16_mla_8x3VL",
73 KernelWeightFormat::VL1VL_BL16,
__anonfd50b0e00702() 74 [](const GemmArgs &args) { return args._ci->has_sve(); },
__anonfd50b0e00802() 75 [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp16_mla_8x3VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
__anonfd50b0e00902() 76 [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_sve_ffinterleaved_fp16_mla_8x3VL, __fp16, __fp16>(args); }
77 ),
78 GemmImplementation<__fp16, __fp16>::with_estimate(
79 GemmMethod::GEMM_HYBRID,
80 "sve_ffhybrid_fp16_mla_6x4VL",
81 KernelWeightFormat::VL1VL_BL16,
__anonfd50b0e00a02() 82 [](const GemmArgs &args) { return args._ci->has_sve(); },
__anonfd50b0e00b02() 83 [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
__anonfd50b0e00c02() 84 [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_sve_ffhybrid_fp16_mla_6x4VL, __fp16, __fp16>(args); }
85 ),
86 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
87 #endif // ARM_COMPUTE_ENABLE_SVE
88 #if defined(__aarch64__)
89 GemmImplementation<__fp16, __fp16>::with_estimate(
90 GemmMethod::GEMM_HYBRID,
91 "a64_hybrid_fp16_mla_6x32",
__anonfd50b0e00d02() 92 [](const GemmArgs &args) { return args._ci->has_fp16(); },
__anonfd50b0e00e02() 93 [](const GemmArgs &args) { return GemmHybridIndirect<cls_a64_hybrid_fp16_mla_6x32, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
__anonfd50b0e00f02() 94 [](const GemmArgs &args) { return new GemmHybridIndirect<cls_a64_hybrid_fp16_mla_6x32, __fp16, __fp16>(args); }
95 ),
96 GemmImplementation<__fp16, __fp16>::with_estimate(
97 GemmMethod::GEMM_INTERLEAVED,
98 "a64_hgemm_8x24",
__anonfd50b0e01002() 99 [](const GemmArgs &args) { return args._ci->has_fp16(); },
__anonfd50b0e01102() 100 [](const GemmArgs &args) { return GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
__anonfd50b0e01202() 101 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_hgemm_8x24, __fp16, __fp16>(args); }
102 ),
103 #ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
104 GemmImplementation<__fp16, __fp16>::with_estimate(
105 GemmMethod::GEMM_INTERLEAVED,
106 "a64_ffinterleaved_fp16_mla_8x24",
107 KernelWeightFormat::VL128_BL16,
__anonfd50b0e01302() 108 [](const GemmArgs &args) { return args._ci->has_fp16(); },
__anonfd50b0e01402() 109 [](const GemmArgs &args) { return GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp16_mla_8x24, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
__anonfd50b0e01502() 110 [](const GemmArgs &args) { return new GemmInterleavedFixedFormat<cls_a64_ffinterleaved_fp16_mla_8x24, __fp16, __fp16>(args); }
111 ),
112 GemmImplementation<__fp16, __fp16>::with_estimate(
113 GemmMethod::GEMM_HYBRID,
114 "a64_ffhybrid_fp16_mla_6x32",
115 KernelWeightFormat::VL128_BL16,
__anonfd50b0e01602() 116 [](const GemmArgs &args) { return args._ci->has_fp16(); },
__anonfd50b0e01702() 117 [](const GemmArgs &args) { return GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>::estimate_cycles<__fp16>(args); },
__anonfd50b0e01802() 118 [](const GemmArgs &args) { return new GemmHybridIndirectFixedFormat<cls_a64_ffhybrid_fp16_mla_6x32, __fp16, __fp16>(args); }
119 ),
120 #endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
121 {
122 GemmMethod::GEMM_INTERLEAVED,
123 "a64_sgemm_8x12",
124 nullptr,
__anonfd50b0e01902() 125 [](const GemmArgs &args) { return !args._ci->has_fp16(); },
__anonfd50b0e01a02() 126 [](const GemmArgs &args) { return new GemmInterleaved<cls_a64_sgemm_8x12, __fp16, __fp16>(args); }
127 },
128 #elif defined(__arm__)
129 {
130 GemmMethod::GEMM_INTERLEAVED,
131 "sgemm_8x6",
132 nullptr,
133 nullptr,
134 [](const GemmArgs &args) { return new GemmInterleaved<sgemm_8x6, __fp16, __fp16>(args); }
135 },
136 #else // not AArch64 or AArch32
137 # error Unknown Architecture
138 #endif
139 {
140 GemmMethod::DEFAULT,
141 "",
142 nullptr,
143 nullptr,
144 nullptr,
145 }
146 };
147
148 template<>
gemm_implementation_list()149 const GemmImplementation<__fp16, __fp16> *gemm_implementation_list<__fp16, __fp16>() {
150 return gemm_fp16_methods;
151 }
152
153 /* Explicitly instantiate the external functions for these types. */
154 template UniqueGemmCommon<__fp16, __fp16> gemm<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
155 template bool has_opt_gemm<__fp16, __fp16, Nothing>(WeightFormat &weight_format, const GemmArgs &args, const Nothing &);
156 template KernelDescription get_gemm_method<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
157 template std::vector<KernelDescription> get_compatible_kernels<__fp16, __fp16, Nothing>(const GemmArgs &args, const Nothing &);
158
159 } // namespace arm_gemm
160
161 #endif // defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC))
162