1 /*
2 * Copyright (c) 2019-2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "src/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedBifrost.h"
25
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "arm_compute/core/GPUTarget.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/TensorShape.h"
31 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
32 #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
33
34 #include <utility>
35
36 namespace arm_compute
37 {
38 namespace opencl
39 {
40 namespace kernels
41 {
42 namespace gemm
43 {
44 using namespace arm_compute::misc::shape_calculator;
45
ClGemmDefaultConfigReshapedBifrost(GPUTarget gpu)46 ClGemmDefaultConfigReshapedBifrost::ClGemmDefaultConfigReshapedBifrost(GPUTarget gpu)
47 : IClGemmKernelConfig(gpu)
48 {
49 }
50
configure(unsigned int m,unsigned int n,unsigned int k,unsigned int b,DataType data_type)51 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
52 {
53 using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (ClGemmDefaultConfigReshapedBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
54
55 CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G7x(&ClGemmDefaultConfigReshapedBifrost::configure_G7x_f32,
56 &ClGemmDefaultConfigReshapedBifrost::configure_G7x_f16,
57 &ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8);
58
59 CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G52(&ClGemmDefaultConfigReshapedBifrost::configure_G52_f32,
60 &ClGemmDefaultConfigReshapedBifrost::configure_G52_f16,
61 &ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8);
62
63 CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G76(&ClGemmDefaultConfigReshapedBifrost::configure_G76_f32,
64 &ClGemmDefaultConfigReshapedBifrost::configure_G76_f16,
65 &ClGemmDefaultConfigReshapedBifrost::configure_G76_u8);
66
67 ConfigurationFunctionExecutorPtr func = nullptr;
68
69 switch(_target)
70 {
71 case GPUTarget::G76:
72 func = configs_G76.get_function(data_type);
73 break;
74 case GPUTarget::G52:
75 func = configs_G52.get_function(data_type);
76 break;
77 default:
78 func = configs_G7x.get_function(data_type);
79 break;
80 }
81
82 ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM");
83 return (this->*func)(m, n, k, b);
84 }
85
configure_G7x_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)86 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
87 {
88 ARM_COMPUTE_UNUSED(k);
89 ARM_COMPUTE_UNUSED(b);
90
91 if(n <= 4)
92 {
93 return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, true, false, false, true);
94 }
95 else
96 {
97 return configure_lhs_rhs_info(m, n, 5, 4, 4, 2, 16, false, true, false, true);
98 }
99 }
100
configure_G7x_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)101 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
102 {
103 ARM_COMPUTE_UNUSED(k);
104 ARM_COMPUTE_UNUSED(b);
105
106 if(n <= 4)
107 {
108 return configure_lhs_rhs_info(m, n, 4, 2, 8, 8, 2, true, true, true, false);
109 }
110 else
111 {
112 return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false);
113 }
114 }
115
configure_G7x_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)116 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
117 {
118 ARM_COMPUTE_UNUSED(k);
119 ARM_COMPUTE_UNUSED(b);
120
121 if(dot8_supported(CLKernelLibrary::get().get_device()))
122 {
123 if(n <= 4)
124 {
125 return configure_lhs_rhs_info(m, n, 4, 2, 16, 2, 2, true, false, false, true);
126 }
127 else
128 {
129 return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, true, false, false, true);
130 }
131 }
132 else
133 {
134 if(n <= 4)
135 {
136 return configure_lhs_rhs_info(m, n, 4, 2, 8, 2, 2, true, false, false, true);
137 }
138 else
139 {
140 return configure_lhs_rhs_info(m, n, 6, 4, 4, 2, 2, true, true, false, true);
141 }
142 }
143 }
144
configure_G52_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)145 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
146 {
147 const float r_mn = static_cast<float>(m) / static_cast<float>(n);
148 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
149 const float r_mk = static_cast<float>(m) / static_cast<float>(k);
150 const float r_nk = static_cast<float>(n) / static_cast<float>(k);
151
152 GEMMLHSMatrixInfo lhs_info_buf;
153 GEMMRHSMatrixInfo rhs_info_buf;
154 GEMMLHSMatrixInfo lhs_info_img;
155 GEMMRHSMatrixInfo rhs_info_img;
156
157 if(workload <= 274.4000f)
158 {
159 if(r_nk <= 0.7461f)
160 {
161 if(r_mn <= 21.1667f)
162 {
163 return configure_lhs_rhs_info(m, n, 4, 2, 4, 4, 4, false, true, true, false, false);
164 }
165 else
166 {
167 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
168 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
169
170 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
171 std::make_pair(lhs_info_buf, rhs_info_buf),
172 n, k, b, DataType::F32);
173 }
174 }
175 else
176 {
177 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
178 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
179
180 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
181 std::make_pair(lhs_info_buf, rhs_info_buf),
182 n, k, b, DataType::F32);
183 }
184 }
185 else
186 {
187 if(r_mk <= 17.3926f)
188 {
189 if(workload <= 542.4000f)
190 {
191 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
192 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
193
194 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
195 std::make_pair(lhs_info_buf, rhs_info_buf),
196 n, k, b, DataType::F32);
197 }
198 else
199 {
200 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true);
201 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false);
202
203 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
204 std::make_pair(lhs_info_buf, rhs_info_buf),
205 n, k, b, DataType::F32);
206 }
207 }
208 else
209 {
210 if(r_nk <= 0.5463f)
211 {
212 if(workload <= 11767.6001f)
213 {
214 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
215 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
216
217 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
218 std::make_pair(lhs_info_buf, rhs_info_buf),
219 n, k, b, DataType::F32);
220 }
221 else
222 {
223 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true);
224 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false);
225
226 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
227 std::make_pair(lhs_info_buf, rhs_info_buf),
228 n, k, b, DataType::F32);
229 }
230 }
231 else
232 {
233 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
234 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
235
236 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
237 std::make_pair(lhs_info_buf, rhs_info_buf),
238 n, k, b, DataType::F32);
239 }
240 }
241 }
242 }
243
configure_G52_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)244 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
245 {
246 ARM_COMPUTE_UNUSED(k);
247
248 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
249
250 if(workload <= 323.4000f)
251 {
252 return configure_lhs_rhs_info(m, n, 2, 2, 8, 4, 8, false, false, false, true, false);
253 }
254 else
255 {
256 return configure_lhs_rhs_info(m, n, 4, 8, 4, 2, 2, true, true, true, false, false);
257 }
258 }
259
configure_G76_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)260 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
261 {
262 ARM_COMPUTE_UNUSED(k);
263 ARM_COMPUTE_UNUSED(b);
264
265 GEMMLHSMatrixInfo lhs_info_buf;
266 GEMMRHSMatrixInfo rhs_info_buf;
267 GEMMLHSMatrixInfo lhs_info_img;
268 GEMMRHSMatrixInfo rhs_info_img;
269
270 // Get lhs_info/rhs_info in case of OpenCL buffer
271 if(n <= 4)
272 {
273 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, true, false, false, true);
274 }
275 else
276 {
277 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 2, 8, 16, false, false, false, true);
278 }
279
280 // Get lhs_info/rhs_info in case of OpenCL image
281 // Condition on the GPU workload
282 if((m / 4) * (n / 4) >= 2560)
283 {
284 // Big workload
285 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 8, true, true, true, false, true);
286 }
287 else
288 {
289 // Small workload
290 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 1, true, true, true, false, true);
291 }
292
293 const TensorInfo tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32);
294 const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img);
295 const TensorInfo tensor_reshaped_info(shape, 1, DataType::F32);
296
297 // In case of vector by matrix with few work-items, we use the OpenCL buffer rather than the OpenCL image2d
298 const bool use_cl_image2d = (n <= 4) ? false : true;
299
300 if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d)
301 {
302 return std::make_pair(lhs_info_img, rhs_info_img);
303 }
304 else
305 {
306 return std::make_pair(lhs_info_buf, rhs_info_buf);
307 }
308 }
309
configure_G76_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)310 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
311 {
312 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
313 const float r_mk = static_cast<float>(m) / static_cast<float>(k);
314
315 if(workload <= 1595.2000f)
316 {
317 if(r_mk <= 2.1044f)
318 {
319 if(workload <= 870.4000f)
320 {
321 return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 2, true, false, true, false, false);
322 }
323 else
324 {
325 return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2, false, false, true, false, false);
326 }
327 }
328 else
329 {
330 return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2, false, false, true, false, false);
331 }
332 }
333 else
334 {
335 return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false, false);
336 }
337 }
338
configure_G76_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)339 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
340 {
341 ARM_COMPUTE_UNUSED(k);
342 ARM_COMPUTE_UNUSED(b);
343
344 if(n <= 4)
345 {
346 return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1, false, false, false, true);
347 }
348 else
349 {
350 return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, false, true, false, true);
351 }
352 }
353 } // namespace gemm
354 } // namespace kernels
355 } // namespace opencl
356 } // namespace arm_compute
357