1 /*
2 * Copyright (c) 2020-2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "src/runtime/CL/gemm/CLGEMMDefaultTypeBifrost.h"
25
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
29
30 #include <map>
31 #include <utility>
32
33 namespace arm_compute
34 {
35 namespace cl_gemm
36 {
CLGEMMDefaultTypeBifrost(GPUTarget gpu)37 CLGEMMDefaultTypeBifrost::CLGEMMDefaultTypeBifrost(GPUTarget gpu)
38 : ICLGEMMKernelSelection(gpu)
39 {
40 }
41
select_kernel(const CLGEMMKernelSelectionParams & params)42 CLGEMMKernelType CLGEMMDefaultTypeBifrost::select_kernel(const CLGEMMKernelSelectionParams ¶ms)
43 {
44 // _target could be used in the future to have a dedicated heuristic for each GPU IP
45 ARM_COMPUTE_UNUSED(_target);
46
47 using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMDefaultTypeBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
48
49 // Default configurations for Bifrost architectures
50 static std::map<DataType, FunctionExecutorPtr> gemm_default_configs =
51 {
52 { DataType::F32, &CLGEMMDefaultTypeBifrost::default_f32 },
53 { DataType::F16, &CLGEMMDefaultTypeBifrost::default_f16 },
54 { DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
55 { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8 },
56 { DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
57 { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8 }
58 };
59
60 // Mali-G71 configurations
61 static std::map<DataType, FunctionExecutorPtr> gemm_g71_configs =
62 {
63 { DataType::F32, &CLGEMMDefaultTypeBifrost::default_f32 },
64 { DataType::F16, &CLGEMMDefaultTypeBifrost::g71_f16 },
65 { DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
66 { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8 },
67 { DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
68 { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8 }
69 };
70
71 // Mali-G52 configurations
72 static std::map<DataType, FunctionExecutorPtr> gemm_g52_configs =
73 {
74 { DataType::F32, &CLGEMMDefaultTypeBifrost::g52_f32 },
75 { DataType::F16, &CLGEMMDefaultTypeBifrost::g52_f16 },
76 { DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
77 { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8 },
78 { DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
79 { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8 }
80 };
81
82 // Mali-G76 configurations
83 static std::map<DataType, FunctionExecutorPtr> gemm_g76_configs =
84 {
85 { DataType::F32, &CLGEMMDefaultTypeBifrost::g76_f32 },
86 { DataType::F16, &CLGEMMDefaultTypeBifrost::g76_f16 },
87 { DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
88 { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8 },
89 { DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
90 { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8 }
91 };
92
93 const DataType data_type = params.data_type;
94
95 switch(_target)
96 {
97 case GPUTarget::G71:
98 if(gemm_g71_configs.find(data_type) != gemm_g71_configs.end())
99 {
100 return (this->*gemm_g71_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
101 }
102 ARM_COMPUTE_ERROR("Not supported data type");
103 case GPUTarget::G76:
104 if(gemm_g76_configs.find(data_type) != gemm_g76_configs.end())
105 {
106 return (this->*gemm_g76_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
107 }
108 ARM_COMPUTE_ERROR("Not supported data type");
109 case GPUTarget::G52:
110 if(gemm_g52_configs.find(data_type) != gemm_g52_configs.end())
111 {
112 return (this->*gemm_g52_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
113 }
114 ARM_COMPUTE_ERROR("Not supported data type");
115 default:
116 if(gemm_default_configs.find(data_type) != gemm_default_configs.end())
117 {
118 return (this->*gemm_default_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
119 }
120 ARM_COMPUTE_ERROR("Not supported data type");
121 }
122 }
123
default_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)124 CLGEMMKernelType CLGEMMDefaultTypeBifrost::default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
125 {
126 ARM_COMPUTE_UNUSED(b);
127
128 CLGEMMKernelType gemm_type = CLGEMMKernelType::NATIVE;
129
130 if(is_rhs_constant)
131 {
132 if((m > 1) && (n < 16))
133 {
134 gemm_type = CLGEMMKernelType::RESHAPED;
135 }
136 else if(m == 1)
137 {
138 gemm_type = CLGEMMKernelType::RESHAPED_ONLY_RHS;
139 }
140 else
141 {
142 if((k > 256) && (m > 4))
143 {
144 constexpr float alpha = 3.2f;
145 constexpr float fact0 = 1.51f;
146 constexpr float fact1 = 1.66f;
147 constexpr float ops = 12.0f;
148 const float scale = k > 1024 ? 1.07f : 1.0f;
149 gemm_type = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? CLGEMMKernelType::RESHAPED : CLGEMMKernelType::RESHAPED_ONLY_RHS;
150 }
151 else
152 {
153 gemm_type = CLGEMMKernelType::RESHAPED_ONLY_RHS;
154 }
155 }
156
157 const auto workload = static_cast<float>((m * n) / 20.0f);
158
159 gemm_type = ((workload > 1600.0f) && (gemm_type == CLGEMMKernelType::RESHAPED)) ? CLGEMMKernelType::RESHAPED : gemm_type;
160 }
161
162 return gemm_type;
163 }
164
default_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)165 CLGEMMKernelType CLGEMMDefaultTypeBifrost::default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
166 {
167 ARM_COMPUTE_UNUSED(n, k, b);
168
169 if(is_rhs_constant)
170 {
171 if(m == 1)
172 {
173 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
174 }
175 else
176 {
177 return CLGEMMKernelType::RESHAPED;
178 }
179 }
180 else
181 {
182 return CLGEMMKernelType::NATIVE;
183 }
184 }
185
default_q8(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)186 CLGEMMKernelType CLGEMMDefaultTypeBifrost::default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
187 {
188 ARM_COMPUTE_UNUSED(m, n, k, b);
189
190 if(is_rhs_constant)
191 {
192 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
193 }
194 else
195 {
196 return CLGEMMKernelType::NATIVE;
197 }
198 }
199
g76_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)200 CLGEMMKernelType CLGEMMDefaultTypeBifrost::g76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
201 {
202 ARM_COMPUTE_UNUSED(b);
203
204 if(!is_rhs_constant)
205 {
206 return CLGEMMKernelType::NATIVE;
207 }
208 if(m == 1)
209 {
210 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
211 }
212 if(k <= 496)
213 {
214 if(n <= 544)
215 {
216 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
217 }
218 else
219 {
220 return CLGEMMKernelType::RESHAPED;
221 }
222 }
223 else
224 {
225 if(k <= 588)
226 {
227 if(k <= 552)
228 {
229 if(m <= 148)
230 {
231 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
232 }
233 else
234 {
235 if(m <= 278)
236 {
237 return CLGEMMKernelType::RESHAPED;
238 }
239 else
240 {
241 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
242 }
243 }
244 }
245 else
246 {
247 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
248 }
249 }
250 else
251 {
252 return CLGEMMKernelType::RESHAPED;
253 }
254 }
255 }
256
g52_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)257 CLGEMMKernelType CLGEMMDefaultTypeBifrost::g52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
258 {
259 ARM_COMPUTE_UNUSED(b);
260
261 if(!is_rhs_constant)
262 {
263 return CLGEMMKernelType::NATIVE;
264 }
265
266 if(m == 1)
267 {
268 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
269 }
270
271 const float r_mn = static_cast<float>(m) / static_cast<float>(n);
272 const float r_mk = static_cast<float>(m) / static_cast<float>(k);
273 const float r_nk = static_cast<float>(n) / static_cast<float>(k);
274 const float r_mnk = static_cast<float>(m) / (static_cast<float>(n) * static_cast<float>(k));
275
276 if(r_mn <= 1.5469f)
277 {
278 if(r_mk <= 0.8766f)
279 {
280 if(r_mk <= 0.0211f)
281 {
282 if(r_mnk <= 77.5833f)
283 {
284 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
285 }
286 else
287 {
288 return CLGEMMKernelType::RESHAPED;
289 }
290 }
291 else
292 {
293 if(r_nk <= 0.0832f)
294 {
295 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
296 }
297 else
298 {
299 return CLGEMMKernelType::RESHAPED;
300 }
301 }
302 }
303 else
304 {
305 if(r_mnk <= 193.0000f)
306 {
307 if(r_mn <= 0.9948f)
308 {
309 if(r_mk <= 2.5453f)
310 {
311 return CLGEMMKernelType::RESHAPED;
312 }
313 else
314 {
315 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
316 }
317 }
318 else
319 {
320 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
321 }
322 }
323 else
324 {
325 return CLGEMMKernelType::RESHAPED;
326 }
327 }
328 }
329 else
330 {
331 if(r_mn <= 17.7370f)
332 {
333 if(r_mnk <= 1391.2875f)
334 {
335 if(r_mk <= 2.9724f)
336 {
337 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
338 }
339 else
340 {
341 if(r_mnk <= 470.0000f)
342 {
343 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
344 }
345 else
346 {
347 return CLGEMMKernelType::RESHAPED;
348 }
349 }
350 }
351 else
352 {
353 if(r_nk <= 0.1381f)
354 {
355 if(r_mnk <= 9040.5000f)
356 {
357 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
358 }
359 else
360 {
361 return CLGEMMKernelType::RESHAPED;
362 }
363 }
364 else
365 {
366 if(r_mn <= 5.6790f)
367 {
368 return CLGEMMKernelType::RESHAPED;
369 }
370 else
371 {
372 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
373 }
374 }
375 }
376 }
377 else
378 {
379 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
380 }
381 }
382 }
383
g76_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)384 CLGEMMKernelType CLGEMMDefaultTypeBifrost::g76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
385 {
386 ARM_COMPUTE_UNUSED(b);
387
388 if(!is_rhs_constant)
389 {
390 return CLGEMMKernelType::NATIVE;
391 }
392
393 if(m == 1)
394 {
395 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
396 }
397
398 const float r_mn = static_cast<float>(m) / static_cast<float>(n);
399 const float r_nk = static_cast<float>(n) / static_cast<float>(k);
400
401 if(k <= 212)
402 {
403 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
404 }
405 else
406 {
407 if(r_nk <= 0.4990234375f)
408 {
409 if(k <= 1392)
410 {
411 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
412 }
413 else
414 {
415 if(m <= 325)
416 {
417 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
418 }
419 else
420 {
421 return CLGEMMKernelType::RESHAPED;
422 }
423 }
424 }
425 else
426 {
427 if(k <= 471)
428 {
429 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
430 }
431 else
432 {
433 if(r_mn <= 0.04475911520421505f)
434 {
435 return CLGEMMKernelType::RESHAPED;
436 }
437 else
438 {
439 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
440 }
441 }
442 }
443 }
444 }
445
g52_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)446 CLGEMMKernelType CLGEMMDefaultTypeBifrost::g52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
447 {
448 if(!is_rhs_constant)
449 {
450 return CLGEMMKernelType::NATIVE;
451 }
452
453 if(m == 1)
454 {
455 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
456 }
457
458 if(n <= 127.0000f)
459 {
460 if(n <= 63.5000f)
461 {
462 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
463 }
464 else
465 {
466 if(m <= 3616.0000f)
467 {
468 if(b <= 18.5000f)
469 {
470 if(m <= 2970.5000f)
471 {
472 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
473 }
474 else
475 {
476 if(k <= 104.0000f)
477 {
478 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
479 }
480 else
481 {
482 return CLGEMMKernelType::RESHAPED;
483 }
484 }
485 }
486 else
487 {
488 return CLGEMMKernelType::RESHAPED;
489 }
490 }
491 else
492 {
493 return CLGEMMKernelType::RESHAPED;
494 }
495 }
496 }
497 else
498 {
499 if(m <= 12.5000f)
500 {
501 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
502 }
503 else
504 {
505 if(k <= 104.0000f)
506 {
507 if(b <= 18.5000f)
508 {
509 if(m <= 490.0000f)
510 {
511 if(n <= 272.0000f)
512 {
513 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
514 }
515 else
516 {
517 return CLGEMMKernelType::RESHAPED;
518 }
519 }
520 else
521 {
522 return CLGEMMKernelType::RESHAPED;
523 }
524 }
525 else
526 {
527 return CLGEMMKernelType::RESHAPED;
528 }
529 }
530 else
531 {
532 if(m <= 226.0000f)
533 {
534 if(n <= 140.0000f)
535 {
536 if(m <= 179.5000f)
537 {
538 return CLGEMMKernelType::RESHAPED;
539 }
540 else
541 {
542 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
543 }
544 }
545 else
546 {
547 return CLGEMMKernelType::RESHAPED;
548 }
549 }
550 else
551 {
552 return CLGEMMKernelType::RESHAPED;
553 }
554 }
555 }
556 }
557 }
558
g71_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)559 CLGEMMKernelType CLGEMMDefaultTypeBifrost::g71_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
560 {
561 ARM_COMPUTE_UNUSED(b);
562 ARM_COMPUTE_UNUSED(n);
563 ARM_COMPUTE_UNUSED(k);
564
565 if(is_rhs_constant)
566 {
567 if(m == 1)
568 {
569 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
570 }
571 else
572 {
573 return CLGEMMKernelType::RESHAPED;
574 }
575 }
576 else
577 {
578 return CLGEMMKernelType::NATIVE;
579 }
580 }
581 } // namespace cl_gemm
582 } // namespace arm_compute
583