1 /*
2  * Copyright (c) 2019-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/gpu/cl/kernels/gemm/reshaped/ClGemmDefaultConfigReshapedBifrost.h"
25 
26 #include "arm_compute/core/CL/CLHelpers.h"
27 #include "arm_compute/core/CL/CLKernelLibrary.h"
28 #include "arm_compute/core/GPUTarget.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/TensorShape.h"
31 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
32 #include "src/gpu/cl/kernels/gemm/ClGemmHelpers.h"
33 
34 #include <utility>
35 
36 namespace arm_compute
37 {
38 namespace opencl
39 {
40 namespace kernels
41 {
42 namespace gemm
43 {
44 using namespace arm_compute::misc::shape_calculator;
45 
ClGemmDefaultConfigReshapedBifrost(GPUTarget gpu)46 ClGemmDefaultConfigReshapedBifrost::ClGemmDefaultConfigReshapedBifrost(GPUTarget gpu)
47     : IClGemmKernelConfig(gpu)
48 {
49 }
50 
configure(unsigned int m,unsigned int n,unsigned int k,unsigned int b,DataType data_type)51 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure(unsigned int m, unsigned int n, unsigned int k, unsigned int b, DataType data_type)
52 {
53     using ConfigurationFunctionExecutorPtr = std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> (ClGemmDefaultConfigReshapedBifrost::*)(unsigned int m, unsigned int n, unsigned int k, unsigned int b);
54 
55     CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G7x(&ClGemmDefaultConfigReshapedBifrost::configure_G7x_f32,
56                                                                     &ClGemmDefaultConfigReshapedBifrost::configure_G7x_f16,
57                                                                     &ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8);
58 
59     CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G52(&ClGemmDefaultConfigReshapedBifrost::configure_G52_f32,
60                                                                     &ClGemmDefaultConfigReshapedBifrost::configure_G52_f16,
61                                                                     &ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8);
62 
63     CLGEMMConfigArray<ConfigurationFunctionExecutorPtr> configs_G76(&ClGemmDefaultConfigReshapedBifrost::configure_G76_f32,
64                                                                     &ClGemmDefaultConfigReshapedBifrost::configure_G76_f16,
65                                                                     &ClGemmDefaultConfigReshapedBifrost::configure_G76_u8);
66 
67     ConfigurationFunctionExecutorPtr func = nullptr;
68 
69     switch(_target)
70     {
71         case GPUTarget::G76:
72             func = configs_G76.get_function(data_type);
73             break;
74         case GPUTarget::G52:
75             func = configs_G52.get_function(data_type);
76             break;
77         default:
78             func = configs_G7x.get_function(data_type);
79             break;
80     }
81 
82     ARM_COMPUTE_ERROR_ON_MSG(func == nullptr, "Data type not support for GEMM");
83     return (this->*func)(m, n, k, b);
84 }
85 
configure_G7x_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)86 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G7x_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
87 {
88     ARM_COMPUTE_UNUSED(k);
89     ARM_COMPUTE_UNUSED(b);
90 
91     if(n <= 4)
92     {
93         return configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, true, false, false, true);
94     }
95     else
96     {
97         return configure_lhs_rhs_info(m, n, 5, 4, 4, 2, 16, false, true, false, true);
98     }
99 }
100 
configure_G7x_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)101 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G7x_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
102 {
103     ARM_COMPUTE_UNUSED(k);
104     ARM_COMPUTE_UNUSED(b);
105 
106     if(n <= 4)
107     {
108         return configure_lhs_rhs_info(m, n, 4, 2, 8, 8, 2, true, true, true, false);
109     }
110     else
111     {
112         return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false);
113     }
114 }
115 
configure_G7x_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)116 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G7x_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
117 {
118     ARM_COMPUTE_UNUSED(k);
119     ARM_COMPUTE_UNUSED(b);
120 
121     if(dot8_supported(CLKernelLibrary::get().get_device()))
122     {
123         if(n <= 4)
124         {
125             return configure_lhs_rhs_info(m, n, 4, 2, 16, 2, 2, true, false, false, true);
126         }
127         else
128         {
129             return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, true, false, false, true);
130         }
131     }
132     else
133     {
134         if(n <= 4)
135         {
136             return configure_lhs_rhs_info(m, n, 4, 2, 8, 2, 2, true, false, false, true);
137         }
138         else
139         {
140             return configure_lhs_rhs_info(m, n, 6, 4, 4, 2, 2, true, true, false, true);
141         }
142     }
143 }
144 
configure_G52_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)145 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G52_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
146 {
147     const float r_mn     = static_cast<float>(m) / static_cast<float>(n);
148     const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
149     const float r_mk     = static_cast<float>(m) / static_cast<float>(k);
150     const float r_nk     = static_cast<float>(n) / static_cast<float>(k);
151 
152     GEMMLHSMatrixInfo lhs_info_buf;
153     GEMMRHSMatrixInfo rhs_info_buf;
154     GEMMLHSMatrixInfo lhs_info_img;
155     GEMMRHSMatrixInfo rhs_info_img;
156 
157     if(workload <= 274.4000f)
158     {
159         if(r_nk <= 0.7461f)
160         {
161             if(r_mn <= 21.1667f)
162             {
163                 return configure_lhs_rhs_info(m, n, 4, 2, 4, 4, 4, false, true, true, false, false);
164             }
165             else
166             {
167                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
168                 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
169 
170                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
171                                            std::make_pair(lhs_info_buf, rhs_info_buf),
172                                            n, k, b, DataType::F32);
173             }
174         }
175         else
176         {
177             std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
178             std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
179 
180             return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
181                                        std::make_pair(lhs_info_buf, rhs_info_buf),
182                                        n, k, b, DataType::F32);
183         }
184     }
185     else
186     {
187         if(r_mk <= 17.3926f)
188         {
189             if(workload <= 542.4000f)
190             {
191                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
192                 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
193 
194                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
195                                            std::make_pair(lhs_info_buf, rhs_info_buf),
196                                            n, k, b, DataType::F32);
197             }
198             else
199             {
200                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true);
201                 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false);
202 
203                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
204                                            std::make_pair(lhs_info_buf, rhs_info_buf),
205                                            n, k, b, DataType::F32);
206             }
207         }
208         else
209         {
210             if(r_nk <= 0.5463f)
211             {
212                 if(workload <= 11767.6001f)
213                 {
214                     std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
215                     std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
216 
217                     return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
218                                                std::make_pair(lhs_info_buf, rhs_info_buf),
219                                                n, k, b, DataType::F32);
220                 }
221                 else
222                 {
223                     std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, true);
224                     std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 1, true, true, false, true, false);
225 
226                     return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
227                                                std::make_pair(lhs_info_buf, rhs_info_buf),
228                                                n, k, b, DataType::F32);
229                 }
230             }
231             else
232             {
233                 std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, true);
234                 std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 4, 4, 2, true, true, false, true, false);
235 
236                 return select_lhs_rhs_info(std::make_pair(lhs_info_img, rhs_info_img),
237                                            std::make_pair(lhs_info_buf, rhs_info_buf),
238                                            n, k, b, DataType::F32);
239             }
240         }
241     }
242 }
243 
configure_G52_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)244 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G52_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
245 {
246     ARM_COMPUTE_UNUSED(k);
247 
248     const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
249 
250     if(workload <= 323.4000f)
251     {
252         return configure_lhs_rhs_info(m, n, 2, 2, 8, 4, 8, false, false, false, true, false);
253     }
254     else
255     {
256         return configure_lhs_rhs_info(m, n, 4, 8, 4, 2, 2, true, true, true, false, false);
257     }
258 }
259 
configure_G76_f32(unsigned int m,unsigned int n,unsigned int k,unsigned int b)260 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G76_f32(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
261 {
262     ARM_COMPUTE_UNUSED(k);
263     ARM_COMPUTE_UNUSED(b);
264 
265     GEMMLHSMatrixInfo lhs_info_buf;
266     GEMMRHSMatrixInfo rhs_info_buf;
267     GEMMLHSMatrixInfo lhs_info_img;
268     GEMMRHSMatrixInfo rhs_info_img;
269 
270     // Get lhs_info/rhs_info in case of OpenCL buffer
271     if(n <= 4)
272     {
273         std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 2, 8, 16, 16, true, false, false, true);
274     }
275     else
276     {
277         std::tie(lhs_info_buf, rhs_info_buf) = configure_lhs_rhs_info(m, n, 4, 4, 2, 8, 16, false, false, false, true);
278     }
279 
280     // Get lhs_info/rhs_info in case of OpenCL image
281     // Condition on the GPU workload
282     if((m / 4) * (n / 4) >= 2560)
283     {
284         // Big workload
285         std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 4, 4, 4, 2, 8, true, true, true, false, true);
286     }
287     else
288     {
289         // Small workload
290         std::tie(lhs_info_img, rhs_info_img) = configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 1, true, true, true, false, true);
291     }
292 
293     const TensorInfo  tensor_rhs_info(TensorShape(n, k, b), 1, DataType::F32);
294     const TensorShape shape = compute_rhs_reshaped_shape(tensor_rhs_info, rhs_info_img);
295     const TensorInfo  tensor_reshaped_info(shape, 1, DataType::F32);
296 
297     // In case of vector by matrix with few work-items, we use the OpenCL buffer rather than the OpenCL image2d
298     const bool use_cl_image2d = (n <= 4) ? false : true;
299 
300     if(bool(validate_image2d_support_on_rhs(tensor_reshaped_info, rhs_info_img)) && use_cl_image2d)
301     {
302         return std::make_pair(lhs_info_img, rhs_info_img);
303     }
304     else
305     {
306         return std::make_pair(lhs_info_buf, rhs_info_buf);
307     }
308 }
309 
configure_G76_f16(unsigned int m,unsigned int n,unsigned int k,unsigned int b)310 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G76_f16(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
311 {
312     const float workload = (static_cast<float>(m) * static_cast<float>(n) * static_cast<float>(b)) / 20.0f;
313     const float r_mk     = static_cast<float>(m) / static_cast<float>(k);
314 
315     if(workload <= 1595.2000f)
316     {
317         if(r_mk <= 2.1044f)
318         {
319             if(workload <= 870.4000f)
320             {
321                 return configure_lhs_rhs_info(m, n, 2, 4, 4, 1, 2, true, false, true, false, false);
322             }
323             else
324             {
325                 return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2, false, false, true, false, false);
326             }
327         }
328         else
329         {
330             return configure_lhs_rhs_info(m, n, 4, 2, 4, 2, 2, false, false, true, false, false);
331         }
332     }
333     else
334     {
335         return configure_lhs_rhs_info(m, n, 4, 8, 4, 4, 2, true, true, true, false, false);
336     }
337 }
338 
configure_G76_u8(unsigned int m,unsigned int n,unsigned int k,unsigned int b)339 std::pair<GEMMLHSMatrixInfo, GEMMRHSMatrixInfo> ClGemmDefaultConfigReshapedBifrost::configure_G76_u8(unsigned int m, unsigned int n, unsigned int k, unsigned int b)
340 {
341     ARM_COMPUTE_UNUSED(k);
342     ARM_COMPUTE_UNUSED(b);
343 
344     if(n <= 4)
345     {
346         return configure_lhs_rhs_info(m, n, 4, 2, 16, 4, 1, false, false, false, true);
347     }
348     else
349     {
350         return configure_lhs_rhs_info(m, n, 4, 4, 16, 2, 2, false, true, false, true);
351     }
352 }
353 } // namespace gemm
354 } // namespace kernels
355 } // namespace opencl
356 } // namespace arm_compute
357