1 /*
2 * Copyright (c) 2020-2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "src/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedValhall.h"
25
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "arm_compute/core/GPUTarget.h"
29 #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
30
31 #include <utility>
32
33 namespace arm_compute
34 {
35 namespace opencl
36 {
37 namespace kernels
38 {
39 namespace gemm
40 {
ClGemmDefaultConfigReshapedValhall(GPUTarget gpu)41 ClGemmDefaultConfigReshapedValhall::ClGemmDefaultConfigReshapedValhall(GPUTarget gpu)
42 : IClGemmKernelConfig(gpu)
43 {
44 }
45
configure(unsigned int m,unsigned int n,unsigned int k,unsigned int b,DataType data_type)46 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
47 {
48 using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (ClGemmDefaultConfigReshapedValhall::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
49
50 CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G77(&ClGemmDefaultConfigReshapedValhall::configure_G77_f32,
51 &ClGemmDefaultConfigReshapedValhall::configure_G77_f16,
52 &ClGemmDefaultConfigReshapedValhall::configure_G77_u8);
53
54 CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G78(&ClGemmDefaultConfigReshapedValhall::configure_G78_f32,
55 &ClGemmDefaultConfigReshapedValhall::configure_G78_f16,
56 &ClGemmDefaultConfigReshapedValhall::configure_G77_u8);
57
58 ConfigurationFunctionExecutorPtr func = nullptr;
59
60 switch(_target)
61 {
62 case GPUTarget::G78:
63 func = configs_G78.get_function(data_type);
64 break;
65 case GPUTarget::G77:
66 default:
67 func = configs_G77.get_function(data_type);
68 break;
69 }
70
71 ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM");
72 return (this->*func)(m, n, k, b);
73 }
74
configure_G77_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)75 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
76 {
77 ARM_COMPUTE_UNUSED(k);
78 ARM_COMPUTE_UNUSED(b);
79
80 if(n <= 4)
81 {
82 return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, 1, 0, 0, 1);
83 }
84 else
85 {
86 return configure_lhs_rhs_info(m, n, 5, 4, 4, 2, 16, 0, 1, 0, 1);
87 }
88 }
89
configure_G77_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)90 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
91 {
92 ARM_COMPUTE_UNUSED(k);
93 ARM_COMPUTE_UNUSED(b);
94
95 const float r_mn = static_cast<float>(m) / static_cast<float>(n);
96 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
97 const float r_mk = static_cast<float>(m) / static_cast<float>(k);
98 const float r_nk = static_cast<float>(n) / static_cast<float>(k);
99
100 GEMMLHSMatrixInfo lhs_info_buf;
101 GEMMRHSMatrixInfo rhs_info_buf;
102 GEMMLHSMatrixInfo lhs_info_img;
103 GEMMRHSMatrixInfo rhs_info_img;
104
105 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 0);
106
107 if(r_mk <= 0.11824845522642136)
108 {
109 if(workload <= 880.0)
110 {
111 return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0);
112 }
113 else
114 {
115 if(r_nk <= 0.42521367967128754)
116 {
117 if(workload <= 1726.4000244140625)
118 {
119 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 0);
120 }
121 else
122 {
123 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
124
125 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
126 std::make_pair(lhs_info_buf, rhs_info_buf),
127 n, k, b, DataType::F16);
128 }
129 }
130 else
131 {
132 if(workload <= 1241.6000366210938)
133 {
134 return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0);
135 }
136 else
137 {
138 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 0);
139 }
140 }
141 }
142 }
143 else
144 {
145 if(workload <= 11404.7998046875)
146 {
147 if(r_mk <= 1.0126488208770752)
148 {
149 if(r_mn <= 2.545312523841858)
150 {
151 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
152
153 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
154 std::make_pair(lhs_info_buf, rhs_info_buf),
155 n, k, b, DataType::F16);
156 }
157 else
158 {
159 return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0);
160 }
161 }
162 else
163 {
164 if(workload <= 2881.199951171875)
165 {
166 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, 0, 0, 1, 0, 1);
167
168 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
169 std::make_pair(lhs_info_buf, rhs_info_buf),
170 n, k, b, DataType::F16);
171 }
172 else
173 {
174 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
175
176 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
177 std::make_pair(lhs_info_buf, rhs_info_buf),
178 n, k, b, DataType::F16);
179 }
180 }
181 }
182 else
183 {
184 if(r_nk <= 0.5765306055545807)
185 {
186 if(r_mn <= 6.010416746139526)
187 {
188 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
189
190 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
191 std::make_pair(lhs_info_buf, rhs_info_buf),
192 n, k, b, DataType::F16);
193 }
194 else
195 {
196 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 1, 0, 1, 0, 1);
197
198 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
199 std::make_pair(lhs_info_buf, rhs_info_buf),
200 n, k, b, DataType::F16);
201 }
202 }
203 else
204 {
205 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 1, 0, 1, 0, 1);
206
207 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
208 std::make_pair(lhs_info_buf, rhs_info_buf),
209 n, k, b, DataType::F16);
210 }
211 }
212 }
213 }
214
configure_G78_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)215 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
216 {
217 const float r_mn = static_cast<float>(m) / static_cast<float>(n);
218 const float r_mk = static_cast<float>(m) / static_cast<float>(k);
219 const float r_nk = static_cast<float>(n) / static_cast<float>(k);
220 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
221
222 if(workload <= 1288.0000f)
223 {
224 if(workload <= 505.6000f)
225 {
226 if(r_mn <= 0.4466f)
227 {
228 if(r_nk <= 0.2384f)
229 {
230 return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
231 }
232 else
233 {
234 return configure_lhs_rhs_info(m, n, 2, 2, 4, 2, 2, 0, 0, 1, 0, 0);
235 }
236 }
237 else
238 {
239 return configure_lhs_rhs_info(m, n, 2, 2, 4, 2, 2, 0, 0, 1, 0, 0);
240 }
241 }
242 else
243 {
244 if(r_mn <= 0.2250f)
245 {
246 if(r_mn <= 0.1599f)
247 {
248 return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
249 }
250 else
251 {
252 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
253 }
254 }
255 else
256 {
257 if(r_mk <= 0.7609f)
258 {
259 if(r_mn <= 2.5453f)
260 {
261 if(workload <= 1089.6000f)
262 {
263 return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
264 }
265 else
266 {
267 return configure_lhs_rhs_info(m, n, 2, 4, 8, 2, 4, 0, 0, 1, 0, 1);
268 }
269 }
270 else
271 {
272 return configure_lhs_rhs_info(m, n, 2, 4, 16, 4, 4, 0, 0, 1, 0, 1);
273 }
274 }
275 else
276 {
277 return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
278 }
279 }
280 }
281 }
282 else
283 {
284 if(workload <= 5434.4001f)
285 {
286 if(workload <= 1603.2000f)
287 {
288 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
289 }
290 else
291 {
292 if(r_nk <= 0.6192f)
293 {
294 if(r_mn <= 16.1016f)
295 {
296 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
297 }
298 else
299 {
300 if(workload <= 2750.0000f)
301 {
302 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
303 }
304 else
305 {
306 if(r_mk <= 6.3151f)
307 {
308 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
309 }
310 else
311 {
312 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
313 }
314 }
315 }
316 }
317 else
318 {
319 if(r_mk <= 0.0387f)
320 {
321 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
322 }
323 else
324 {
325 if(r_mk <= 2.5859f)
326 {
327 if(r_mk <= 0.2734f)
328 {
329 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
330 }
331 else
332 {
333 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
334 }
335 }
336 else
337 {
338 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
339 }
340 }
341 }
342 }
343 }
344 else
345 {
346 if(r_mk <= 25.7500f)
347 {
348 if(r_mk <= 0.3615f)
349 {
350 if(r_mn <= 0.0913f)
351 {
352 if(r_mk <= 0.0683f)
353 {
354 return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1);
355 }
356 else
357 {
358 return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
359 }
360 }
361 else
362 {
363 return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
364 }
365 }
366 else
367 {
368 if(workload <= 11174.3999f)
369 {
370 if(r_mk <= 0.8047f)
371 {
372 return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
373 }
374 else
375 {
376 if(workload <= 7185.5999f)
377 {
378 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
379 }
380 else
381 {
382 return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1);
383 }
384 }
385 }
386 else
387 {
388 if(workload <= 17917.5000f)
389 {
390 if(r_mk <= 1.5078f)
391 {
392 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
393 }
394 else
395 {
396 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
397 }
398 }
399 else
400 {
401 if(workload <= 34449.6016f)
402 {
403 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
404 }
405 else
406 {
407 return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 4, 0, 0, 1, 0, 1);
408 }
409 }
410 }
411 }
412 }
413 else
414 {
415 if(r_mk <= 331.1111f)
416 {
417 if(workload <= 53397.5996f)
418 {
419 if(r_mn <= 57.8063f)
420 {
421 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
422 }
423 else
424 {
425 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
426 }
427 }
428 else
429 {
430 if(r_nk <= 0.9211f)
431 {
432 return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1);
433 }
434 else
435 {
436 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
437 }
438 }
439 }
440 else
441 {
442 if(workload <= 38070.4004f)
443 {
444 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
445 }
446 else
447 {
448 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
449 }
450 }
451 }
452 }
453 }
454 }
455
configure_G78_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)456 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
457 {
458 const float r_mn = static_cast<float>(m) / static_cast<float>(n);
459 const float r_nk = static_cast<float>(n) / static_cast<float>(k);
460 const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
461
462 if(workload <= 801.6000f)
463 {
464 return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
465 }
466 else
467 {
468 if(r_mn <= 0.1211f)
469 {
470 if(workload <= 3296.0000f)
471 {
472 return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
473 }
474 else
475 {
476 if(r_nk <= 1.0625f)
477 {
478 return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
479 }
480 else
481 {
482 return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 4, 0, 0, 1, 0, 1);
483 }
484 }
485 }
486 else
487 {
488 if(workload <= 5068.8000f)
489 {
490 return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
491 }
492 else
493 {
494 if(r_nk <= 0.2361f)
495 {
496 if(workload <= 12630.0000f)
497 {
498 return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
499 }
500 else
501 {
502 return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 1, 0, 0, 1, 0, 1);
503 }
504 }
505 else
506 {
507 if(workload <= 178790.3984f)
508 {
509 return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
510 }
511 else
512 {
513 return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
514 }
515 }
516 }
517 }
518 }
519 }
520
configure_G77_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)521 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
522 {
523 ARM_COMPUTE_UNUSED(k);
524 ARM_COMPUTE_UNUSED(b);
525
526 if(n <= 4)
527 {
528 return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1, 0, 0, 0, 1);
529 }
530 else
531 {
532 return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, 0, 1, 0, 1);
533 }
534 }
535 } // namespace gemm
536 } // namespace kernels
537 } // namespace opencl
538 } // namespace arm_compute
539