xref: /aosp_15_r20/external/ComputeLibrary/src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2020-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/runtime/CL/gemm/CLGEMMDefaultTypeValhall.h"
25 
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
29 
30 #include <map>
31 #include <utility>
32 
33 namespace arm_compute
34 {
35 namespace cl_gemm
36 {
CLGEMMDefaultTypeValhall(GPUTarget gpu)37 CLGEMMDefaultTypeValhall::CLGEMMDefaultTypeValhall(GPUTarget gpu)
38     : ICLGEMMKernelSelection(gpu)
39 {
40 }
41 
select_kernel(const CLGEMMKernelSelectionParams & params)42 CLGEMMKernelType CLGEMMDefaultTypeValhall::select_kernel(const CLGEMMKernelSelectionParams &params)
43 {
44     // _target could be used in the future to have a dedicated heuristic for each GPU IP
45     ARM_COMPUTE_UNUSED(_target);
46 
47     using FunctionExecutorPtr = CLGEMMKernelType (CLGEMMDefaultTypeValhall::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant);
48 
49     // Default configurations for Valhall architectures
50     static std::map<DataType, FunctionExecutorPtr> gemm_default_configs =
51     {
52         { DataType::F32, &CLGEMMDefaultTypeValhall::default_f32 },
53         { DataType::F16, &CLGEMMDefaultTypeValhall::default_f16 },
54         { DataType::QASYMM8, &CLGEMMDefaultTypeValhall::default_q8 },
55         { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeValhall::default_q8 },
56         { DataType::QSYMM8, &CLGEMMDefaultTypeValhall::default_q8 },
57         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeValhall::default_q8 }
58     };
59 
60     // Mali-G77 configurations
61     static std::map<DataType, FunctionExecutorPtr> gemm_g77_configs =
62     {
63         { DataType::F32, &CLGEMMDefaultTypeValhall::default_f32 },
64         { DataType::F16, &CLGEMMDefaultTypeValhall::g77_f16 },
65         { DataType::QASYMM8, &CLGEMMDefaultTypeValhall::default_q8 },
66         { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeValhall::default_q8 },
67         { DataType::QSYMM8, &CLGEMMDefaultTypeValhall::default_q8 },
68         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeValhall::default_q8 }
69     };
70 
71     // Mali-G78 configurations
72     static std::map<DataType, FunctionExecutorPtr> gemm_g78_configs =
73     {
74         { DataType::F32, &CLGEMMDefaultTypeValhall::g78_f32 },
75         { DataType::F16, &CLGEMMDefaultTypeValhall::g78_f16 },
76         { DataType::QASYMM8, &CLGEMMDefaultTypeValhall::default_q8 },
77         { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeValhall::default_q8 },
78         { DataType::QSYMM8, &CLGEMMDefaultTypeValhall::default_q8 },
79         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeValhall::default_q8 }
80     };
81 
82     // Mali-G715 and Mali-G615 configurations
83     static std::map<DataType, FunctionExecutorPtr> gemm_g715_configs =
84     {
85         { DataType::F32, &CLGEMMDefaultTypeValhall::g715_f32 },
86         { DataType::F16, &CLGEMMDefaultTypeValhall::g715_f16 },
87         { DataType::QASYMM8, &CLGEMMDefaultTypeValhall::default_q8 },
88         { DataType::QASYMM8_SIGNED, &CLGEMMDefaultTypeValhall::default_q8 },
89         { DataType::QSYMM8, &CLGEMMDefaultTypeValhall::default_q8 },
90         { DataType::QSYMM8_PER_CHANNEL, &CLGEMMDefaultTypeValhall::default_q8 }
91     };
92 
93     const DataType data_type = params.data_type;
94 
95     switch(_target)
96     {
97         case GPUTarget::G715:
98         case GPUTarget::G615:
99             if(gemm_g715_configs.find(data_type) != gemm_g715_configs.end())
100             {
101                 return (this->*gemm_g715_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
102             }
103             ARM_COMPUTE_ERROR("Not supported data type");
104         case GPUTarget::G78:
105             if(gemm_g78_configs.find(data_type) != gemm_g78_configs.end())
106             {
107                 return (this->*gemm_g78_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
108             }
109             ARM_COMPUTE_ERROR("Not supported data type");
110         case GPUTarget::G77:
111             if(gemm_g77_configs.find(data_type) != gemm_g77_configs.end())
112             {
113                 return (this->*gemm_g77_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
114             }
115             ARM_COMPUTE_ERROR("Not supported data type");
116         default:
117             if(gemm_default_configs.find(data_type) != gemm_default_configs.end())
118             {
119                 return (this->*gemm_default_configs[data_type])(params.m, params.n, params.k, params.b, params.is_rhs_constant);
120             }
121             ARM_COMPUTE_ERROR("Not supported data type");
122     }
123 }
124 
default_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)125 CLGEMMKernelType CLGEMMDefaultTypeValhall::default_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
126 {
127     ARM_COMPUTE_UNUSED(m, n, k, b);
128 
129     return is_rhs_constant ? CLGEMMKernelType::RESHAPED_ONLY_RHS : CLGEMMKernelType::NATIVE;
130 }
131 
default_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)132 CLGEMMKernelType CLGEMMDefaultTypeValhall::default_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
133 {
134     ARM_COMPUTE_UNUSED(m, n, k, b);
135 
136     return is_rhs_constant ? CLGEMMKernelType::RESHAPED_ONLY_RHS : CLGEMMKernelType::NATIVE;
137 }
138 
g77_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)139 CLGEMMKernelType CLGEMMDefaultTypeValhall::g77_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
140 {
141     if(!is_rhs_constant)
142     {
143         return CLGEMMKernelType::NATIVE;
144     }
145 
146     if(m == 1)
147     {
148         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
149     }
150 
151     const float r_mn     = static_cast<float>(m) / static_cast<float>(n);
152     const float r_mk     = static_cast<float>(m) / static_cast<float>(k);
153     const float r_nk     = static_cast<float>(n) / static_cast<float>(k);
154     const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
155 
156     if(r_mk <= 0.6817956566810608)
157     {
158         if(workload <= 801.6000061035156)
159         {
160             return CLGEMMKernelType::RESHAPED_ONLY_RHS;
161         }
162         else
163         {
164             if(r_mn <= 0.0839829258620739)
165             {
166                 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
167             }
168             else
169             {
170                 if(r_mk <= 0.24917218834161758)
171                 {
172                     return CLGEMMKernelType::RESHAPED;
173                 }
174                 else
175                 {
176                     if(workload <= 2551.75)
177                     {
178                         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
179                     }
180                     else
181                     {
182                         if(workload <= 5061.574951171875)
183                         {
184                             return CLGEMMKernelType::RESHAPED_ONLY_RHS;
185                         }
186                         else
187                         {
188                             return CLGEMMKernelType::RESHAPED;
189                         }
190                     }
191                 }
192             }
193         }
194     }
195     else
196     {
197         if(r_mk <= 4.849947690963745)
198         {
199             if(workload <= 17618.4501953125)
200             {
201                 if(workload <= 5224.699951171875)
202                 {
203                     return CLGEMMKernelType::RESHAPED_ONLY_RHS;
204                 }
205                 else
206                 {
207                     if(r_nk <= 0.7933054566383362)
208                     {
209                         return CLGEMMKernelType::RESHAPED;
210                     }
211                     else
212                     {
213                         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
214                     }
215                 }
216             }
217             else
218             {
219                 if(workload <= 20275.2001953125)
220                 {
221                     return CLGEMMKernelType::RESHAPED;
222                 }
223                 else
224                 {
225                     if(r_mk <= 3.07421875)
226                     {
227                         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
228                     }
229                     else
230                     {
231                         return CLGEMMKernelType::RESHAPED;
232                     }
233                 }
234             }
235         }
236         else
237         {
238             return CLGEMMKernelType::RESHAPED_ONLY_RHS;
239         }
240     }
241 }
242 
default_q8(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)243 CLGEMMKernelType CLGEMMDefaultTypeValhall::default_q8(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
244 {
245     ARM_COMPUTE_UNUSED(m, n, k, b);
246 
247     if(is_rhs_constant)
248     {
249         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
250     }
251     else
252     {
253         return CLGEMMKernelType::NATIVE;
254     }
255 }
256 
g78_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)257 CLGEMMKernelType CLGEMMDefaultTypeValhall::g78_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
258 {
259     ARM_COMPUTE_UNUSED(b);
260 
261     if(!is_rhs_constant)
262     {
263         return CLGEMMKernelType::NATIVE;
264     }
265 
266     if(m == 1)
267     {
268         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
269     }
270 
271     if(n <= 272.0000f)
272     {
273         return CLGEMMKernelType::RESHAPED_ONLY_RHS;
274     }
275     else
276     {
277         if(k <= 471.0000f)
278         {
279             return CLGEMMKernelType::RESHAPED_ONLY_RHS;
280         }
281         else
282         {
283             if(m <= 72.5000f)
284             {
285                 return CLGEMMKernelType::RESHAPED_ONLY_RHS;
286             }
287             else
288             {
289                 if(m <= 90.5000f)
290                 {
291                     return CLGEMMKernelType::RESHAPED;
292                 }
293                 else
294                 {
295                     if(k <= 2448.0000f)
296                     {
297                         if(n <= 756.0000f)
298                         {
299                             return CLGEMMKernelType::RESHAPED_ONLY_RHS;
300                         }
301                         else
302                         {
303                             return CLGEMMKernelType::RESHAPED;
304                         }
305                     }
306                     else
307                     {
308                         return CLGEMMKernelType::RESHAPED;
309                     }
310                 }
311             }
312         }
313     }
314 }
315 
g78_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)316 CLGEMMKernelType CLGEMMDefaultTypeValhall::g78_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
317 {
318     ARM_COMPUTE_UNUSED(m, n, k, b);
319 
320     if(!is_rhs_constant)
321     {
322         return CLGEMMKernelType::NATIVE;
323     }
324 
325     return CLGEMMKernelType::RESHAPED_ONLY_RHS;
326 }
327 
g715_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)328 CLGEMMKernelType CLGEMMDefaultTypeValhall::g715_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
329 {
330     if(!is_rhs_constant)
331     {
332         return default_f32(m, n, k, b, is_rhs_constant);
333     }
334 
335     unsigned int best_m0;
336     unsigned int best_n0;
337 
338     if(opencl::kernels::gemm::is_mmul_kernel_preferred(m, n, k, b, DataType::F32, best_m0, best_n0))
339     {
340         return CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL;
341     }
342     else
343     {
344         return default_f32(m, n, k, b, is_rhs_constant);
345     }
346 }
347 
g715_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b,bool is_rhs_constant)348 CLGEMMKernelType CLGEMMDefaultTypeValhall::g715_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b, bool is_rhs_constant)
349 {
350     if(!is_rhs_constant)
351     {
352         return g78_f16(m, n, k, b, is_rhs_constant);
353     }
354 
355     unsigned int best_m0;
356     unsigned int best_n0;
357 
358     if(opencl::kernels::gemm::is_mmul_kernel_preferred(m, n, k, b, DataType::F16, best_m0, best_n0))
359     {
360         return CLGEMMKernelType::RESHAPED_ONLY_RHS_MMUL;
361     }
362     else
363     {
364         return g78_f16(m, n, k, b, is_rhs_constant);
365     }
366 }
367 
368 } // namespace cl_gemm
369 } // namespace arm_compute
370