xref: /aosp_15_r20/external/ComputeLibrary/src/runtime/CL/gemm/CLGEMMDefaultTypeBifrost.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2020-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/runtime/CL/gemm/CLGEMMDefaultTypeBifrost.h"
25 
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
29 
30 #include <map>
31 #include <utility>
32 
33 namespace arm_compute
34 {
35 namespace cl_gemm
36 {
CLGEMMDefaultTypeBifrost(GPUTarget gpu)37 CLGEMMDefaultTypeBifrost::CLGEMMDefaultTypeBifrost(GPUTarget gpu)
38     : ICLGEMMKernelSelection(gpu)
39 {
40 }
41 
select_kernel(const CLGEMMKernelSelectionParams & params)42 CLGEMMKernelType CLGEMMDefaultTypeBifrost::select_kernel(const CLGEMMKernelSelectionParams &params)
43 {
44     // _target could be used in the future to have a dedicated heuristic for each GPU IP
45     ARM_COMPUTE_UNUSED(_target);
46 
47     using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMDefaultTypeBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
48 
49     // Default configurations for Bifrost architectures
50     static std::map<DataType, FunctionExecutorPtr> gemm_default_configs =
51     {
52         { DataType::F32, &CLGEMMDefaultTypeBifrost::default_f32 },
53         { DataType::F16, &CLGEMMDefaultTypeBifrost::default_f16 },
54         { DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
55         { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8 },
56         { DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
57         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8 }
58     };
59 
60     // Mali-G71 configurations
61     static std::map<DataType, FunctionExecutorPtr> gemm_g71_configs =
62     {
63         { DataType::F32, &CLGEMMDefaultTypeBifrost::default_f32 },
64         { DataType::F16, &CLGEMMDefaultTypeBifrost::g71_f16 },
65         { DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
66         { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8 },
67         { DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
68         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8 }
69     };
70 
71     // Mali-G52 configurations
72     static std::map<DataType, FunctionExecutorPtr> gemm_g52_configs =
73     {
74         { DataType::F32, &CLGEMMDefaultTypeBifrost::g52_f32 },
75         { DataType::F16, &CLGEMMDefaultTypeBifrost::g52_f16 },
76         { DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
77         { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8 },
78         { DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
79         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8 }
80     };
81 
82     // Mali-G76 configurations
83     static std::map<DataType, FunctionExecutorPtr> gemm_g76_configs =
84     {
85         { DataType::F32, &CLGEMMDefaultTypeBifrost::g76_f32 },
86         { DataType::F16, &CLGEMMDefaultTypeBifrost::g76_f16 },
87         { DataType::QASYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
88         { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeBifrost::default_q8 },
89         { DataType::QSYMM8, &CLGEMMDefaultTypeBifrost::default_q8 },
90         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeBifrost::default_q8 }
91     };
92 
93     const DataType data_type = params.data_type;
94 
95     switch(_target)
96     {
97         case GPUTarget::G71:
98             if(gemm_g71_configs.find(data_type) != gemm_g71_configs.end())
99             {
100                 return (this->*gemm_g71_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
101             }
102             ARM_COMPUTE_ERROR("Not supported data type");
103         case GPUTarget::G76:
104             if(gemm_g76_configs.find(data_type) != gemm_g76_configs.end())
105             {
106                 return (this->*gemm_g76_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
107             }
108             ARM_COMPUTE_ERROR("Not supported data type");
109         case GPUTarget::G52:
110             if(gemm_g52_configs.find(data_type) != gemm_g52_configs.end())
111             {
112                 return (this->*gemm_g52_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
113             }
114             ARM_COMPUTE_ERROR("Not supported data type");
115         default:
116             if(gemm_default_configs.find(data_type) != gemm_default_configs.end())
117             {
118                 return (this->*gemm_default_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
119             }
120             ARM_COMPUTE_ERROR("Not supported data type");
121     }
122 }
123 
default_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)124 CLGEMMKernelType CLGEMMDefaultTypeBifrost::default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
125 {
126     ARM_COMPUTE_UNUSED(b);
127 
128     CLGEMMKernelType gemm_type = CLGEMMKernelType::NATIVE;
129 
130     if(is_rhs_constant)
131     {
132         if((m > 1) && (n < 16))
133         {
134             gemm_type = CLGEMMKernelType::RESHAPED;
135         }
136         else if(m == 1)
137         {
138             gemm_type = CLGEMMKernelType::RESHAPED_ONLY_RHS;
139         }
140         else
141         {
142             if((k > 256) && (m > 4))
143             {
144                 constexpr float alpha = 3.2f;
145                 constexpr float fact0 = 1.51f;
146                 constexpr float fact1 = 1.66f;
147                 constexpr float ops   = 12.0f;
148                 const float     scale = k > 1024 ? 1.07f : 1.0f;
149                 gemm_type             = (alpha + ((n * fact0) / ops) < ((fact1 * n * scale) / ops)) ? CLGEMMKernelType::RESHAPED : CLGEMMKernelType::RESHAPED_ONLY_RHS;
150             }
151             else
152             {
153                 gemm_type = CLGEMMKernelType::RESHAPED_ONLY_RHS;
154             }
155         }
156 
157         const auto workload = static_cast<float>((m * n) / 20.0f);
158 
159         gemm_type = ((workload > 1600.0f) && (gemm_type == CLGEMMKernelType::RESHAPED)) ? CLGEMMKernelType::RESHAPED : gemm_type;
160     }
161 
162     return gemm_type;
163 }
164 
default_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)165 CLGEMMKernelType CLGEMMDefaultTypeBifrost::default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
166 {
167     ARM_COMPUTE_UNUSED(n, k, b);
168 
169     if(is_rhs_constant)
170     {
171         if(m == 1)
172         {
173             return CLGEMMKernelType::RESHAPED_ONLY_RHS;
174         }
175         else
176         {
177             return CLGEMMKernelType::RESHAPED;
178         }
179     }
180     else
181     {
182         return CLGEMMKernelType::NATIVE;
183     }
184 }
185 
default_q8(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)186 CLGEMMKernelType CLGEMMDefaultTypeBifrost::default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
187 {
188     ARM_COMPUTE_UNUSED(m, n, k, b);
189 
190     if(is_rhs_constant)
191     {
192         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
193     }
194     else
195     {
196         return CLGEMMKernelType::NATIVE;
197     }
198 }
199 
g76_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)200 CLGEMMKernelType CLGEMMDefaultTypeBifrost::g76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
201 {
202     ARM_COMPUTE_UNUSED(b);
203 
204     if(!is_rhs_constant)
205     {
206         return CLGEMMKernelType::NATIVE;
207     }
208     if(m == 1)
209     {
210         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
211     }
212     if(k <= 496)
213     {
214         if(n <= 544)
215         {
216             return CLGEMMKernelType::RESHAPED_ONLY_RHS;
217         }
218         else
219         {
220             return CLGEMMKernelType::RESHAPED;
221         }
222     }
223     else
224     {
225         if(k <= 588)
226         {
227             if(k <= 552)
228             {
229                 if(m <= 148)
230                 {
231                     return CLGEMMKernelType::RESHAPED_ONLY_RHS;
232                 }
233                 else
234                 {
235                     if(m <= 278)
236                     {
237                         return CLGEMMKernelType::RESHAPED;
238                     }
239                     else
240                     {
241                         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
242                     }
243                 }
244             }
245             else
246             {
247                 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
248             }
249         }
250         else
251         {
252             return CLGEMMKernelType::RESHAPED;
253         }
254     }
255 }
256 
g52_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)257 CLGEMMKernelType CLGEMMDefaultTypeBifrost::g52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
258 {
259     ARM_COMPUTE_UNUSED(b);
260 
261     if(!is_rhs_constant)
262     {
263         return CLGEMMKernelType::NATIVE;
264     }
265 
266     if(m == 1)
267     {
268         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
269     }
270 
271     const float r_mn  = static_cast<float>(m) / static_cast<float>(n);
272     const float r_mk  = static_cast<float>(m) / static_cast<float>(k);
273     const float r_nk  = static_cast<float>(n) / static_cast<float>(k);
274     const float r_mnk = static_cast<float>(m) / (static_cast<float>(n) * static_cast<float>(k));
275 
276     if(r_mn <= 1.5469f)
277     {
278         if(r_mk <= 0.8766f)
279         {
280             if(r_mk <= 0.0211f)
281             {
282                 if(r_mnk <= 77.5833f)
283                 {
284                     return CLGEMMKernelType::RESHAPED_ONLY_RHS;
285                 }
286                 else
287                 {
288                     return CLGEMMKernelType::RESHAPED;
289                 }
290             }
291             else
292             {
293                 if(r_nk <= 0.0832f)
294                 {
295                     return CLGEMMKernelType::RESHAPED_ONLY_RHS;
296                 }
297                 else
298                 {
299                     return CLGEMMKernelType::RESHAPED;
300                 }
301             }
302         }
303         else
304         {
305             if(r_mnk <= 193.0000f)
306             {
307                 if(r_mn <= 0.9948f)
308                 {
309                     if(r_mk <= 2.5453f)
310                     {
311                         return CLGEMMKernelType::RESHAPED;
312                     }
313                     else
314                     {
315                         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
316                     }
317                 }
318                 else
319                 {
320                     return CLGEMMKernelType::RESHAPED_ONLY_RHS;
321                 }
322             }
323             else
324             {
325                 return CLGEMMKernelType::RESHAPED;
326             }
327         }
328     }
329     else
330     {
331         if(r_mn <= 17.7370f)
332         {
333             if(r_mnk <= 1391.2875f)
334             {
335                 if(r_mk <= 2.9724f)
336                 {
337                     return CLGEMMKernelType::RESHAPED_ONLY_RHS;
338                 }
339                 else
340                 {
341                     if(r_mnk <= 470.0000f)
342                     {
343                         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
344                     }
345                     else
346                     {
347                         return CLGEMMKernelType::RESHAPED;
348                     }
349                 }
350             }
351             else
352             {
353                 if(r_nk <= 0.1381f)
354                 {
355                     if(r_mnk <= 9040.5000f)
356                     {
357                         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
358                     }
359                     else
360                     {
361                         return CLGEMMKernelType::RESHAPED;
362                     }
363                 }
364                 else
365                 {
366                     if(r_mn <= 5.6790f)
367                     {
368                         return CLGEMMKernelType::RESHAPED;
369                     }
370                     else
371                     {
372                         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
373                     }
374                 }
375             }
376         }
377         else
378         {
379             return CLGEMMKernelType::RESHAPED_ONLY_RHS;
380         }
381     }
382 }
383 
g76_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)384 CLGEMMKernelType CLGEMMDefaultTypeBifrost::g76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
385 {
386     ARM_COMPUTE_UNUSED(b);
387 
388     if(!is_rhs_constant)
389     {
390         return CLGEMMKernelType::NATIVE;
391     }
392 
393     if(m == 1)
394     {
395         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
396     }
397 
398     const float r_mn = static_cast<float>(m) / static_cast<float>(n);
399     const float r_nk = static_cast<float>(n) / static_cast<float>(k);
400 
401     if(k <= 212)
402     {
403         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
404     }
405     else
406     {
407         if(r_nk <= 0.4990234375f)
408         {
409             if(k <= 1392)
410             {
411                 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
412             }
413             else
414             {
415                 if(m <= 325)
416                 {
417                     return CLGEMMKernelType::RESHAPED_ONLY_RHS;
418                 }
419                 else
420                 {
421                     return CLGEMMKernelType::RESHAPED;
422                 }
423             }
424         }
425         else
426         {
427             if(k <= 471)
428             {
429                 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
430             }
431             else
432             {
433                 if(r_mn <= 0.04475911520421505f)
434                 {
435                     return CLGEMMKernelType::RESHAPED;
436                 }
437                 else
438                 {
439                     return CLGEMMKernelType::RESHAPED_ONLY_RHS;
440                 }
441             }
442         }
443     }
444 }
445 
g52_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)446 CLGEMMKernelType CLGEMMDefaultTypeBifrost::g52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
447 {
448     if(!is_rhs_constant)
449     {
450         return CLGEMMKernelType::NATIVE;
451     }
452 
453     if(m == 1)
454     {
455         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
456     }
457 
458     if(n <= 127.0000f)
459     {
460         if(n <= 63.5000f)
461         {
462             return CLGEMMKernelType::RESHAPED_ONLY_RHS;
463         }
464         else
465         {
466             if(m <= 3616.0000f)
467             {
468                 if(b <= 18.5000f)
469                 {
470                     if(m <= 2970.5000f)
471                     {
472                         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
473                     }
474                     else
475                     {
476                         if(k <= 104.0000f)
477                         {
478                             return CLGEMMKernelType::RESHAPED_ONLY_RHS;
479                         }
480                         else
481                         {
482                             return CLGEMMKernelType::RESHAPED;
483                         }
484                     }
485                 }
486                 else
487                 {
488                     return CLGEMMKernelType::RESHAPED;
489                 }
490             }
491             else
492             {
493                 return CLGEMMKernelType::RESHAPED;
494             }
495         }
496     }
497     else
498     {
499         if(m <= 12.5000f)
500         {
501             return CLGEMMKernelType::RESHAPED_ONLY_RHS;
502         }
503         else
504         {
505             if(k <= 104.0000f)
506             {
507                 if(b <= 18.5000f)
508                 {
509                     if(m <= 490.0000f)
510                     {
511                         if(n <= 272.0000f)
512                         {
513                             return CLGEMMKernelType::RESHAPED_ONLY_RHS;
514                         }
515                         else
516                         {
517                             return CLGEMMKernelType::RESHAPED;
518                         }
519                     }
520                     else
521                     {
522                         return CLGEMMKernelType::RESHAPED;
523                     }
524                 }
525                 else
526                 {
527                     return CLGEMMKernelType::RESHAPED;
528                 }
529             }
530             else
531             {
532                 if(m <= 226.0000f)
533                 {
534                     if(n <= 140.0000f)
535                     {
536                         if(m <= 179.5000f)
537                         {
538                             return CLGEMMKernelType::RESHAPED;
539                         }
540                         else
541                         {
542                             return CLGEMMKernelType::RESHAPED_ONLY_RHS;
543                         }
544                     }
545                     else
546                     {
547                         return CLGEMMKernelType::RESHAPED;
548                     }
549                 }
550                 else
551                 {
552                     return CLGEMMKernelType::RESHAPED;
553                 }
554             }
555         }
556     }
557 }
558 
g71_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)559 CLGEMMKernelType CLGEMMDefaultTypeBifrost::g71_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
560 {
561     ARM_COMPUTE_UNUSED(b);
562     ARM_COMPUTE_UNUSED(n);
563     ARM_COMPUTE_UNUSED(k);
564 
565     if(is_rhs_constant)
566     {
567         if(m == 1)
568         {
569             return CLGEMMKernelType::RESHAPED_ONLY_RHS;
570         }
571         else
572         {
573             return CLGEMMKernelType::RESHAPED;
574         }
575     }
576     else
577     {
578         return CLGEMMKernelType::NATIVE;
579     }
580 }
581 } // namespace cl_gemm
582 } // namespace arm_compute
583