1 /*
2  * Copyright (c) 2020-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/gpu/cl/kernels/gemm/reshaped_only_rhs/ClGemmDefaultConfigReshapedRhsOnlyValhall.h"
25 
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "arm_compute/core/GPUTarget.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/TensorShape.h"
31 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
32 
33 #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
34 #include "src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h"
35 
36 #include <utility>
37 
38 namespace arm_compute
39 {
40 namespace opencl
41 {
42 namespace kernels
43 {
44 namespace gemm
45 {
46 using namespace arm_compute::misc::shape_calculator;
47 
ClGemmDefaultConfigReshapedRhsOnlyValhall(GPUTarget gpu)48 ClGemmDefaultConfigReshapedRhsOnlyValhall::ClGemmDefaultConfigReshapedRhsOnlyValhall(GPUTarget gpu)
49     : IClGemmKernelConfig(gpu)
50 {
51 }
52 
configure(unsigned int m,unsigned int n,unsigned int k,unsigned int b,DataType data_type)53 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
54 {
55     using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (ClGemmDefaultConfigReshapedRhsOnlyValhall::*)(unsigned int m, unsigned int n, unsigned int k,
56                                              unsigned int b);
57 
58     CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G77(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32,
59                                                                     &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16,
60                                                                     &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
61 
62     CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G78(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32,
63                                                                     &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16,
64                                                                     &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
65 
66     CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G715(&ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32,
67                                                                      &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16,
68                                                                      &ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8);
69 
70     ConfigurationFunctionExecutorPtr func = nullptr;
71 
72     switch(_target)
73     {
74         case GPUTarget::G78:
75             func = configs_G78.get_function(data_type);
76             break;
77         case GPUTarget::G715:
78         case GPUTarget::G615:
79             func = configs_G715.get_function(data_type);
80             break;
81         case GPUTarget::G77:
82         default:
83             func = configs_G77.get_function(data_type);
84             break;
85     }
86 
87     ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM");
88     return (this->*func)(m, n, k, b);
89 }
90 
configure_G77_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)91 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
92 {
93     if(m == 1)
94     {
95         const float r_mn = static_cast<float>(m) / static_cast<float>(n);
96         const float r_mk = static_cast<float>(m) / static_cast<float>(k);
97 
98         if(r_mk <= 0.0064484127797186375)
99         {
100             if(r_mn <= 0.0028273810748942196)
101             {
102                 GEMMLHSMatrixInfo lhs_info_buf;
103                 GEMMRHSMatrixInfo rhs_info_buf;
104                 GEMMLHSMatrixInfo lhs_info_img;
105                 GEMMRHSMatrixInfo rhs_info_img;
106 
107                 const unsigned int h0 = std::max(n / 4, 1U);
108                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 16, 0, 1, 0, 0, 1);
109                 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 1, 4, 4, 1, h0, 0, 1, 0, 1, 0);
110 
111                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
112                                            std::make_pair(lhs_info_buf, rhs_info_buf),
113                                            n, k, b, DataType::F32);
114             }
115             else
116             {
117                 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, 0, 1, 0, 0, 0);
118             }
119         }
120         else
121         {
122             if(r_mk <= 0.020312500186264515)
123             {
124                 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 4, 0, 1, 0, 0, 0);
125             }
126             else
127             {
128                 return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, 16, 0, 1, 0, 1, 0);
129             }
130         }
131     }
132     else
133     {
134         const float r_mn     = static_cast<float>(m) / static_cast<float>(n);
135         const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
136         const float r_mk     = static_cast<float>(m) / static_cast<float>(k);
137 
138         if(workload <= 1999.2000122070312)
139         {
140             if(workload <= 747.1999816894531)
141             {
142                 return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
143             }
144             else
145             {
146                 GEMMLHSMatrixInfo lhs_info_buf;
147                 GEMMRHSMatrixInfo rhs_info_buf;
148                 GEMMLHSMatrixInfo lhs_info_img;
149                 GEMMRHSMatrixInfo rhs_info_img;
150                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, 0, 0, 0, 1, 1);
151                 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
152 
153                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
154                                            std::make_pair(lhs_info_buf, rhs_info_buf),
155                                            n, k, b, DataType::F32);
156             }
157         }
158         else
159         {
160             if(r_mn <= 0.03348214365541935)
161             {
162                 if(r_mk <= 0.028125000186264515)
163                 {
164                     return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
165                 }
166                 else
167                 {
168                     GEMMLHSMatrixInfo lhs_info_buf;
169                     GEMMRHSMatrixInfo rhs_info_buf;
170                     GEMMLHSMatrixInfo lhs_info_img;
171                     GEMMRHSMatrixInfo rhs_info_img;
172                     std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 2, 0, 0, 0, 1, 1);
173                     std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 8, 0, 1, 0, 1, 0);
174 
175                     return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
176                                                std::make_pair(lhs_info_buf, rhs_info_buf),
177                                                n, k, b, DataType::F32);
178                 }
179             }
180             else
181             {
182                 GEMMLHSMatrixInfo lhs_info_buf;
183                 GEMMRHSMatrixInfo rhs_info_buf;
184                 GEMMLHSMatrixInfo lhs_info_img;
185                 GEMMRHSMatrixInfo rhs_info_img;
186                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 2, 0, 1, 0, 0, 1);
187                 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, 0, 1, 0, 1, 0);
188 
189                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
190                                            std::make_pair(lhs_info_buf, rhs_info_buf),
191                                            n, k, b, DataType::F32);
192             }
193         }
194     }
195 }
196 
configure_G77_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)197 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
198 {
199     ARM_COMPUTE_UNUSED(k);
200     ARM_COMPUTE_UNUSED(b);
201 
202     if(m == 1)
203     {
204         const unsigned int h0 = std::max(n / 2, 1U);
205         if(n <= 836.0)
206         {
207             return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, h0, 0, 1, 0, 1, 0);
208         }
209         else
210         {
211             return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, h0, 0, 1, 0, 1, 0);
212         }
213     }
214     else if(m < 128)
215     {
216         const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
217         if(k >= 512)
218         {
219             return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 0);
220         }
221         else
222         {
223             return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, 0, 1, 0, 0);
224         }
225     }
226     else
227     {
228         const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
229         if(n >= 64)
230         {
231             return configure_lhs_rhs_info(m, n, 4, 8, 4, 1, h0, 0, 1, 0, 0);
232         }
233         else
234         {
235             if(k >= 512)
236             {
237                 return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 0);
238             }
239             else
240             {
241                 return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, h0, 0, 1, 0, 0);
242             }
243         }
244     }
245 }
246 
configure_G77_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)247 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
248 {
249     ARM_COMPUTE_UNUSED(k);
250     ARM_COMPUTE_UNUSED(b);
251 
252     if(m == 1)
253     {
254         const unsigned int h0 = std::max(n / 2, 1U);
255         return configure_lhs_rhs_info(m, n, 1, 4, 16, 1, h0, 0, 1, 0, 1);
256     }
257     else
258     {
259         const int h0 = std::max(std::min(static_cast<int>(n / 4), static_cast<int>(256)), static_cast<int>(1));
260         if(m >= 28)
261         {
262             return configure_lhs_rhs_info(m, n, 4, 4, 16, 1, h0, 0, 1, 0, 1);
263         }
264         else
265         {
266             return configure_lhs_rhs_info(m, n, 2, 4, 16, 1, h0, 0, 1, 0, 1);
267         }
268     }
269 }
270 
configure_G78_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)271 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
272 {
273     const float r_mn     = static_cast<float>(m) / static_cast<float>(n);
274     const float r_mk     = static_cast<float>(m) / static_cast<float>(k);
275     const float r_nk     = static_cast<float>(n) / static_cast<float>(k);
276     const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
277 
278     if(m == 1)
279     {
280         if(workload <= 278.7000f)
281         {
282             if(workload <= 7.5000f)
283             {
284                 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
285             }
286             else
287             {
288                 if(r_mn <= 0.0031f)
289                 {
290                     if(workload <= 256.6000f)
291                     {
292                         if(workload <= 16.7500f)
293                         {
294                             if(r_nk <= 1.6671f)
295                             {
296                                 return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
297                             }
298                             else
299                             {
300                                 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
301                             }
302                         }
303                         else
304                         {
305                             return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
306                         }
307                     }
308                     else
309                     {
310                         return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
311                     }
312                 }
313                 else
314                 {
315                     if(r_mk <= 0.0027f)
316                     {
317                         if(r_mk <= 0.0014f)
318                         {
319                             return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
320                         }
321                         else
322                         {
323                             if(workload <= 8.9500f)
324                             {
325                                 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
326                             }
327                             else
328                             {
329                                 return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
330                             }
331                         }
332                     }
333                     else
334                     {
335                         if(workload <= 14.1500f)
336                         {
337                             return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
338                         }
339                         else
340                         {
341                             if(r_mk <= 0.0041f)
342                             {
343                                 return configure_lhs_rhs_info(m, n, 1, 2, 2, 1, 32, 0, 0, 0, 1, 0);
344                             }
345                             else
346                             {
347                                 return configure_lhs_rhs_info(m, n, 1, 2, 8, 1, 2, 0, 1, 1, 0, 0);
348                             }
349                         }
350                     }
351                 }
352             }
353         }
354         else
355         {
356             if(workload <= 363.7000f)
357             {
358                 if(r_mk <= 0.0031f)
359                 {
360                     return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 32, 0, 1, 0, 1, 0);
361                 }
362                 else
363                 {
364                     return configure_lhs_rhs_info(m, n, 1, 4, 4, 1, 32, 0, 1, 0, 1, 0);
365                 }
366             }
367             else
368             {
369                 return configure_lhs_rhs_info(m, n, 1, 4, 2, 1, 32, 0, 1, 0, 1, 0);
370             }
371         }
372     }
373     else
374     {
375         if(workload <= 1384.8000f)
376         {
377             if(workload <= 704.0000f)
378             {
379                 return configure_lhs_rhs_info(m, n, 2, 2, 4, 1, 32, 0, 1, 0, 1, 0);
380             }
381             else
382             {
383                 return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 4, 0, 0, 0, 1, 1);
384             }
385         }
386         else
387         {
388             if(workload <= 16761.6006f)
389             {
390                 if(r_mn <= 187.1250f)
391                 {
392                     return configure_lhs_rhs_info(m, n, 4, 4, 4, 1, 16, 0, 0, 0, 1, 1);
393                 }
394                 else
395                 {
396                     return configure_lhs_rhs_info(m, n, 2, 4, 8, 1, 4, 0, 0, 0, 1, 1);
397                 }
398             }
399             else
400             {
401                 if(r_mk <= 432.4630f)
402                 {
403                     return configure_lhs_rhs_info(m, n, 5, 4, 4, 1, 16, 0, 0, 0, 1, 1);
404                 }
405                 else
406                 {
407                     return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 16, 0, 1, 0, 1, 1);
408                 }
409             }
410         }
411     }
412 }
413 
configure_G78_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)414 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
415 {
416     const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
417     const float r_mn = static_cast<float>(m) / static_cast<float>(n);
418     const float r_mk = static_cast<float>(m) / static_cast<float>(k);
419     const float r_nk = static_cast<float>(n) / static_cast<float>(k);
420 
421     if(m == 1)
422     {
423         if(r_mn <= 0.0045f)
424         {
425             if(workload <= 278.7000f)
426             {
427                 return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, 0, 0, 0, 1, 1);
428             }
429             else
430             {
431                 return configure_lhs_rhs_info(m, n, 1, 4, 8, 1, 32, 0, 0, 1, 0, 0);
432             }
433         }
434         else
435         {
436             return configure_lhs_rhs_info(m, n, 1, 2, 16, 1, 8, 0, 0, 1, 0, 0);
437         }
438     }
439     else
440     {
441         if(workload <= 1384.8000f)
442         {
443             if(r_nk <= 0.8333f)
444             {
445                 if(r_mk <= 0.9119f)
446                 {
447                     return configure_lhs_rhs_info(m, n, 2, 2, 16, 1, 4, 0, 1, 0, 1, 1);
448                 }
449                 else
450                 {
451                     if(r_nk <= 0.1181f)
452                     {
453                         return configure_lhs_rhs_info(m, n, 2, 2, 8, 1, 32, 0, 0, 1, 0, 0);
454                     }
455                     else
456                     {
457                         return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 0);
458                     }
459                 }
460             }
461             else
462             {
463                 if(r_mk <= 1.0013f)
464                 {
465                     return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 1);
466                 }
467                 else
468                 {
469                     return configure_lhs_rhs_info(m, n, 5, 4, 8, 1, 4, 0, 1, 1, 0, 1);
470                 }
471             }
472         }
473         else
474         {
475             if(workload <= 11404.7998f)
476             {
477                 if(r_mk <= 2.2884f)
478                 {
479                     if(r_nk <= 0.9286f)
480                     {
481                         return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 4, 0, 1, 1, 0, 1);
482                     }
483                     else
484                     {
485                         return configure_lhs_rhs_info(m, n, 4, 4, 8, 1, 32, 0, 1, 1, 0, 1);
486                     }
487                 }
488                 else
489                 {
490                     return configure_lhs_rhs_info(m, n, 5, 4, 8, 1, 4, 0, 1, 1, 0, 1);
491                 }
492             }
493             else
494             {
495                 if(r_nk <= 1.1926f)
496                 {
497                     if(r_mn <= 1385.7917f)
498                     {
499                         return configure_lhs_rhs_info(m, n, 6, 4, 8, 1, 4, 0, 1, 1, 0, 1);
500                     }
501                     else
502                     {
503                         return configure_lhs_rhs_info(m, n, 2, 8, 8, 1, 32, 0, 1, 1, 0, 0);
504                     }
505                 }
506                 else
507                 {
508                     return configure_lhs_rhs_info(m, n, 6, 4, 8, 1, 32, 0, 1, 1, 0, 1);
509                 }
510             }
511         }
512     }
513 }
514 
configure_G715_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)515 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
516 {
517     unsigned int best_m0;
518     unsigned int best_n0;
519 
520     if(is_mmul_kernel_preferred(m, n, k, b, DataType::F32, best_m0, best_n0))
521     {
522         return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4, false, true, false, false, true);
523     }
524     else
525     {
526         return configure_G77_f32(m, n, k, b);
527     }
528 }
529 
configure_G715_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)530 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedRhsOnlyValhall::configure_G715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
531 {
532     unsigned int best_m0;
533     unsigned int best_n0;
534 
535     if(is_mmul_kernel_preferred(m, n, k, b, DataType::F16, best_m0, best_n0))
536     {
537         return configure_lhs_rhs_info(m, n, best_m0, best_n0, 1, 1, 4, false, true, false, false, true);
538     }
539     else
540     {
541         return configure_G78_f16(m, n, k, b);
542     }
543 }
544 } // namespace gemm
545 } // namespace kernels
546 } // namespace opencl
547 } // namespace arm_compute
548