1 /*
2 * Copyright (c) 2020-2022 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h"
25
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "arm_compute/core/GPUTarget.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/TensorShape.h"
31 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
32
33 #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
34 #include "src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h"
35
36 #include <utility>
37
38 namespace arm_compute
39 {
40 namespace opencl
41 {
42 namespace kernels
43 {
44 namespace gemm
45 {
46 using namespace arm_compute::misc::shape_calculator;
47
ClGemmDefaultConfigReshapedRhsOnlyValhall(GPUTarget gpu)48 ClGemmDefaultConfigReshapedRhsOnlyValhall::ClGemmDefaultConfigReshapedRhsOnlyValhall(GPUTarget gpu)
49 : IClGemmKernelConfig(gpu)
50 {
51 }
52
configure(unsigned int m,unsigned int n,unsigned int k,unsigned int b,DataType data_type)53 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
54 {
55 using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (ClGemmDefaultConfigReshapedRhsOnlyValhall::*)(unsigned int m, unsigned int n, unsigned int k,
56 unsigned int b);
57
58 CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G77(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32,
59 &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16,
60 &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
61
62 CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G78(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32,
63 &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16,
64 &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
65
66 CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G715(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32,
67 &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16,
68 &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
69
70 ConfigurationFunctionExecutorPtr func = nullptr;
71
72 switch(_target)
73 {
74 case GPUTarget::G78:
75 func = configs_G78.get_function(data_type);
76 break;
77 case GPUTarget::G715:
78 case GPUTarget::G615:
79 func = configs_G715.get_function(data_type);
80 break;
81 case GPUTarget::G77:
82 default:
83 func = configs_G77.get_function(data_type);
84 break;
85 }
86
87 ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM");
88 return (this->*func)(m, n, k, b);
89 }
90
configure_G77_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)91 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
92 {
93 if(m == 1)
94 {
95 const float r_mn = static_cast<float>(m) / static_cast<float>(n);
96 const float r_mk = static_cast<float>(m) / static_cast<float>(k);
97
98 if(r_mk <= 0.0064484127797186375)
99 {
100 if(r_mn <= 0.0028273810748942196)
101 {
102 GEMMLHSMatrixInfo lhs_info_buf;
103 GEMMRHSMatrixInfo rhs_info_buf;
104 GEMMLHSMatrixInfo lhs_info_img;
105 GEMMRHSMatrixInfo rhs_info_img;
106
107 const unsigned int h0 = std::max(n / 4, 1U);
108 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, 0, 1, 0, 0, 1);
109 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, 0, 1, 0, 1, 0);
110
111 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
112 std::make_pair(lhs_info_buf, rhs_info_buf),
113 n, k, b, DataType::F32);
114 }
115 else
116 {
117 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, 0, 1, 0, 0, 0);
118 }
119 }
120 else
121 {
122 if(r_mk <= 0.020312500186264515)
123 {
124 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, 0, 1, 0, 0, 0);
125 }
126 else
127 {
128 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, 0, 1, 0, 1, 0);
129 }
130 }
131 }
132 else
133 {
134 const float r_mn = static_cast<float>(m) / static_cast<float>(n);
135 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
136 const float r_mk = static_cast<float>(m) / static_cast<float>(k);
137
138 if(workload <= 1999.2000122070312)
139 {
140 if(workload <= 747.1999816894531)
141 {
142 return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
143 }
144 else
145 {
146 GEMMLHSMatrixInfo lhs_info_buf;
147 GEMMRHSMatrixInfo rhs_info_buf;
148 GEMMLHSMatrixInfo lhs_info_img;
149 GEMMRHSMatrixInfo rhs_info_img;
150 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, 0, 0, 0, 1, 1);
151 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
152
153 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
154 std::make_pair(lhs_info_buf, rhs_info_buf),
155 n, k, b, DataType::F32);
156 }
157 }
158 else
159 {
160 if(r_mn <= 0.03348214365541935)
161 {
162 if(r_mk <= 0.028125000186264515)
163 {
164 return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
165 }
166 else
167 {
168 GEMMLHSMatrixInfo lhs_info_buf;
169 GEMMRHSMatrixInfo rhs_info_buf;
170 GEMMLHSMatrixInfo lhs_info_img;
171 GEMMRHSMatrixInfo rhs_info_img;
172 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, 0, 0, 0, 1, 1);
173 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
174
175 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
176 std::make_pair(lhs_info_buf, rhs_info_buf),
177 n, k, b, DataType::F32);
178 }
179 }
180 else
181 {
182 GEMMLHSMatrixInfo lhs_info_buf;
183 GEMMRHSMatrixInfo rhs_info_buf;
184 GEMMLHSMatrixInfo lhs_info_img;
185 GEMMRHSMatrixInfo rhs_info_img;
186 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, 0, 1, 0, 0, 1);
187 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, 0, 1, 0, 1, 0);
188
189 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
190 std::make_pair(lhs_info_buf, rhs_info_buf),
191 n, k, b, DataType::F32);
192 }
193 }
194 }
195 }
196
configure_G77_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)197 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
198 {
199 ARM_COMPUTE_UNUSED(k);
200 ARM_COMPUTE_UNUSED(b);
201
202 if(m == 1)
203 {
204 const unsigned int h0 = std::max(n / 2, 1U);
205 if(n <= 836.0)
206 {
207 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, 0, 1, 0, 1, 0);
208 }
209 else
210 {
211 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, 0, 1, 0, 1, 0);
212 }
213 }
214 else if(m < 128)
215 {
216 const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
217 if(k >= 512)
218 {
219 return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 0);
220 }
221 else
222 {
223 return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, 0, 1, 0, 0);
224 }
225 }
226 else
227 {
228 const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
229 if(n >= 64)
230 {
231 return configure_lhs_rhs_info(m, n, 4, 8, 4, 1, h0, 0, 1, 0, 0);
232 }
233 else
234 {
235 if(k >= 512)
236 {
237 return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 0);
238 }
239 else
240 {
241 return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, 0, 1, 0, 0);
242 }
243 }
244 }
245 }
246
configure_G77_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)247 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
248 {
249 ARM_COMPUTE_UNUSED(k);
250 ARM_COMPUTE_UNUSED(b);
251
252 if(m == 1)
253 {
254 const unsigned int h0 = std::max(n / 2, 1U);
255 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, 0, 1, 0, 1);
256 }
257 else
258 {
259 const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
260 if(m >= 28)
261 {
262 return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, 0, 1, 0, 1);
263 }
264 else
265 {
266 return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 1);
267 }
268 }
269 }
270
configure_G78_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)271 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
272 {
273 const float r_mn = static_cast<float>(m) / static_cast<float>(n);
274 const float r_mk = static_cast<float>(m) / static_cast<float>(k);
275 const float r_nk = static_cast<float>(n) / static_cast<float>(k);
276 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
277
278 if(m == 1)
279 {
280 if(workload <= 278.7000f)
281 {
282 if(workload <= 7.5000f)
283 {
284 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
285 }
286 else
287 {
288 if(r_mn <= 0.0031f)
289 {
290 if(workload <= 256.6000f)
291 {
292 if(workload <= 16.7500f)
293 {
294 if(r_nk <= 1.6671f)
295 {
296 return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
297 }
298 else
299 {
300 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
301 }
302 }
303 else
304 {
305 return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
306 }
307 }
308 else
309 {
310 return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
311 }
312 }
313 else
314 {
315 if(r_mk <= 0.0027f)
316 {
317 if(r_mk <= 0.0014f)
318 {
319 return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
320 }
321 else
322 {
323 if(workload <= 8.9500f)
324 {
325 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
326 }
327 else
328 {
329 return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
330 }
331 }
332 }
333 else
334 {
335 if(workload <= 14.1500f)
336 {
337 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
338 }
339 else
340 {
341 if(r_mk <= 0.0041f)
342 {
343 return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
344 }
345 else
346 {
347 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
348 }
349 }
350 }
351 }
352 }
353 }
354 else
355 {
356 if(workload <= 363.7000f)
357 {
358 if(r_mk <= 0.0031f)
359 {
360 return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 32, 0, 1, 0, 1, 0);
361 }
362 else
363 {
364 return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 32, 0, 1, 0, 1, 0);
365 }
366 }
367 else
368 {
369 return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 32, 0, 1, 0, 1, 0);
370 }
371 }
372 }
373 else
374 {
375 if(workload <= 1384.8000f)
376 {
377 if(workload <= 704.0000f)
378 {
379 return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 32, 0, 1, 0, 1, 0);
380 }
381 else
382 {
383 return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 4, 0, 0, 0, 1, 1);
384 }
385 }
386 else
387 {
388 if(workload <= 16761.6006f)
389 {
390 if(r_mn <= 187.1250f)
391 {
392 return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, 0, 0, 0, 1, 1);
393 }
394 else
395 {
396 return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 4, 0, 0, 0, 1, 1);
397 }
398 }
399 else
400 {
401 if(r_mk <= 432.4630f)
402 {
403 return configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 16, 0, 0, 0, 1, 1);
404 }
405 else
406 {
407 return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 16, 0, 1, 0, 1, 1);
408 }
409 }
410 }
411 }
412 }
413
configure_G78_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)414 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
415 {
416 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
417 const float r_mn = static_cast<float>(m) / static_cast<float>(n);
418 const float r_mk = static_cast<float>(m) / static_cast<float>(k);
419 const float r_nk = static_cast<float>(n) / static_cast<float>(k);
420
421 if(m == 1)
422 {
423 if(r_mn <= 0.0045f)
424 {
425 if(workload <= 278.7000f)
426 {
427 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, 0, 0, 0, 1, 1);
428 }
429 else
430 {
431 return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 32, 0, 0, 1, 0, 0);
432 }
433 }
434 else
435 {
436 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, 0, 0, 1, 0, 0);
437 }
438 }
439 else
440 {
441 if(workload <= 1384.8000f)
442 {
443 if(r_nk <= 0.8333f)
444 {
445 if(r_mk <= 0.9119f)
446 {
447 return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 4, 0, 1, 0, 1, 1);
448 }
449 else
450 {
451 if(r_nk <= 0.1181f)
452 {
453 return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 32, 0, 0, 1, 0, 0);
454 }
455 else
456 {
457 return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0);
458 }
459 }
460 }
461 else
462 {
463 if(r_mk <= 1.0013f)
464 {
465 return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 1);
466 }
467 else
468 {
469 return configure_lhs_rhs_info(m, n, 5, 4, 8, 1, 4, 0, 1, 1, 0, 1);
470 }
471 }
472 }
473 else
474 {
475 if(workload <= 11404.7998f)
476 {
477 if(r_mk <= 2.2884f)
478 {
479 if(r_nk <= 0.9286f)
480 {
481 return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 4, 0, 1, 1, 0, 1);
482 }
483 else
484 {
485 return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 1);
486 }
487 }
488 else
489 {
490 return configure_lhs_rhs_info(m, n, 5, 4, 8, 1, 4, 0, 1, 1, 0, 1);
491 }
492 }
493 else
494 {
495 if(r_nk <= 1.1926f)
496 {
497 if(r_mn <= 1385.7917f)
498 {
499 return configure_lhs_rhs_info(m, n, 6, 4, 8, 1, 4, 0, 1, 1, 0, 1);
500 }
501 else
502 {
503 return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 32, 0, 1, 1, 0, 0);
504 }
505 }
506 else
507 {
508 return configure_lhs_rhs_info(m, n, 6, 4, 8, 1, 32, 0, 1, 1, 0, 1);
509 }
510 }
511 }
512 }
513 }
514
configure_G715_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)515 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
516 {
517 unsigned int best_m0;
518 unsigned int best_n0;
519
520 if(is_mmul_kernel_preferred(m, n, k, b, DataType::F32, best_m0, best_n0))
521 {
522 return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4, false, true, false, false, true);
523 }
524 else
525 {
526 return configure_G77_f32(m, n, k, b);
527 }
528 }
529
configure_G715_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)530 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
531 {
532 unsigned int best_m0;
533 unsigned int best_n0;
534
535 if(is_mmul_kernel_preferred(m, n, k, b, DataType::F16, best_m0, best_n0))
536 {
537 return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4, false, true, false, false, true);
538 }
539 else
540 {
541 return configure_G78_f16(m, n, k, b);
542 }
543 }
544 } // namespace gemm
545 } // namespace kernels
546 } // namespace opencl
547 } // namespace arm_compute
548