1 /*
2  * Copyright (c) 2020-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedValhall.h"
25 
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "arm_compute/core/GPUTarget.h"
29 #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
30 
31 #include <utility>
32 
33 namespace arm_compute
34 {
35 namespace opencl
36 {
37 namespace kernels
38 {
39 namespace gemm
40 {
ClGemmDefaultConfigReshapedValhall(GPUTarget gpu)41 ClGemmDefaultConfigReshapedValhall::ClGemmDefaultConfigReshapedValhall(GPUTarget gpu)
42     : IClGemmKernelConfig(gpu)
43 {
44 }
45 
configure(unsigned int m,unsigned int n,unsigned int k,unsigned int b,DataType data_type)46 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
47 {
48     using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (ClGemmDefaultConfigReshapedValhall::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
49 
50     CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G77(&ClGemmDefaultConfigReshapedValhall::configure_G77_f32,
51                                                                     &ClGemmDefaultConfigReshapedValhall::configure_G77_f16,
52                                                                     &ClGemmDefaultConfigReshapedValhall::configure_G77_u8);
53 
54     CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G78(&ClGemmDefaultConfigReshapedValhall::configure_G78_f32,
55                                                                     &ClGemmDefaultConfigReshapedValhall::configure_G78_f16,
56                                                                     &ClGemmDefaultConfigReshapedValhall::configure_G77_u8);
57 
58     ConfigurationFunctionExecutorPtr func = nullptr;
59 
60     switch(_target)
61     {
62         case GPUTarget::G78:
63             func = configs_G78.get_function(data_type);
64             break;
65         case GPUTarget::G77:
66         default:
67             func = configs_G77.get_function(data_type);
68             break;
69     }
70 
71     ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM");
72     return (this->*func)(m, n, k, b);
73 }
74 
configure_G77_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)75 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G77_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
76 {
77     ARM_COMPUTE_UNUSED(k);
78     ARM_COMPUTE_UNUSED(b);
79 
80     if(n <= 4)
81     {
82         return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, 1, 0, 0, 1);
83     }
84     else
85     {
86         return configure_lhs_rhs_info(m, n, 5, 4, 4, 2, 16, 0, 1, 0, 1);
87     }
88 }
89 
configure_G77_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)90 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
91 {
92     ARM_COMPUTE_UNUSED(k);
93     ARM_COMPUTE_UNUSED(b);
94 
95     const float r_mn     = static_cast<float>(m) / static_cast<float>(n);
96     const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
97     const float r_mk     = static_cast<float>(m) / static_cast<float>(k);
98     const float r_nk     = static_cast<float>(n) / static_cast<float>(k);
99 
100     GEMMLHSMatrixInfo lhs_info_buf;
101     GEMMRHSMatrixInfo rhs_info_buf;
102     GEMMLHSMatrixInfo lhs_info_img;
103     GEMMRHSMatrixInfo rhs_info_img;
104 
105     std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 0);
106 
107     if(r_mk <= 0.11824845522642136)
108     {
109         if(workload <= 880.0)
110         {
111             return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0);
112         }
113         else
114         {
115             if(r_nk <= 0.42521367967128754)
116             {
117                 if(workload <= 1726.4000244140625)
118                 {
119                     return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 0);
120                 }
121                 else
122                 {
123                     std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
124 
125                     return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
126                                                std::make_pair(lhs_info_buf, rhs_info_buf),
127                                                n, k, b, DataType::F16);
128                 }
129             }
130             else
131             {
132                 if(workload <= 1241.6000366210938)
133                 {
134                     return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0);
135                 }
136                 else
137                 {
138                     return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 0);
139                 }
140             }
141         }
142     }
143     else
144     {
145         if(workload <= 11404.7998046875)
146         {
147             if(r_mk <= 1.0126488208770752)
148             {
149                 if(r_mn <= 2.545312523841858)
150                 {
151                     std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
152 
153                     return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
154                                                std::make_pair(lhs_info_buf, rhs_info_buf),
155                                                n, k, b, DataType::F16);
156                 }
157                 else
158                 {
159                     return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 4, 0, 0, 1, 0, 0);
160                 }
161             }
162             else
163             {
164                 if(workload <= 2881.199951171875)
165                 {
166                     std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, 0, 0, 1, 0, 1);
167 
168                     return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
169                                                std::make_pair(lhs_info_buf, rhs_info_buf),
170                                                n, k, b, DataType::F16);
171                 }
172                 else
173                 {
174                     std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
175 
176                     return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
177                                                std::make_pair(lhs_info_buf, rhs_info_buf),
178                                                n, k, b, DataType::F16);
179                 }
180             }
181         }
182         else
183         {
184             if(r_nk <= 0.5765306055545807)
185             {
186                 if(r_mn <= 6.010416746139526)
187                 {
188                     std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 0, 1, 1, 0, 1);
189 
190                     return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
191                                                std::make_pair(lhs_info_buf, rhs_info_buf),
192                                                n, k, b, DataType::F16);
193                 }
194                 else
195                 {
196                     std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 1, 0, 1, 0, 1);
197 
198                     return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
199                                                std::make_pair(lhs_info_buf, rhs_info_buf),
200                                                n, k, b, DataType::F16);
201                 }
202             }
203             else
204             {
205                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, 1, 0, 1, 0, 1);
206 
207                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
208                                            std::make_pair(lhs_info_buf, rhs_info_buf),
209                                            n, k, b, DataType::F16);
210             }
211         }
212     }
213 }
214 
configure_G78_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)215 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
216 {
217     const float r_mn     = static_cast<float>(m) / static_cast<float>(n);
218     const float r_mk     = static_cast<float>(m) / static_cast<float>(k);
219     const float r_nk     = static_cast<float>(n) / static_cast<float>(k);
220     const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
221 
222     if(workload <= 1288.0000f)
223     {
224         if(workload <= 505.6000f)
225         {
226             if(r_mn <= 0.4466f)
227             {
228                 if(r_nk <= 0.2384f)
229                 {
230                     return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
231                 }
232                 else
233                 {
234                     return configure_lhs_rhs_info(m, n, 2, 2, 4, 2, 2, 0, 0, 1, 0, 0);
235                 }
236             }
237             else
238             {
239                 return configure_lhs_rhs_info(m, n, 2, 2, 4, 2, 2, 0, 0, 1, 0, 0);
240             }
241         }
242         else
243         {
244             if(r_mn <= 0.2250f)
245             {
246                 if(r_mn <= 0.1599f)
247                 {
248                     return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
249                 }
250                 else
251                 {
252                     return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
253                 }
254             }
255             else
256             {
257                 if(r_mk <= 0.7609f)
258                 {
259                     if(r_mn <= 2.5453f)
260                     {
261                         if(workload <= 1089.6000f)
262                         {
263                             return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
264                         }
265                         else
266                         {
267                             return configure_lhs_rhs_info(m, n, 2, 4, 8, 2, 4, 0, 0, 1, 0, 1);
268                         }
269                     }
270                     else
271                     {
272                         return configure_lhs_rhs_info(m, n, 2, 4, 16, 4, 4, 0, 0, 1, 0, 1);
273                     }
274                 }
275                 else
276                 {
277                     return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
278                 }
279             }
280         }
281     }
282     else
283     {
284         if(workload <= 5434.4001f)
285         {
286             if(workload <= 1603.2000f)
287             {
288                 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
289             }
290             else
291             {
292                 if(r_nk <= 0.6192f)
293                 {
294                     if(r_mn <= 16.1016f)
295                     {
296                         return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
297                     }
298                     else
299                     {
300                         if(workload <= 2750.0000f)
301                         {
302                             return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
303                         }
304                         else
305                         {
306                             if(r_mk <= 6.3151f)
307                             {
308                                 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
309                             }
310                             else
311                             {
312                                 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
313                             }
314                         }
315                     }
316                 }
317                 else
318                 {
319                     if(r_mk <= 0.0387f)
320                     {
321                         return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
322                     }
323                     else
324                     {
325                         if(r_mk <= 2.5859f)
326                         {
327                             if(r_mk <= 0.2734f)
328                             {
329                                 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
330                             }
331                             else
332                             {
333                                 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
334                             }
335                         }
336                         else
337                         {
338                             return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
339                         }
340                     }
341                 }
342             }
343         }
344         else
345         {
346             if(r_mk <= 25.7500f)
347             {
348                 if(r_mk <= 0.3615f)
349                 {
350                     if(r_mn <= 0.0913f)
351                     {
352                         if(r_mk <= 0.0683f)
353                         {
354                             return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1);
355                         }
356                         else
357                         {
358                             return configure_lhs_rhs_info(m, n, 2, 4, 8, 4, 4, 0, 0, 1, 0, 1);
359                         }
360                     }
361                     else
362                     {
363                         return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
364                     }
365                 }
366                 else
367                 {
368                     if(workload <= 11174.3999f)
369                     {
370                         if(r_mk <= 0.8047f)
371                         {
372                             return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
373                         }
374                         else
375                         {
376                             if(workload <= 7185.5999f)
377                             {
378                                 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
379                             }
380                             else
381                             {
382                                 return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1);
383                             }
384                         }
385                     }
386                     else
387                     {
388                         if(workload <= 17917.5000f)
389                         {
390                             if(r_mk <= 1.5078f)
391                             {
392                                 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
393                             }
394                             else
395                             {
396                                 return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 1, 0, 1);
397                             }
398                         }
399                         else
400                         {
401                             if(workload <= 34449.6016f)
402                             {
403                                 return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
404                             }
405                             else
406                             {
407                                 return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 4, 0, 0, 1, 0, 1);
408                             }
409                         }
410                     }
411                 }
412             }
413             else
414             {
415                 if(r_mk <= 331.1111f)
416                 {
417                     if(workload <= 53397.5996f)
418                     {
419                         if(r_mn <= 57.8063f)
420                         {
421                             return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
422                         }
423                         else
424                         {
425                             return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
426                         }
427                     }
428                     else
429                     {
430                         if(r_nk <= 0.9211f)
431                         {
432                             return configure_lhs_rhs_info(m, n, 8, 4, 4, 4, 2, 0, 0, 1, 0, 1);
433                         }
434                         else
435                         {
436                             return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
437                         }
438                     }
439                 }
440                 else
441                 {
442                     if(workload <= 38070.4004f)
443                     {
444                         return configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1);
445                     }
446                     else
447                     {
448                         return configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 2, 0, 0, 1, 0, 1);
449                     }
450                 }
451             }
452         }
453     }
454 }
455 
configure_G78_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)456 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
457 {
458     const float r_mn     = static_cast<float>(m) / static_cast<float>(n);
459     const float r_nk     = static_cast<float>(n) / static_cast<float>(k);
460     const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
461 
462     if(workload <= 801.6000f)
463     {
464         return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
465     }
466     else
467     {
468         if(r_mn <= 0.1211f)
469         {
470             if(workload <= 3296.0000f)
471             {
472                 return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
473             }
474             else
475             {
476                 if(r_nk <= 1.0625f)
477                 {
478                     return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
479                 }
480                 else
481                 {
482                     return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 4, 0, 0, 1, 0, 1);
483                 }
484             }
485         }
486         else
487         {
488             if(workload <= 5068.8000f)
489             {
490                 return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
491             }
492             else
493             {
494                 if(r_nk <= 0.2361f)
495                 {
496                     if(workload <= 12630.0000f)
497                     {
498                         return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
499                     }
500                     else
501                     {
502                         return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 1, 0, 0, 1, 0, 1);
503                     }
504                 }
505                 else
506                 {
507                     if(workload <= 178790.3984f)
508                     {
509                         return configure_lhs_rhs_info(m, n, 8, 4, 4, 2, 2, 0, 0, 1, 0, 1);
510                     }
511                     else
512                     {
513                         return configure_lhs_rhs_info(m, n, 8, 4, 4, 1, 1, 0, 0, 1, 0, 1);
514                     }
515                 }
516             }
517         }
518     }
519 }
520 
configure_G77_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)521 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedValhall::configure_G77_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
522 {
523     ARM_COMPUTE_UNUSED(k);
524     ARM_COMPUTE_UNUSED(b);
525 
526     if(n <= 4)
527     {
528         return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1, 0, 0, 0, 1);
529     }
530     else
531     {
532         return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, 0, 1, 0, 1);
533     }
534 }
535 } // namespace gemm
536 } // namespace kernels
537 } // namespace opencl
538 } // namespace arm_compute
539