xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/tasks/conv_generic.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/conv_generic.h"
17 
18 #include <algorithm>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/strings/substitute.h"
24 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
25 #include "tensorflow/lite/delegates/gpu/common/shape.h"
26 #include "tensorflow/lite/delegates/gpu/common/status.h"
27 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
28 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
29 
30 namespace tflite {
31 namespace gpu {
32 
33 namespace {
GenerateUploadByThreads(const std::string & local_ptr_name,const std::string & name,bool use_ptrs,const std::string & global_offset_name,const std::string & lid_name,int total_work_items,int elements_to_upload)34 std::string GenerateUploadByThreads(const std::string& local_ptr_name,
35                                     const std::string& name, bool use_ptrs,
36                                     const std::string& global_offset_name,
37                                     const std::string& lid_name,
38                                     int total_work_items,
39                                     int elements_to_upload) {
40   std::string c;
41   std::string offset =
42       global_offset_name.empty() ? "" : global_offset_name + " + ";
43   const int groups = elements_to_upload / total_work_items;
44   const int reminder = elements_to_upload % total_work_items;
45   const std::string access_start = name + (use_ptrs ? "[" : ".Read(");
46   const std::string access_end = use_ptrs ? "]" : ")";
47   for (int i = 0; i < groups; ++i) {
48     c += "    " + local_ptr_name + "[" + lid_name + " + " +
49          std::to_string(total_work_items * i) + "] = " + access_start + offset +
50          lid_name + " + " + std::to_string(total_work_items * i) + access_end +
51          ";\n";
52   }
53   if (reminder != 0) {
54     c += "    if (" + lid_name + " < " + std::to_string(reminder) + ") {\n";
55     c += "      " + local_ptr_name + "[" + lid_name + " + " +
56          std::to_string(total_work_items * groups) + "] = " + access_start +
57          offset + lid_name + " + " + std::to_string(total_work_items * groups) +
58          access_end + ";\n";
59     c += "    }\n";
60   }
61   return c;
62 }
63 
GenerateAsyncUpload(const std::string & local_ptr_name,const std::string & global_ptr_name,const std::string & global_offset_name,int elements_to_upload)64 std::string GenerateAsyncUpload(const std::string& local_ptr_name,
65                                 const std::string& global_ptr_name,
66                                 const std::string& global_offset_name,
67                                 int elements_to_upload) {
68   std::string c;
69   std::string offset =
70       global_offset_name.empty() ? "" : " + " + global_offset_name;
71   c += "    async_work_group_copy(" + local_ptr_name + ", " + global_ptr_name +
72        offset + ", " + std::to_string(elements_to_upload) + ", 0);\n";
73   return c;
74 }
75 
GenerateBlockCoords(const int4 & block_size,const int3 & work_group_launch_order,bool linear_spatial,bool linear_all,bool need_depth,bool need_batch)76 std::string GenerateBlockCoords(const int4& block_size,
77                                 const int3& work_group_launch_order,
78                                 bool linear_spatial, bool linear_all,
79                                 bool need_depth, bool need_batch) {
80   std::string c;
81   int3 launch_remap;
82   launch_remap[work_group_launch_order.x] = 0;
83   launch_remap[work_group_launch_order.y] = 1;
84   launch_remap[work_group_launch_order.z] = 2;
85   if (linear_all) {
86     c += "  int linear_all = GLOBAL_ID_0;\n";
87     if (need_batch) {
88       c += "  int B = linear_all % args.task_size_b;\n";
89       c += "  linear_all = linear_all / args.task_size_b;\n";
90     }
91     c += "  int DST_X = linear_all % args.task_size_x;\n";
92     c += "  linear_all = linear_all / args.task_size_x;\n";
93     c += "  int DST_Y = linear_all % args.task_size_y;\n";
94     c += "  linear_all = linear_all / args.task_size_y;\n";
95     if (need_depth) {
96       c += "  int DST_Z = linear_all % args.task_size_z;\n";
97       c += "  linear_all = linear_all / args.task_size_z;\n";
98     }
99     c += "  int DST_S = linear_all;\n";
100   } else if (linear_spatial) {
101     if (work_group_launch_order[0] == 0) {
102       c += "  int linear_spatial = GLOBAL_ID_0;\n";
103     } else {
104       c += "  int linear_spatial = GROUP_ID_" +
105            std::to_string(launch_remap[0]) + " * GROUP_SIZE_0 + LOCAL_ID_0;\n";
106     }
107     if (need_batch) {
108       c += "  int B = linear_spatial % args.task_size_b;\n";
109       c += "  linear_spatial = linear_spatial / args.task_size_b;\n";
110     }
111     c += "  int DST_X = linear_spatial % args.task_size_x;\n";
112     c += "  linear_spatial = linear_spatial / args.task_size_x;\n";
113     c += "  int DST_Y = linear_spatial % args.task_size_y;\n";
114     c += "  linear_spatial = linear_spatial / args.task_size_y;\n";
115     if (need_depth) {
116       c += "  int DST_Z = linear_spatial;\n";
117     }
118     if (work_group_launch_order[1] == 1) {
119       c += "  int DST_S = GLOBAL_ID_1;\n";
120     } else {
121       c += "  int DST_S = GROUP_ID_" + std::to_string(launch_remap[1]) +
122            " * GROUP_SIZE_1 + LOCAL_ID_1;\n";
123     }
124   } else {
125     if (work_group_launch_order[0] == 0) {
126       c += "  int DST_X = GLOBAL_ID_0;\n";
127     } else {
128       c += "  int DST_X = GROUP_ID_" + std::to_string(launch_remap[0]) +
129            " * GROUP_SIZE_0 + LOCAL_ID_0;\n";
130     }
131     if (need_batch) {
132       c += "  int B = DST_X % args.task_size_b;\n";
133       c += "  DST_X = DST_X / args.task_size_b;\n";
134     }
135     std::string global_id_1;
136     if (work_group_launch_order[1] == 1) {
137       global_id_1 = "GLOBAL_ID_1";
138     } else {
139       global_id_1 = "GROUP_ID_" + std::to_string(launch_remap[1]) +
140                     " * GROUP_SIZE_1 + LOCAL_ID_1";
141     }
142     if (need_depth) {
143       c += "  int linear_id_1 = " + global_id_1 + ";\n";
144       c += "  int DST_Y = linear_id_1 % args.task_size_y;\n";
145       c += "  int DST_Z = linear_id_1 / args.task_size_y;\n";
146     } else {
147       c += "  int DST_Y = " + global_id_1 + ";\n";
148     }
149     if (work_group_launch_order[2] == 2) {
150       c += "  int DST_S = GLOBAL_ID_2;\n";
151     } else {
152       c += "  int DST_S = GROUP_ID_" + std::to_string(launch_remap[2]) +
153            " * GROUP_SIZE_2 + LOCAL_ID_2;\n";
154     }
155   }
156   if (block_size.x != 1) {
157     c += "  DST_X *= " + std::to_string(block_size.x) + ";\n";
158   }
159   if (block_size.y != 1) {
160     c += "  DST_Y *= " + std::to_string(block_size.y) + ";\n";
161   }
162   if (need_depth && block_size.z != 1) {
163     c += "  DST_Z *= " + std::to_string(block_size.z) + ";\n";
164   }
165   if (block_size.w != 1) {
166     c += "  DST_S *= " + std::to_string(block_size.w) + ";\n";
167   }
168 
169   return c;
170 }
171 }  // namespace
172 
ConvGeneric(const OperationDef & definition,const Convolution2DAttributes & attr,const GpuInfo & gpu_info,const BHWC * dst_shape)173 ConvGeneric::ConvGeneric(const OperationDef& definition,
174                          const Convolution2DAttributes& attr,
175                          const GpuInfo& gpu_info, const BHWC* dst_shape)
176     : GPUOperation(definition),
177       stride_(attr.strides.w, attr.strides.h, 1, 1),
178       padding_(-attr.padding.prepended.w, -attr.padding.prepended.h, 0, 0),
179       kernel_size_(attr.weights.shape.w, attr.weights.shape.h, 1, 1),
180       dilation_(attr.dilations.w, attr.dilations.h, 1, 1),
181       conv_params_(GuessBestParams(gpu_info, definition, attr, dst_shape)) {
182   const int src_slices = DivideRoundUp(attr.weights.shape.i, 4);
183   const int dst_slices = DivideRoundUp(attr.weights.shape.o, 4);
184   if (attr.groups != 1) {
185     conv_params_.groups_support = true;
186     const int dst_group_slices = dst_slices / attr.groups;
187     if (dst_group_slices % conv_params_.block_size.w != 0) {
188       if (conv_params_.block_size.w == 4 && dst_group_slices % 2 == 0) {
189         conv_params_.block_size.w = 2;
190       } else {
191         conv_params_.block_size.w = 1;
192       }
193     }
194     args_.AddInt("src_group_size", src_slices);
195     args_.AddInt("dst_group_size", dst_slices / attr.groups);
196   }
197 }
198 
ConvGeneric(const OperationDef & definition,const Convolution2DAttributes & attr,const BHWC & weights_shape,const GpuInfo & gpu_info,const BHWC * dst_shape)199 ConvGeneric::ConvGeneric(const OperationDef& definition,
200                          const Convolution2DAttributes& attr,
201                          const BHWC& weights_shape, const GpuInfo& gpu_info,
202                          const BHWC* dst_shape)
203     : GPUOperation(definition),
204       stride_(attr.strides.w, attr.strides.h, 1, 1),
205       padding_(-attr.padding.prepended.w, -attr.padding.prepended.h, 0, 0),
206       kernel_size_(weights_shape.w, weights_shape.h, 1, 1),
207       dilation_(attr.dilations.w, attr.dilations.h, 1, 1),
208       conv_params_(GuessBestParams(gpu_info, definition, attr, weights_shape,
209                                    dst_shape)) {}
210 
ConvGeneric(const OperationDef & definition,const FullyConnectedAttributes & attr,const GpuInfo & gpu_info,const BHWC * dst_shape)211 ConvGeneric::ConvGeneric(const OperationDef& definition,
212                          const FullyConnectedAttributes& attr,
213                          const GpuInfo& gpu_info, const BHWC* dst_shape)
214     : GPUOperation(definition),
215       stride_(1, 1, 1, 1),
216       padding_(0, 0, 0, 0),
217       kernel_size_(1, 1, 1, 1),
218       dilation_(1, 1, 1, 1),
219       conv_params_(GuessBestParams(gpu_info, definition, attr, dst_shape)) {}
220 
ConvGeneric(const OperationDef & definition)221 ConvGeneric::ConvGeneric(const OperationDef& definition)
222     : GPUOperation(definition),
223       stride_(1, 1, 1, 1),
224       padding_(0, 0, 0, 0),
225       kernel_size_(1, 1, 1, 1),
226       dilation_(1, 1, 1, 1) {}
227 
ConvGeneric(ConvGeneric && operation)228 ConvGeneric::ConvGeneric(ConvGeneric&& operation)
229     : GPUOperation(std::move(operation)),
230       stride_(operation.stride_),
231       padding_(operation.padding_),
232       kernel_size_(operation.kernel_size_),
233       dilation_(operation.dilation_),
234       conv_params_(operation.conv_params_) {}
235 
ConvGeneric(const OperationDef & definition,const Convolution3DAttributes & attr,const GpuInfo & gpu_info,const BHWDC * dst_shape)236 ConvGeneric::ConvGeneric(const OperationDef& definition,
237                          const Convolution3DAttributes& attr,
238                          const GpuInfo& gpu_info, const BHWDC* dst_shape)
239     : GPUOperation(definition),
240       stride_(attr.strides.w, attr.strides.h, attr.strides.d, 1),
241       padding_(-attr.padding.prepended.w, -attr.padding.prepended.h,
242                -attr.padding.prepended.d, 0),
243       kernel_size_(attr.weights.shape.w, attr.weights.shape.h,
244                    attr.weights.shape.d, 1),
245       dilation_(attr.dilations.w, attr.dilations.h, attr.dilations.d, 1),
246       conv_params_(GuessBestParams(gpu_info, definition, attr, dst_shape)) {}
247 
operator =(ConvGeneric && operation)248 ConvGeneric& ConvGeneric::operator=(ConvGeneric&& operation) {
249   if (this != &operation) {
250     std::swap(stride_, operation.stride_);
251     std::swap(padding_, operation.padding_);
252     std::swap(kernel_size_, operation.kernel_size_);
253     std::swap(dilation_, operation.dilation_);
254     std::swap(conv_params_, operation.conv_params_);
255     GPUOperation::operator=(std::move(operation));
256   }
257   return *this;
258 }
259 
GenerateCode(const GpuInfo & gpu_info)260 void ConvGeneric::GenerateCode(const GpuInfo& gpu_info) {
261   if (conv_params_.linear_all) {
262     grid_dimension_ = 1;
263   } else if (conv_params_.linear_spatial) {
264     grid_dimension_ = 2;
265   }
266 
267   AddSrcTensor("src_tensor", definition_.src_tensors[0]);
268   AddDstTensor("dst_tensor", definition_.dst_tensors[0]);
269   if (definition_.src_tensors.size() == 2) {  // dynamic weights
270     const DataType weights_type = definition_.GetDataType();
271     if (conv_params_.weights_layout == WeightsLayout::kOSpatialIOGroupI4O4 ||
272         conv_params_.weights_layout == WeightsLayout::kOSpatialIOGroupO4I4) {
273       definition_.src_tensors[1] = {weights_type, TensorStorageType::BUFFER,
274                                     Layout::HWC};
275       BufferDescriptor desc;
276       desc.element_type = weights_type;
277       desc.element_size = 4;
278       desc.memory_type = conv_params_.weights_upload_type ==
279                                  ConvGeneric::WeightsUploadType::CONSTANT_MEM
280                              ? MemoryType::CONSTANT
281                              : MemoryType::GLOBAL;
282 
283       AddSrcBuffer("weights", desc);
284     } else {
285       TensorDescriptor desc{weights_type, TensorStorageType::TEXTURE_2D,
286                             Layout::HW};
287       definition_.src_tensors[1] = desc;
288       definition_.src_tensors.push_back(desc);
289       definition_.src_tensors.push_back(desc);
290       definition_.src_tensors.push_back(desc);
291       for (int i = 0; i < 4; ++i) {
292         const std::string name = "weights" + std::to_string(i);
293         AddSrcTensor(name, definition_.src_tensors[1 + i]);
294       }
295     }
296   }
297 
298   code_ = GenerateConv(gpu_info, definition_, conv_params_);
299   if (definition_.precision == CalculationsPrecision::F16 &&
300       gpu_info.IsPowerVR()) {
301     compiler_options_.push_back(CompilerOptions::kClFastRelaxedMath);
302   }
303   if (gpu_info.IsMali()) {
304     compiler_options_.push_back(CompilerOptions::kClFastRelaxedMath);
305   }
306   if (conv_params_.IsPrivateMemBroadcast() && gpu_info.IsCL20OrHigher()) {
307     compiler_options_.push_back(CompilerOptions::kCl20);
308   }
309   bool kernel_is_trivial =
310       conv_params_.x_kernel_is_1 && conv_params_.y_kernel_is_1;
311   if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
312     kernel_is_trivial = kernel_is_trivial & conv_params_.z_kernel_is_1;
313   }
314   if (gpu_info.IsAdreno() && gpu_info.adreno_info.IsAdreno3xx() &&
315       definition_.precision == CalculationsPrecision::F16 &&
316       kernel_is_trivial) {
317     compiler_options_.push_back(CompilerOptions::kAdrenoFullSimd);
318   }
319 }
320 
BindArguments(ArgumentsBinder * args)321 absl::Status ConvGeneric::BindArguments(ArgumentsBinder* args) {
322   const int task_size_b = dst_[0]->Batch();
323   const int task_size_x =
324       DivideRoundUp(dst_[0]->Width(), conv_params_.block_size.x);
325   const int task_size_y =
326       DivideRoundUp(dst_[0]->Height(), conv_params_.block_size.y);
327   const int task_size_z =
328       DivideRoundUp(dst_[0]->Depth(), conv_params_.block_size.z);
329   RETURN_IF_ERROR(args->SetInt("task_size_b", task_size_b));
330   RETURN_IF_ERROR(args->SetInt("task_size_x", task_size_x));
331   RETURN_IF_ERROR(args->SetInt("task_size_y", task_size_y));
332   RETURN_IF_ERROR(args->SetInt("task_size_z", task_size_z));
333   return absl::OkStatus();
334 }
335 
GetGridSize() const336 int3 ConvGeneric::GetGridSize() const {
337   const int task_size_b = dst_[0]->Batch();
338   const int task_size_x =
339       DivideRoundUp(dst_[0]->Width(), conv_params_.block_size.x);
340   const int task_size_y =
341       DivideRoundUp(dst_[0]->Height(), conv_params_.block_size.y);
342   const int task_size_z =
343       DivideRoundUp(dst_[0]->Depth(), conv_params_.block_size.z);
344   const int task_size_s =
345       DivideRoundUp(dst_[0]->Slices(), conv_params_.block_size.w);
346   int3 wg;
347 
348   if (conv_params_.linear_all) {
349     return int3(
350         task_size_x * task_size_b * task_size_y * task_size_z * task_size_s, 1,
351         1);
352   } else if (conv_params_.linear_spatial) {
353     return int3(task_size_x * task_size_b * task_size_y * task_size_z,
354                 task_size_s, 1);
355   } else {
356     return int3(task_size_x * task_size_b, task_size_y * task_size_z,
357                 task_size_s);
358   }
359 }
360 
GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups) const361 void ConvGeneric::GetPossibleKernelWorkGroups(
362     TuningType tuning_type, const GpuInfo& gpu_info,
363     const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
364   if (conv_params_.weights_upload_type ==
365           WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP ||
366       conv_params_.weights_upload_type ==
367           WeightsUploadType::LOCAL_MEM_BY_THREADS ||
368       conv_params_.fixed_work_group_size) {
369     work_groups->push_back(work_group_size_);
370     return;
371   }
372   GetPossibleWorkGroupsConv(tuning_type, gpu_info, kernel_info, grid_size_,
373                             work_groups);
374 }
375 
GenerateConv(const GpuInfo & gpu_info,const OperationDef & op_def,const ConvParams & conv_params)376 std::string ConvGeneric::GenerateConv(const GpuInfo& gpu_info,
377                                       const OperationDef& op_def,
378                                       const ConvParams& conv_params) {
379   const auto& src_def = op_def.src_tensors[0];
380 
381   auto generate_id = [&](const std::string& x, const std::string& y,
382                          const std::string& z) {
383     std::string id;
384     if (src_def.HasAxis(Axis::WIDTH)) {
385       id += "_w" + x;
386     }
387     if (src_def.HasAxis(Axis::HEIGHT)) {
388       id += "_h" + y;
389     }
390     if (src_def.HasAxis(Axis::DEPTH)) {
391       id += "_d" + z;
392     }
393     return id;
394   };
395 
396   auto generate_id_full = [&](const std::string& x, const std::string& y,
397                               const std::string& z, const std::string& s) {
398     return generate_id(x, y, z) + "_s" + s;
399   };
400 
401   auto generate_check = [&](const std::string& x, const std::string& y,
402                             const std::string& z) {
403     std::string check;
404     const std::vector<Axis> axes{Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH};
405     const std::vector<std::string> names{"in_x", "in_y", "in_z"};
406     const std::vector<bool> is_1{conv_params_.x_kernel_is_1,
407                                  conv_params_.y_kernel_is_1,
408                                  conv_params_.z_kernel_is_1};
409     const std::vector<std::string> coords{x, y, z};
410     for (int i = 0; i < axes.size(); ++i) {
411       const auto& axis = axes[i];
412       if (src_def.HasAxis(axis) && !src_def.SupportsZeroClamp(axis, gpu_info) &&
413           !is_1[i]) {
414         if (!check.empty()) {
415           check += " && ";
416         }
417         check += names[i] + coords[i];
418       }
419     }
420     return check;
421   };
422 
423   if (!conv_params_.x_kernel_is_1) {
424     args_.AddInt("stride_x", stride_.x);
425     args_.AddInt("padding_x", padding_.x);
426     args_.AddInt("kernel_size_x", kernel_size_.x);
427     args_.AddInt("dilation_x", dilation_.x);
428   }
429   if (!conv_params_.y_kernel_is_1) {
430     args_.AddInt("stride_y", stride_.y);
431     args_.AddInt("padding_y", padding_.y);
432     args_.AddInt("kernel_size_y", kernel_size_.y);
433     args_.AddInt("dilation_y", dilation_.y);
434   }
435   if (src_def.HasAxis(Axis::DEPTH) && !conv_params_.z_kernel_is_1) {
436     args_.AddInt("stride_z", stride_.z);
437     args_.AddInt("padding_z", padding_.z);
438     args_.AddInt("kernel_size_z", kernel_size_.z);
439     args_.AddInt("dilation_z", dilation_.z);
440   }
441   args_.AddInt("task_size_b");
442   args_.AddInt("task_size_x");
443   args_.AddInt("task_size_y");
444   args_.AddInt("task_size_z");
445 
446   const int wg_total_size =
447       work_group_size_.x * work_group_size_.y * work_group_size_.z;
448   const std::string barrier =
449       wg_total_size == 32 && gpu_info.IsWaveSizeEqualTo32()
450           ? "SIMD_LOCAL_MEM_BARRIER"
451           : "LOCAL_MEM_BARRIER";
452 
453   const bool need_local_mem =
454       conv_params.weights_upload_type ==
455           ConvGeneric::WeightsUploadType::LOCAL_MEM_BY_THREADS ||
456       conv_params.weights_upload_type ==
457           ConvGeneric::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP;
458 
459   const int local_mem_size =
460       conv_params.block_size.w * 4 * conv_params.src_depth_loop_size;
461 
462   const bool use_simd_broadcast = conv_params.IsPrivateMemBroadcast();
463   const int simd_size = conv_params.simd_size;
464 
465   const bool late_oob_check = need_local_mem || use_simd_broadcast;
466 
467   const std::string weights_space =
468       conv_params.weights_upload_type ==
469               ConvGeneric::WeightsUploadType::CONSTANT_MEM
470           ? "__constant"
471           : "__global";
472 
473   const std::string weights_data_type =
474       conv_params.weights_data_type == DataType::FLOAT32 ? "float4" : "half4";
475 
476   const std::string weights_global_ptr =
477       weights_space + " " + weights_data_type + "*";
478 
479   std::string c;
480   if (use_simd_broadcast && gpu_info.IsApiOpenCl()) {
481     if (gpu_info.opencl_info.cl_version == OpenClVersion::kCl2_0) {
482       c += "#pragma OPENCL EXTENSION cl_khr_subgroups : enable\n";
483     } else if (gpu_info.SupportsExtension("cl_intel_subgroups")) {
484       c += "#pragma OPENCL EXTENSION cl_intel_subgroups : enable\n";
485     }
486   }
487   const int4 block_size = conv_params.block_size;
488   if (conv_params.fixed_work_group_size && gpu_info.IsApiOpenCl()) {
489     c += "__attribute__((reqd_work_group_size(" +
490          std::to_string(work_group_size_.x) + ", " +
491          std::to_string(work_group_size_.y) + ", " +
492          std::to_string(work_group_size_.z) + ")))\n";
493   }
494   if (use_simd_broadcast && gpu_info.IsIntel() && gpu_info.IsApiOpenCl()) {
495     c += "__attribute__((intel_reqd_sub_group_size(" +
496          std::to_string(simd_size) + ")))\n";
497   }
498   std::string dst_oob_check;
499   if (src_def.HasAxis(Axis::DEPTH)) {
500     if (conv_params.linear_all) {
501       dst_oob_check = "DST_S >= args.dst_tensor.Slices()";
502     } else if (conv_params.linear_spatial) {
503       dst_oob_check =
504           "DST_Z >= args.dst_tensor.Depth() || DST_S >= "
505           "args.dst_tensor.Slices()";
506     } else {
507       dst_oob_check =
508           "DST_X >= args.dst_tensor.Width() || DST_Z >= "
509           "args.dst_tensor.Depth() || DST_S >= args.dst_tensor.Slices()";
510     }
511   } else {
512     if (conv_params.linear_all) {
513       dst_oob_check = "DST_S >= args.dst_tensor.Slices()";
514     } else if (conv_params.linear_spatial) {
515       dst_oob_check =
516           "DST_Y >= args.dst_tensor.Height() || DST_S >= "
517           "args.dst_tensor.Slices()";
518     } else {
519       dst_oob_check =
520           "DST_X >= args.dst_tensor.Width() || DST_Y >= "
521           "args.dst_tensor.Height() || DST_S >= args.dst_tensor.Slices()";
522     }
523   }
524   c += "MAIN_FUNCTION($0) {\n";
525   c += GenerateBlockCoords(conv_params.block_size, work_group_launch_order_,
526                            conv_params.linear_spatial, conv_params.linear_all,
527                            src_def.HasAxis(Axis::DEPTH),
528                            src_def.HasAxis(Axis::BATCH));
529   if (src_def.HasAxis(Axis::BATCH)) {
530     c += "  args.src_tensor.SetBatchRef(B);\n";
531     c += "  args.dst_tensor.SetBatchRef(B);\n";
532   }
533   if (!conv_params.need_dst_loop) {
534     c += "  DST_S = 0;\n";
535   }
536   c += "  if (DST_S >= args.dst_tensor.Slices()) return;\n";
537   if (!late_oob_check) {
538     c += "  if (" + dst_oob_check + ") {\n";
539     c += "    return;\n";
540     c += "  }\n";
541   }
542   if (conv_params.groups_support) {
543     c += "      int conv_group_id = DST_S / args.dst_group_size;\n";
544     c += "      int src_start_slice = conv_group_id * args.src_group_size;\n";
545     c += "      int src_end_slice = src_start_slice + args.src_group_size;\n";
546   }
547   const std::string src_group_start_slice =
548       conv_params.groups_support ? "src_start_slice" : "0";
549   const std::string src_group_end_slice =
550       conv_params.groups_support ? "src_end_slice" : "args.src_tensor.Slices()";
551   const std::string src_group_slices = conv_params.groups_support
552                                            ? "args.src_group_size"
553                                            : "args.src_tensor.Slices()";
554   if (conv_params.weights_upload_type ==
555       ConvGeneric::WeightsUploadType::LOCAL_MEM_BY_THREADS) {
556     if (conv_params.linear_spatial) {
557       c += "  int lid = LOCAL_ID_0;\n";
558     } else {
559       c += "  int lid = LOCAL_ID_1 * " + std::to_string(work_group_size_.x) +
560            " + LOCAL_ID_0;\n";
561     }
562   }
563   if (use_simd_broadcast) {
564     c += "  int simd_id = SUB_GROUP_LOCAL_ID;\n";
565   }
566   for (int s = 0; s < block_size.w; ++s) {
567     const std::string sind = std::to_string(s);
568     for (int z = 0; z < block_size.z; ++z) {
569       const std::string zind = std::to_string(z);
570       for (int y = 0; y < block_size.y; ++y) {
571         const std::string yind = std::to_string(y);
572         for (int x = 0; x < block_size.x; ++x) {
573           const std::string xind = std::to_string(x);
574           c += "  ACCUM_FLT4 r" + generate_id_full(xind, yind, zind, sind) +
575                " = INIT_ACCUM_FLT4(0.0f);\n";
576         }
577       }
578     }
579   }
580   if (!conv_params_.x_kernel_is_1) {
581     for (int x = 0; x < block_size.x; ++x) {
582       const std::string xind = std::to_string(x);
583       const std::string xc = "(DST_X + " + xind + ")";
584       c += "  int xc" + xind + " = " + xc +
585            " * args.stride_x + args.padding_x;\n";
586     }
587   } else {
588     for (int x = 0; x < block_size.x; ++x) {
589       const std::string xind = std::to_string(x);
590       c += "  int xc" + xind + " = DST_X + " + xind + ";\n";
591       if (!src_def.CanReadOutOfBorder(Axis::WIDTH)) {
592         c += "  xc" + xind + " = clamp(xc" + xind +
593              ", 0, args.src_tensor.Width() - 1);\n";
594       }
595     }
596   }
597   if (!conv_params_.y_kernel_is_1) {
598     for (int y = 0; y < block_size.y; ++y) {
599       const std::string yind = std::to_string(y);
600       const std::string yc = "(DST_Y + " + yind + ")";
601       c += "  int yc" + yind + " = " + yc +
602            " * args.stride_y + args.padding_y;\n";
603     }
604   } else {
605     for (int y = 0; y < block_size.y; ++y) {
606       const std::string yind = std::to_string(y);
607       c += "  int yc" + yind + " = DST_Y + " + yind + ";\n";
608       if (!src_def.CanReadOutOfBorder(Axis::HEIGHT)) {
609         c += "  yc" + yind + " = clamp(yc" + yind +
610              ", 0, args.src_tensor.Height() - 1);\n";
611       }
612     }
613   }
614   if (src_def.HasAxis(Axis::DEPTH)) {
615     if (!conv_params_.z_kernel_is_1) {
616       for (int z = 0; z < block_size.z; ++z) {
617         const std::string zind = std::to_string(z);
618         const std::string zc = "(DST_Z + " + zind + ")";
619         c += "  int zc" + zind + " = " + zc +
620              " * args.stride_z + args.padding_z;\n";
621       }
622     } else {
623       for (int z = 0; z < block_size.z; ++z) {
624         const std::string zind = std::to_string(z);
625         c += "  int zc" + zind + " = DST_Z + " + zind + ";\n";
626         if (!src_def.CanReadOutOfBorder(Axis::DEPTH)) {
627           c += "  zc" + zind + " = clamp(zc" + zind +
628                ", 0, args.src_tensor.Depth() - 1);\n";
629         }
630       }
631     }
632   }
633   bool trivial_kernel_size =
634       conv_params_.x_kernel_is_1 && conv_params_.y_kernel_is_1;
635   if (src_def.HasAxis(Axis::DEPTH)) {
636     trivial_kernel_size = trivial_kernel_size && conv_params_.z_kernel_is_1;
637   }
638   if (need_local_mem) {
639     c += "  __local " + weights_data_type + " weights_cache[" +
640          std::to_string(local_mem_size) + "];\n";
641   } else if (conv_params.AreWeightsBuffer() &&
642              gpu_info.SupportsPointersInKernels()) {
643     c += "    " + weights_global_ptr + " weights_cache;\n";
644   } else if (!trivial_kernel_size) {
645     c += "  int filter_offset = 0;\n";
646   }
647   if (conv_params.AreWeightsBuffer()) {
648     std::string offset;
649     if (conv_params.different_weights_for_height) {
650       offset = "(DST_S * args.src_tensor.Height() + DST_Y * " +
651                std::to_string(block_size.w) +
652                ") * 4 * args.src_tensor.Slices()";
653     } else {
654       std::string kernel_spatial_offset = "";
655       if (!conv_params_.x_kernel_is_1) {
656         kernel_spatial_offset += " * args.kernel_size_x";
657       }
658       if (!conv_params_.y_kernel_is_1) {
659         kernel_spatial_offset += " * args.kernel_size_y";
660       }
661       if (src_def.HasAxis(Axis::DEPTH) && !conv_params_.z_kernel_is_1) {
662         kernel_spatial_offset += " * args.kernel_size_z";
663       }
664       offset = "DST_S * 4 * " + src_group_slices + kernel_spatial_offset;
665     }
666     if (gpu_info.SupportsPointersInKernels()) {
667       c += "  " + weights_global_ptr +
668            " filters_loc = args.weights.GetPtr() + " + offset + ";\n";
669     } else {
670       c += "  int filters_offset = " + offset + ";\n";
671     }
672   }
673   if (src_def.HasAxis(Axis::DEPTH) && !conv_params_.z_kernel_is_1) {
674     c += "  for (int kz = 0; kz < args.kernel_size_z; ++kz) {\n";
675     for (int z = 0; z < block_size.z; ++z) {
676       const std::string zck = "zck" + std::to_string(z);
677       c += "  int zck" + std::to_string(z) + " = kz * args.dilation_z + zc" +
678            std::to_string(z) + ";\n";
679       if (!src_def.SupportsZeroClamp(Axis::DEPTH, gpu_info)) {
680         c += "  bool in_z" + std::to_string(z) + " = " + zck + " >= 0 && " +
681              zck + " < args.src_tensor.Depth();\n";
682         if (!src_def.CanReadOutOfBorder(Axis::DEPTH)) {
683           c += "  " + zck + " = clamp(" + zck +
684                ", 0, args.src_tensor.Depth() - 1);\n";
685         }
686       }
687     }
688   }
689   if (!conv_params_.y_kernel_is_1) {
690     c += "  for (int ky = 0; ky < args.kernel_size_y; ++ky) {\n";
691     for (int y = 0; y < block_size.y; ++y) {
692       const std::string yck = "yck" + std::to_string(y);
693       c += "  int " + yck + " = ky * args.dilation_y + yc" + std::to_string(y) +
694            ";\n";
695       if (!src_def.SupportsZeroClamp(Axis::HEIGHT, gpu_info)) {
696         c += "  bool in_y" + std::to_string(y) + " = " + yck + " >= 0 && " +
697              yck + " < args.src_tensor.Height();\n";
698         if (!src_def.CanReadOutOfBorder(Axis::HEIGHT)) {
699           c += "  " + yck + " = clamp(" + yck +
700                ", 0, args.src_tensor.Height() - 1);\n";
701         }
702       }
703     }
704   }
705   if (!conv_params_.x_kernel_is_1) {
706     c += "  for (int kx = 0; kx < args.kernel_size_x; ++kx) {\n";
707     for (int x = 0; x < block_size.x; ++x) {
708       const std::string xck = "xck" + std::to_string(x);
709       c += "  int xck" + std::to_string(x) + " = kx * args.dilation_x + xc" +
710            std::to_string(x) + ";\n";
711       if (!src_def.SupportsZeroClamp(Axis::WIDTH, gpu_info)) {
712         c += "  bool in_x" + std::to_string(x) + " = " + xck + " >= 0 && " +
713              xck + " < args.src_tensor.Width();\n";
714         if (!src_def.CanReadOutOfBorder(Axis::WIDTH)) {
715           c += "  " + xck + " = clamp(" + xck +
716                ", 0, args.src_tensor.Width() - 1);\n";
717         }
718       }
719     }
720   }
721   const bool need_multiple_slice_strides =
722       src_def.ReturnsZeroForNegOneRead(gpu_info) && !trivial_kernel_size;
723   for (int z = 0; z < block_size.z; ++z) {
724     const std::string zind = std::to_string(z);
725     for (int y = 0; y < block_size.y; ++y) {
726       const std::string yind = std::to_string(y);
727       for (int x = 0; x < block_size.x; ++x) {
728         const std::string xind = std::to_string(x);
729         std::string xc = conv_params.x_kernel_is_1 ? "xc" + xind : "xck" + xind;
730         std::string yc = conv_params.y_kernel_is_1 ? "yc" + yind : "yck" + yind;
731         const std::string id = generate_id(xind, yind, zind);
732         std::string coords = "" + xc + ", " + yc;
733         if (src_def.HasAxis(Axis::DEPTH)) {
734           std::string zc =
735               conv_params.z_kernel_is_1 ? "zc" + zind : "zck" + zind;
736           coords += ", " + zc;
737         }
738         if (src_def.IsLinear()) {
739           c += "  int addr" + id + " = args.src_tensor.GetAddress(" + coords +
740                ", " + src_group_start_slice + ");\n";
741           if (need_multiple_slice_strides) {
742             const std::string check = generate_check(xind, yind, zind);
743             c += "  addr" + id + " = select(-1, addr" + id + ", (" + check +
744                  "));\n";
745             c += "  int ds" + id +
746                  " = select(0, args.src_tensor.SliceStride(), (" + check +
747                  "));\n";
748           }
749         }
750       }
751     }
752   }
753   if (src_def.IsLinear() && !need_multiple_slice_strides) {
754     c += "  int ds = args.src_tensor.SliceStride();\n";
755   }
756 
757   auto declare_src = [&]() {
758     for (int z = 0; z < block_size.z; ++z) {
759       const std::string zind = std::to_string(z);
760       for (int y = 0; y < block_size.y; ++y) {
761         const std::string yind = std::to_string(y);
762         for (int x = 0; x < block_size.x; ++x) {
763           const std::string xind = std::to_string(x);
764           const std::string id = generate_id(xind, yind, zind);
765           c += "    " + weights_data_type + " src" + id + ";\n";
766         }
767       }
768     }
769   };
770   const bool conditional_read = gpu_info.IsMali();
771   auto read_src = [&]() {
772     const std::string cl_type = ToCLDataType(conv_params.weights_data_type);
773     for (int z = 0; z < block_size.z; ++z) {
774       const std::string zind = std::to_string(z);
775       for (int y = 0; y < block_size.y; ++y) {
776         const std::string yind = std::to_string(y);
777         for (int x = 0; x < block_size.x; ++x) {
778           const std::string xind = std::to_string(x);
779           std::string id = generate_id(xind, yind, zind);
780           const std::string check = generate_check(xind, yind, zind);
781           std::string address;
782           if (src_def.IsLinear()) {
783             address = "addr" + id;
784           } else {
785             std::string xc =
786                 conv_params.x_kernel_is_1 ? "xc" + xind : "xck" + xind;
787             std::string yc =
788                 conv_params.y_kernel_is_1 ? "yc" + yind : "yck" + yind;
789             address = "" + xc + ", " + yc;
790             if (src_def.HasAxis(Axis::DEPTH)) {
791               std::string zc =
792                   conv_params.z_kernel_is_1 ? "zc" + zind : "zck" + zind;
793               address += ", " + zc;
794             }
795             address += ", s";
796           }
797           if (src_def.ReturnsZeroForNegOneRead(gpu_info)) {
798             c += "    src" + id + " = args.src_tensor.Read<" + cl_type + ">(" +
799                  address + ");\n";
800             const std::string ds = trivial_kernel_size ? "ds" : "ds" + id;
801             c += "    " + address + " += " + ds + ";\n";
802           } else {
803             if (!check.empty()) {
804               if (conditional_read) {
805                 c += "    src" + id + " = " + check +
806                      " ? args.src_tensor.Read<" + cl_type + ">(" + address +
807                      ") : INIT_FLT4(0.0f);\n";
808               } else {
809                 c += "    src" + id + " = args.src_tensor.Read<" + cl_type +
810                      ">(" + address + ") * INIT_FLT(" + check + ");\n";
811               }
812             } else {
813               c += "    src" + id + " = args.src_tensor.Read<" + cl_type +
814                    ">(" + address + ");\n";
815             }
816             if (src_def.IsLinear()) {
817               c += "    " + address + " += ds;\n";
818             }
819           }
820         }
821       }
822     }
823   };
824   bool use_fma = gpu_info.IsAMD() && gpu_info.IsApiOpenCl();
825   auto conv_core = [&](int shared_offset) {
826     const std::string channels[] = {"x", "y", "z", "w"};
827     for (int s = 0; s < block_size.w; ++s) {
828       const std::string sind = std::to_string(s);
829       if (op_def.precision != CalculationsPrecision::F32_F16) {
830         for (int ch = 0; ch < 4; ++ch) {
831           for (int z = 0; z < block_size.z; ++z) {
832             const std::string zind = std::to_string(z);
833             for (int y = 0; y < block_size.y; ++y) {
834               const std::string yind = std::to_string(y);
835               for (int x = 0; x < block_size.x; ++x) {
836                 const std::string xind = std::to_string(x);
837                 std::string R = "r" + generate_id_full(xind, yind, zind, sind);
838                 std::string S = "src" + generate_id(xind, yind, zind);
839                 if (use_simd_broadcast) {
840                   int simd_id = (s * 4 + ch + shared_offset) / simd_size;
841                   int thread_id = (s * 4 + ch + shared_offset) % simd_size;
842                   std::string w_val_x = "SUB_GROUP_BROADCAST(simd_w" +
843                                         std::to_string(simd_id) + ".x, " +
844                                         std::to_string(thread_id) + "u)";
845                   std::string w_val_y = "SUB_GROUP_BROADCAST(simd_w" +
846                                         std::to_string(simd_id) + ".y, " +
847                                         std::to_string(thread_id) + "u)";
848                   std::string w_val_z = "SUB_GROUP_BROADCAST(simd_w" +
849                                         std::to_string(simd_id) + ".z, " +
850                                         std::to_string(thread_id) + "u)";
851                   std::string w_val_w = "SUB_GROUP_BROADCAST(simd_w" +
852                                         std::to_string(simd_id) + ".w, " +
853                                         std::to_string(thread_id) + "u)";
854                   if (GetWeightsDescription().IsI4O4()) {
855                     c += "    " + R + ".x += " + w_val_x + " * " + S + "." +
856                          channels[ch] + ";\n";
857                     c += "    " + R + ".y += " + w_val_y + " * " + S + "." +
858                          channels[ch] + ";\n";
859                     c += "    " + R + ".z += " + w_val_z + " * " + S + "." +
860                          channels[ch] + ";\n";
861                     c += "    " + R + ".w += " + w_val_w + " * " + S + "." +
862                          channels[ch] + ";\n";
863                   } else {
864                     c += "    " + R + "." + channels[ch] + " += " + w_val_x +
865                          " * " + S + ".x;\n";
866                     c += "    " + R + "." + channels[ch] + " += " + w_val_y +
867                          " * " + S + ".y;\n";
868                     c += "    " + R + "." + channels[ch] + " += " + w_val_z +
869                          " * " + S + ".z;\n";
870                     c += "    " + R + "." + channels[ch] + " += " + w_val_w +
871                          " * " + S + ".w;\n";
872                   }
873                 } else {
874                   const std::string weight_id =
875                       std::to_string(s * 4 + ch + shared_offset);
876                   std::string w_val;
877                   if (conv_params.AreWeightsBuffer()) {
878                     if (gpu_info.SupportsPointersInKernels()) {
879                       w_val = "weights_cache[" + weight_id + "]";
880                     } else {
881                       w_val = "args.weights.Read(filters_offset + " +
882                               weight_id + ")";
883                     }
884                   } else {
885                     w_val = "f" + weight_id;
886                   }
887                   if (GetWeightsDescription().IsI4O4()) {
888                     if (use_fma) {
889                       c += "    " + R + " = fma(" + w_val + ", " + S + "." +
890                            channels[ch] + ", " + R + ");\n";
891                     } else {
892                       c += "    " + R + " += " + w_val + " * " + S + "." +
893                            channels[ch] + ";\n";
894                     }
895                   } else {
896                     c += "    " + R + "." + channels[ch] + " += dot(" + w_val +
897                          ", " + S + ");\n";
898                   }
899                 }
900               }
901             }
902           }
903         }
904       } else {  // F32_F16 precision
905         for (int z = 0; z < block_size.z; ++z) {
906           const std::string zind = std::to_string(z);
907           for (int y = 0; y < block_size.y; ++y) {
908             const std::string yind = std::to_string(y);
909             for (int x = 0; x < block_size.x; ++x) {
910               const std::string xind = std::to_string(x);
911               std::string R = "r" + generate_id_full(xind, yind, zind, sind);
912               std::string S = "src" + generate_id(xind, yind, zind);
913               std::vector<std::string> F(4);
914               for (int i = 0; i < 4; ++i) {
915                 std::string weight_id =
916                     std::to_string(s * 4 + i + shared_offset);
917                 if (conv_params.AreWeightsBuffer()) {
918                   if (gpu_info.SupportsPointersInKernels()) {
919                     F[i] = "weights_cache[" + weight_id + "]";
920                   } else {
921                     F[i] =
922                         "args.weights.Read(filters_offset + " + weight_id + ")";
923                   }
924                 } else {
925                   F[i] = "f" + weight_id;
926                 }
927               }
928               if (GetWeightsDescription().IsI4O4()) {
929                 c += "    " + R + " += TO_ACCUM_TYPE(" + S + ".x * " + F[0] +
930                      " + " + S + ".y * " + F[1] + " + " + S + ".z * " + F[2] +
931                      " + " + S + ".w * " + F[3] + ");\n";
932               } else {
933                 c += "    " + R + ".x += dot(" + S + ", " + F[0] + ");\n";
934                 c += "    " + R + ".y += dot(" + S + ", " + F[1] + ");\n";
935                 c += "    " + R + ".z += dot(" + S + ", " + F[2] + ");\n";
936                 c += "    " + R + ".w += dot(" + S + ", " + F[3] + ");\n";
937               }
938             }
939           }
940         }
941       }
942     }
943   };
944 
945   c += "  int s = " + src_group_start_slice + ";\n";
946   if (conv_params.need_src_loop) {
947     c += "  do {\n";
948   }
949   declare_src();
950   const int total_work_items =
951       work_group_size_.x * work_group_size_.y * work_group_size_.z;
952   if (conv_params.weights_upload_type ==
953       ConvGeneric::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP) {
954     c += GenerateAsyncUpload("weights_cache", "filters_loc",
955                              /*global_offset_name*/ "", local_mem_size);
956   } else if (conv_params.weights_upload_type ==
957              ConvGeneric::WeightsUploadType::LOCAL_MEM_BY_THREADS) {
958     if (gpu_info.IsApiMetal() && wg_total_size == 32 &&
959         gpu_info.IsWaveSizeEqualTo32()) {
960       c += "    SIMDGROUP_BARRIER(mem_flags::mem_none);\n";
961     } else {
962       c += "    " + barrier + ";\n";
963     }
964     if (gpu_info.SupportsPointersInKernels()) {
965       c += GenerateUploadByThreads(
966           "weights_cache", "filters_loc", /*use_ptrs*/ true,
967           /*global_offset_name*/ "", "lid", total_work_items, local_mem_size);
968     } else {
969       c += GenerateUploadByThreads("weights_cache", "args.weights",
970                                    /*use_ptrs*/ false, "filters_offset", "lid",
971                                    total_work_items, local_mem_size);
972     }
973   } else if (use_simd_broadcast) {
974     int parts = local_mem_size / simd_size;
975     int reminder = local_mem_size % simd_size;
976     const std::string read_start = gpu_info.SupportsPointersInKernels()
977                                        ? "filters_loc["
978                                        : "args.weights.Read(filters_offset + ";
979     const std::string read_end =
980         gpu_info.SupportsPointersInKernels() ? "]" : ")";
981     for (int i = 0; i < parts; ++i) {
982       const std::string weights_index =
983           "simd_id + " + std::to_string(i * simd_size);
984       c += "    FLT4 simd_w" + std::to_string(i) + " = " + read_start +
985            weights_index + read_end + ";\n";
986     }
987     if (reminder) {
988       const std::string weights_index =
989           "simd_id + " + std::to_string(parts * simd_size);
990       c += "    FLT4 simd_w" + std::to_string(parts) + ";\n";
991       c += "    if (simd_id < " + std::to_string(reminder) + ") {\n";
992       c += "      simd_w" + std::to_string(parts) + " = " + read_start +
993            weights_index + read_end + ";\n";
994       c += "    }\n";
995     }
996   } else if (conv_params.AreWeightsBuffer()) {  // GLOBAL_MEM/CONSTANT_MEM
997     if (gpu_info.SupportsPointersInKernels()) {
998       c += "    weights_cache = filters_loc;\n";
999     }
1000   } else {  // TEXTURES_MEM
1001     for (int dst_s = 0; dst_s < block_size.w; ++dst_s) {
1002       std::string f_y = trivial_kernel_size ? "s" : "filter_offset";
1003       if (trivial_kernel_size && conv_params.groups_support) {
1004         f_y = "s - src_start_slice";
1005       }
1006       if (conv_params.different_weights_for_height) {
1007         f_y = "DST_Y * args.src_tensor.Slices() + s";
1008       }
1009       c += absl::Substitute(
1010           R"(    FLT4 f$2 = args.weights0.Read(DST_S + $0, $1);
1011     FLT4 f$3 = args.weights1.Read(DST_S + $0, $1);
1012     FLT4 f$4 = args.weights2.Read(DST_S + $0, $1);
1013     FLT4 f$5 = args.weights3.Read(DST_S + $0, $1);
1014 )",
1015           dst_s, f_y, dst_s * 4 + 0, dst_s * 4 + 1, dst_s * 4 + 2,
1016           dst_s * 4 + 3);
1017     }
1018     if (!trivial_kernel_size) {
1019       c += "    filter_offset++;\n";
1020     }
1021   }
1022   read_src();
1023   c += "    s += 1;\n";
1024   if (conv_params.weights_upload_type ==
1025       ConvGeneric::WeightsUploadType::LOCAL_MEM_BY_THREADS) {
1026     c += "    " + barrier + ";\n";
1027   }
1028   conv_core(0);
1029   for (int i = 1; i < conv_params.src_depth_loop_size; ++i) {
1030     read_src();
1031     conv_core(i * block_size.w * 4);
1032     c += "    s += 1;\n";
1033   }
1034   if (conv_params.AreWeightsBuffer()) {
1035     if (gpu_info.SupportsPointersInKernels()) {
1036       c += "    filters_loc += " + std::to_string(local_mem_size) + ";\n";
1037     } else {
1038       c += "    filters_offset += " + std::to_string(local_mem_size) + ";\n";
1039     }
1040   }
1041   if (conv_params.need_src_loop) {
1042     c += "  } while (s < " + src_group_end_slice + ");\n";
1043   }
1044   if (!conv_params.x_kernel_is_1) {
1045     c += "  };\n";
1046   }
1047   if (!conv_params.y_kernel_is_1) {
1048     c += "  };\n";
1049   }
1050   if (src_def.HasAxis(Axis::DEPTH) && !conv_params_.z_kernel_is_1) {
1051     c += "  };\n";
1052   }
1053   if (conv_params.AreWeightsBuffer()) {
1054     if (conv_params.weights_upload_type ==
1055         ConvGeneric::WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP) {
1056       c += GenerateAsyncUpload("weights_cache", "args.biases.GetPtr()", "DST_S",
1057                                block_size.w);
1058     } else if (conv_params.weights_upload_type ==
1059                ConvGeneric::WeightsUploadType::LOCAL_MEM_BY_THREADS) {
1060       c += "  " + barrier + ";\n";
1061       c += GenerateUploadByThreads("weights_cache", "args.biases",
1062                                    /*use_ptrs*/ false, "DST_S", "lid",
1063                                    total_work_items, block_size.w);
1064       c += "  " + barrier + ";\n";
1065     } else if (gpu_info.SupportsPointersInKernels()) {
1066       c += "  weights_cache = args.biases.GetPtr() + DST_S;\n";
1067     }
1068   }
1069   if (late_oob_check) {
1070     c += "  if (" + dst_oob_check + ") {\n";
1071     c += "    return;\n";
1072     c += "  }\n";
1073   }
1074 
1075   auto generate_dst_check = [&](int x, int y, int z) {
1076     std::string check;
1077     const std::vector<Axis> axes{Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH};
1078     const std::vector<std::string> names{"Width()", "Height()", "Depth()"};
1079     std::vector<std::string> coords(3);
1080     coords[0] = "DST_X + " + std::to_string(x);
1081     coords[1] = "DST_Y + " + std::to_string(y);
1082     coords[2] = "DST_Z + " + std::to_string(z);
1083     const std::vector<int> ids{x, y, z};
1084     for (int i = 0; i < axes.size(); ++i) {
1085       const auto& axis = axes[i];
1086       if (src_def.HasAxis(axis) && ids[i] != 0) {
1087         if (!check.empty()) {
1088           check += " && ";
1089         }
1090         check += coords[i] + " < args.dst_tensor." + names[i];
1091       }
1092     }
1093     return check;
1094   };
1095 
1096   for (int s = 0; s < block_size.w; ++s) {
1097     const std::string sind = std::to_string(s);
1098     c += "  if (DST_S + " + sind + " >= args.dst_tensor.Slices()) return;\n";
1099     c += "  {\n";
1100     if (conv_params.AreWeightsBuffer() &&
1101         gpu_info.SupportsPointersInKernels()) {
1102       c += "    FLT4 bias_val = TO_FLT4(weights_cache[" + sind + "]);\n";
1103     } else {
1104       c += "    FLT4 bias_val = args.biases.Read(DST_S + " + sind + ");\n";
1105     }
1106     for (int z = 0; z < block_size.z; ++z) {
1107       const std::string zind = std::to_string(z);
1108       for (int y = 0; y < block_size.y; ++y) {
1109         const std::string yind = std::to_string(y);
1110         for (int x = 0; x < block_size.x; ++x) {
1111           const std::string xind = std::to_string(x);
1112           const std::string id = generate_id_full(xind, yind, zind, sind);
1113           const std::string check = generate_dst_check(x, y, z);
1114           std::string coords = "DST_X + " + xind + ", DST_Y + " + yind;
1115           if (src_def.HasAxis(Axis::DEPTH)) {
1116             coords += ", DST_Z + " + zind;
1117           }
1118           coords += ", DST_S + " + sind;
1119           if (!check.empty()) {
1120             c += "  if (" + check + ") {\n";
1121           } else {
1122             c += "  {\n";
1123           }
1124           c += "    FLT4 res = TO_FLT4(r" + id + ") + bias_val;\n";
1125           c += "    args.dst_tensor.Write(res, " + coords + ");\n";
1126           c += "  }\n";
1127         }
1128       }
1129     }
1130     c += "  }\n";
1131   }
1132   c += "}\n";
1133   return c;
1134 }
1135 
GetGroupsCount(const BHWC & dst_shape,const int3 & wg_size,const int4 & block_size)1136 int GetGroupsCount(const BHWC& dst_shape, const int3& wg_size,
1137                    const int4& block_size) {
1138   const int dst_slices = DivideRoundUp(dst_shape.c, 4);
1139 
1140   int grid_x = DivideRoundUp(dst_shape.w, block_size.x) * dst_shape.b;
1141   int grid_y = DivideRoundUp(dst_shape.h, block_size.y);
1142   int grid_z = DivideRoundUp(dst_slices, block_size.w);
1143 
1144   return DivideRoundUp(grid_x, wg_size.x) * DivideRoundUp(grid_y, wg_size.y) *
1145          DivideRoundUp(grid_z, wg_size.z);
1146 }
1147 
GetGroupsCountForLinearWH(const BHWC & dst_shape,const int3 & wg_size,const int4 & block_size)1148 int GetGroupsCountForLinearWH(const BHWC& dst_shape, const int3& wg_size,
1149                               const int4& block_size) {
1150   const int dst_slices = DivideRoundUp(dst_shape.c, 4);
1151 
1152   int grid_x = DivideRoundUp(dst_shape.w, block_size.x) * dst_shape.b;
1153   int grid_y = DivideRoundUp(dst_shape.h, block_size.y);
1154   int grid_z = DivideRoundUp(dst_slices, block_size.w);
1155 
1156   return DivideRoundUp(grid_x * grid_y, wg_size.x) *
1157          DivideRoundUp(grid_z, wg_size.y);
1158 }
1159 
GetGroupsCountForLinearWHS(const BHWC & dst_shape,const int3 & wg_size,const int4 & block_size)1160 int GetGroupsCountForLinearWHS(const BHWC& dst_shape, const int3& wg_size,
1161                                const int4& block_size) {
1162   const int dst_slices = DivideRoundUp(dst_shape.c, 4);
1163 
1164   int grid_x = DivideRoundUp(dst_shape.w, block_size.x) * dst_shape.b;
1165   int grid_y = DivideRoundUp(dst_shape.h, block_size.y);
1166   int grid_z = DivideRoundUp(dst_slices, block_size.w);
1167 
1168   return DivideRoundUp(grid_x * grid_y * grid_z, wg_size.x);
1169 }
1170 
IsKernelXIs1(const Convolution2DAttributes & attr)1171 bool IsKernelXIs1(const Convolution2DAttributes& attr) {
1172   return attr.weights.shape.w == 1 && attr.strides.w == 1 &&
1173          attr.dilations.w == 1 && attr.padding.prepended.w == 0 &&
1174          attr.padding.appended.w == 0;
1175 }
1176 
IsKernelYIs1(const Convolution2DAttributes & attr)1177 bool IsKernelYIs1(const Convolution2DAttributes& attr) {
1178   return attr.weights.shape.h == 1 && attr.strides.h == 1 &&
1179          attr.dilations.h == 1 && attr.padding.prepended.h == 0 &&
1180          attr.padding.appended.h == 0;
1181 }
1182 
GetMaximumPossibleWavesCount(const AppleInfo & apple_info,const BHWC & dst_shape)1183 int GetMaximumPossibleWavesCount(const AppleInfo& apple_info,
1184                                  const BHWC& dst_shape) {
1185   if (apple_info.IsLocalMemoryPreferredOverGlobal()) {
1186     return GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, int4(1, 1, 1, 1));
1187   } else {
1188     return GetGroupsCountForLinearWHS(dst_shape, {32, 1, 1}, int4(1, 1, 1, 1));
1189   }
1190 }
1191 
GetRecommendedBlockSize(const AppleInfo & apple_info,const BHWC & dst_shape)1192 int GetRecommendedBlockSize(const AppleInfo& apple_info,
1193                             const BHWC& dst_shape) {
1194   const int max_waves = GetMaximumPossibleWavesCount(apple_info, dst_shape);
1195   const int cu_count = apple_info.GetComputeUnitsCount();
1196   if (max_waves >= cu_count * 64) {
1197     return 8;
1198   } else if (max_waves >= cu_count * 32) {
1199     return 4;
1200   } else if (max_waves >= cu_count * 16) {
1201     return 2;
1202   } else {
1203     return 1;
1204   }
1205 }
1206 
1207 struct WorkGroupSizeOption {
1208   enum class ThreadMapping { kDefault, kLinearSpatial, kLinearAll };
1209   int3 work_group_size;
1210   int work_groups_count;
1211   ThreadMapping thread_mapping;
1212   float penalty = 1.0f;
1213 };
1214 
CreateWorkGroupSizeOption(const int3 & work_group_size,WorkGroupSizeOption::ThreadMapping mapping_type,float penalty,const BHWC & dst_shape,const int4 & block_size)1215 WorkGroupSizeOption CreateWorkGroupSizeOption(
1216     const int3& work_group_size,
1217     WorkGroupSizeOption::ThreadMapping mapping_type, float penalty,
1218     const BHWC& dst_shape, const int4& block_size) {
1219   WorkGroupSizeOption wg;
1220   wg.work_group_size = work_group_size;
1221   wg.thread_mapping = mapping_type;
1222   wg.penalty = penalty;
1223   if (mapping_type == WorkGroupSizeOption::ThreadMapping::kDefault) {
1224     wg.work_groups_count =
1225         GetGroupsCount(dst_shape, work_group_size, block_size);
1226   } else if (mapping_type ==
1227              WorkGroupSizeOption::ThreadMapping::kLinearSpatial) {
1228     wg.work_groups_count =
1229         GetGroupsCountForLinearWH(dst_shape, work_group_size, block_size);
1230   } else if (mapping_type == WorkGroupSizeOption::ThreadMapping::kLinearAll) {
1231     wg.work_groups_count =
1232         GetGroupsCountForLinearWHS(dst_shape, work_group_size, block_size);
1233   }
1234   return wg;
1235 }
1236 
GetConvParamsForA7A8(const AppleInfo & apple_info,bool x_kernel_is_1,bool y_kernel_is_1,int src_slices,const BHWC & dst_shape)1237 ConvGeneric::ConvParams GetConvParamsForA7A8(const AppleInfo& apple_info,
1238                                              bool x_kernel_is_1,
1239                                              bool y_kernel_is_1, int src_slices,
1240                                              const BHWC& dst_shape) {
1241   const int dst_slices = DivideRoundUp(dst_shape.c, 4);
1242   int blk_total_size = GetRecommendedBlockSize(apple_info, dst_shape);
1243   int3 block_size = int3(1, 1, 1);
1244   if (blk_total_size >= 4 && (dst_slices % 4 == 0 || dst_slices >= 16)) {
1245     block_size.z = 4;
1246     blk_total_size /= 4;
1247   } else if (blk_total_size >= 2 && (dst_slices % 2 == 0 || dst_slices >= 4)) {
1248     block_size.z = 2;
1249     blk_total_size /= 2;
1250   }
1251   if (blk_total_size >= 4) {
1252     block_size.x = 2;
1253     block_size.y = 2;
1254     blk_total_size /= 4;
1255   } else if (blk_total_size >= 2) {
1256     if (dst_shape.w % 2 != 0 && dst_shape.h % 2 == 0) {
1257       block_size.y = 2;
1258     } else {
1259       block_size.x = 2;
1260     }
1261     blk_total_size /= 2;
1262   }
1263 
1264   ConvGeneric::ConvParams params;
1265   params.weights_upload_type =
1266       ConvGeneric::WeightsUploadType::LOCAL_MEM_BY_THREADS;
1267   params.x_kernel_is_1 = x_kernel_is_1;
1268   params.y_kernel_is_1 = y_kernel_is_1;
1269   params.src_depth_loop_size = 1;
1270   params.block_size.x = block_size.x;
1271   params.block_size.y = block_size.y;
1272   params.block_size.z = 1;
1273   params.block_size.w = block_size.z;
1274   params.weights_layout = WeightsLayout::kOSpatialIOGroupO4I4;
1275 
1276   std::vector<WorkGroupSizeOption> options;
1277   options.push_back(CreateWorkGroupSizeOption(
1278       {8, 4, 1}, WorkGroupSizeOption::ThreadMapping::kDefault, 1.0f, dst_shape,
1279       params.block_size));
1280   options.push_back(CreateWorkGroupSizeOption(
1281       {4, 4, 1}, WorkGroupSizeOption::ThreadMapping::kDefault, 1.01f, dst_shape,
1282       params.block_size));
1283   options.push_back(CreateWorkGroupSizeOption(
1284       {4, 2, 1}, WorkGroupSizeOption::ThreadMapping::kDefault, 1.25f, dst_shape,
1285       params.block_size));
1286   options.push_back(CreateWorkGroupSizeOption(
1287       {32, 1, 1}, WorkGroupSizeOption::ThreadMapping::kLinearSpatial, 1.0f,
1288       dst_shape, params.block_size));
1289   options.push_back(CreateWorkGroupSizeOption(
1290       {16, 1, 1}, WorkGroupSizeOption::ThreadMapping::kLinearSpatial, 1.01f,
1291       dst_shape, params.block_size));
1292   options.push_back(CreateWorkGroupSizeOption(
1293       {8, 1, 1}, WorkGroupSizeOption::ThreadMapping::kLinearSpatial, 1.25f,
1294       dst_shape, params.block_size));
1295   options.push_back(CreateWorkGroupSizeOption(
1296       {32, 1, 1}, WorkGroupSizeOption::ThreadMapping::kLinearAll, 3.1 * 1.0f,
1297       dst_shape, params.block_size));
1298   options.push_back(CreateWorkGroupSizeOption(
1299       {16, 1, 1}, WorkGroupSizeOption::ThreadMapping::kLinearAll, 3.1 * 1.01f,
1300       dst_shape, params.block_size));
1301   options.push_back(CreateWorkGroupSizeOption(
1302       {8, 1, 1}, WorkGroupSizeOption::ThreadMapping::kLinearAll, 3.1 * 1.25f,
1303       dst_shape, params.block_size));
1304 
1305   float optimum = options[0].work_groups_count * options[0].penalty *
1306                   options[0].work_group_size.x * options[0].work_group_size.y *
1307                   options[0].work_group_size.z;
1308   int optimum_index = 0;
1309   for (int i = 1; i < options.size(); ++i) {
1310     float local_optimum = options[i].work_groups_count * options[i].penalty *
1311                           options[i].work_group_size.x *
1312                           options[i].work_group_size.y *
1313                           options[i].work_group_size.z;
1314     if (local_optimum < optimum) {
1315       optimum = local_optimum;
1316       optimum_index = i;
1317     }
1318   }
1319 
1320   WorkGroupSizeOption optimum_wg = options[optimum_index];
1321   if (optimum_wg.thread_mapping ==
1322       WorkGroupSizeOption::ThreadMapping::kLinearSpatial) {
1323     params.linear_spatial = true;
1324     params.linear_all = false;
1325     params.work_group_size = optimum_wg.work_group_size;
1326     params.work_group_launch_order = int3(1, 0, 2);
1327   } else if (optimum_wg.thread_mapping ==
1328              WorkGroupSizeOption::ThreadMapping::kLinearAll) {
1329     params.linear_spatial = false;
1330     params.linear_all = true;
1331     params.work_group_size = optimum_wg.work_group_size;
1332     params.work_group_launch_order = int3(0, 1, 2);
1333     params.weights_upload_type = ConvGeneric::WeightsUploadType::GLOBAL_MEM;
1334   } else {
1335     // default 3D workgroup
1336     params.linear_spatial = false;
1337     params.linear_all = false;
1338     params.work_group_size = optimum_wg.work_group_size;
1339     params.work_group_launch_order = int3(2, 0, 1);
1340   }
1341   int total_elements = params.block_size.x * params.block_size.y *
1342                        params.block_size.z * params.block_size.w;
1343   if (total_elements == 1) {
1344     if (src_slices % 4 == 0) {
1345       params.src_depth_loop_size = 4;
1346     } else if (src_slices % 2 == 0) {
1347       params.src_depth_loop_size = 2;
1348     }
1349   } else if (total_elements == 2) {
1350     if (src_slices % 2 == 0) {
1351       params.src_depth_loop_size = 2;
1352     }
1353   }
1354   if (params.src_depth_loop_size == src_slices) {
1355     params.need_src_loop = false;
1356   }
1357   if (params.block_size.w == dst_slices) {
1358     params.need_dst_loop = false;
1359   }
1360   const bool use_filters_constants =
1361       !params.need_dst_loop && !params.need_src_loop && params.x_kernel_is_1 &&
1362       params.y_kernel_is_1;
1363   if (use_filters_constants) {
1364     params.weights_upload_type = ConvGeneric::WeightsUploadType::CONSTANT_MEM;
1365   }
1366 
1367   return params;
1368 }
1369 
GetConvParamsForA9AndHigher(const AppleInfo & apple_info,bool x_kernel_is_1,bool y_kernel_is_1,int src_slices,const BHWC & dst_shape)1370 ConvGeneric::ConvParams GetConvParamsForA9AndHigher(const AppleInfo& apple_info,
1371                                                     bool x_kernel_is_1,
1372                                                     bool y_kernel_is_1,
1373                                                     int src_slices,
1374                                                     const BHWC& dst_shape) {
1375   const int dst_slices = DivideRoundUp(dst_shape.c, 4);
1376   int blk_total_size = GetRecommendedBlockSize(apple_info, dst_shape);
1377   int3 block_size = int3(1, 1, 1);
1378   if (blk_total_size >= 2 && apple_info.IsBionic()) {
1379     if (dst_shape.h % 2 != 0 && dst_shape.w % 2 == 0) {
1380       block_size.x = 2;
1381     } else {
1382       block_size.y = 2;
1383     }
1384     blk_total_size /= 2;
1385   }
1386   if (blk_total_size >= 4 && (dst_slices % 4 == 0 || dst_slices >= 16)) {
1387     block_size.z = 4;
1388     blk_total_size /= 4;
1389   } else if (blk_total_size >= 2 && (dst_slices % 2 == 0 || dst_slices >= 4)) {
1390     block_size.z = 2;
1391     blk_total_size /= 2;
1392   }
1393   if (blk_total_size >= 4 && dst_slices == 3) {
1394     block_size.z = 3;
1395     blk_total_size /= 4;
1396   }
1397 
1398   ConvGeneric::ConvParams params;
1399   params.weights_upload_type = ConvGeneric::WeightsUploadType::GLOBAL_MEM;
1400   params.x_kernel_is_1 = x_kernel_is_1;
1401   params.y_kernel_is_1 = y_kernel_is_1;
1402   params.src_depth_loop_size = 1;
1403   params.block_size.x = block_size.x;
1404   params.block_size.y = block_size.y;
1405   params.block_size.z = 1;
1406   params.block_size.w = block_size.z;
1407   params.linear_spatial = false;
1408   params.linear_all = false;
1409   params.work_group_size = int3(8, 4, 1);
1410   params.work_group_launch_order = int3(2, 0, 1);
1411   params.weights_layout = WeightsLayout::kOSpatialIOGroupO4I4;
1412   int g1 = GetGroupsCount(dst_shape, params.work_group_size, params.block_size);
1413   int g2 = GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, params.block_size);
1414   int g3 = GetGroupsCountForLinearWHS(dst_shape, {32, 1, 1}, params.block_size);
1415   if (g2 < g1) {
1416     params.linear_spatial = true;
1417     params.work_group_size = int3(32, 1, 1);
1418     params.work_group_launch_order = int3(0, 1, 2);
1419   }
1420   float precise_threshold = apple_info.IsBionic() ? 1.0f : 1.04f;
1421   float precise_ratio = static_cast<float>(g2) / static_cast<float>(g3);
1422   if (precise_ratio > precise_threshold) {
1423     params.linear_spatial = false;
1424     params.linear_all = true;
1425     params.work_group_size = int3(32, 1, 1);
1426   }
1427   int total_elements = params.block_size.x * params.block_size.y *
1428                        params.block_size.z * params.block_size.w;
1429   if (total_elements == 1) {
1430     if (src_slices % 4 == 0) {
1431       params.src_depth_loop_size = 4;
1432     } else if (src_slices % 2 == 0) {
1433       params.src_depth_loop_size = 2;
1434     }
1435   } else if (total_elements == 2) {
1436     if (src_slices % 2 == 0) {
1437       params.src_depth_loop_size = 2;
1438     }
1439   }
1440   if (params.src_depth_loop_size == src_slices) {
1441     params.need_src_loop = false;
1442   }
1443   if (params.block_size.w == dst_slices) {
1444     params.need_dst_loop = false;
1445   }
1446   const bool use_filters_constants =
1447       !params.need_dst_loop && !params.need_src_loop && params.x_kernel_is_1 &&
1448       params.y_kernel_is_1;
1449   if (use_filters_constants) {
1450     params.weights_upload_type = ConvGeneric::WeightsUploadType::CONSTANT_MEM;
1451   }
1452 
1453   return params;
1454 }
1455 
GuessBestParamsApple(const GpuInfo & gpu_info,const OperationDef & definition,int src_depth,int dst_depth,bool x_kernel_is_1,bool y_kernel_is_1,bool different_weights_for_height,const BHWC & dst_shape)1456 ConvGeneric::ConvParams ConvGeneric::GuessBestParamsApple(
1457     const GpuInfo& gpu_info, const OperationDef& definition, int src_depth,
1458     int dst_depth, bool x_kernel_is_1, bool y_kernel_is_1,
1459     bool different_weights_for_height, const BHWC& dst_shape) {
1460   if (gpu_info.apple_info.IsLocalMemoryPreferredOverGlobal()) {
1461     return GetConvParamsForA7A8(gpu_info.apple_info, x_kernel_is_1,
1462                                 y_kernel_is_1, src_depth, dst_shape);
1463   } else {
1464     return GetConvParamsForA9AndHigher(gpu_info.apple_info, x_kernel_is_1,
1465                                        y_kernel_is_1, src_depth, dst_shape);
1466   }
1467 }
1468 
GuessBestParams(const GpuInfo & gpu_info,const OperationDef & definition,int src_depth,int dst_depth,bool x_kernel_is_1,bool y_kernel_is_1,bool different_weights_for_height,const BHWC * dst_shape)1469 ConvGeneric::ConvParams ConvGeneric::GuessBestParams(
1470     const GpuInfo& gpu_info, const OperationDef& definition, int src_depth,
1471     int dst_depth, bool x_kernel_is_1, bool y_kernel_is_1,
1472     bool different_weights_for_height, const BHWC* dst_shape) {
1473   ConvParams conv_params;
1474   conv_params.linear_spatial = false;
1475   conv_params.linear_all = false;
1476   conv_params.block_size = int4(1, 1, 1, 1);
1477   conv_params.weights_data_type =
1478       DeduceDataTypeFromPrecision(definition.precision);
1479   conv_params.x_kernel_is_1 = x_kernel_is_1;
1480   conv_params.y_kernel_is_1 = y_kernel_is_1;
1481   conv_params.different_weights_for_height = different_weights_for_height;
1482   if (gpu_info.IsNvidia()) {
1483     if (different_weights_for_height) {
1484       work_group_size_ = int3(32, 1, 1);
1485       work_group_launch_order_ = int3(2, 0, 1);
1486       conv_params.fixed_work_group_size = true;
1487     } else {
1488       conv_params.linear_spatial = true;
1489       work_group_size_ = int3(32, 1, 1);
1490       work_group_launch_order_ = int3(1, 0, 2);
1491       conv_params.fixed_work_group_size = true;
1492     }
1493     conv_params.block_size = int4(2, 1, 1, 4);
1494     conv_params.src_depth_loop_size = 1;
1495     conv_params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS;
1496     if (dst_depth % 4 == 0 || dst_depth >= 8) {
1497       conv_params.block_size.w = 4;
1498     } else if (dst_depth % 2 == 0 || dst_depth >= 4) {
1499       conv_params.block_size.w = 2;
1500     } else {
1501       conv_params.block_size.w = dst_depth;
1502     }
1503     if (dst_shape) {
1504       int task_size = dst_shape->w * dst_shape->b * dst_shape->h * dst_depth;
1505       float task_size_per_cu =
1506           static_cast<float>(task_size) / gpu_info.GetComputeUnitsCount();
1507       int block_size = conv_params.block_size.x * conv_params.block_size.y *
1508                        conv_params.block_size.w;
1509       float threads_per_cu = task_size_per_cu / block_size;
1510       float warps_per_cu = threads_per_cu / 32 /*warp_size*/;
1511       if (warps_per_cu < 8.0f) {
1512         conv_params.block_size.x = 1;
1513       }
1514       if (warps_per_cu < 4.0f && conv_params.block_size.w >= 4) {
1515         conv_params.block_size.w /= 2;
1516       }
1517       if (warps_per_cu < 2.0f && conv_params.block_size.w >= 2) {
1518         conv_params.block_size.w /= 2;
1519       }
1520     }
1521     if (src_depth % 2 == 0) {
1522       conv_params.src_depth_loop_size = 2;
1523     }
1524     if (src_depth % 4 == 0 && conv_params.block_size.w <= 2) {
1525       conv_params.src_depth_loop_size = 4;
1526     }
1527   } else if (gpu_info.IsPowerVR()) {
1528     if (different_weights_for_height) {
1529       work_group_size_ = int3(32, 1, 1);
1530       work_group_launch_order_ = int3(2, 0, 1);
1531       conv_params.fixed_work_group_size = true;
1532     } else {
1533       conv_params.linear_spatial = true;
1534       work_group_size_ = int3(32, 1, 1);
1535       work_group_launch_order_ = int3(1, 0, 2);
1536       conv_params.fixed_work_group_size = true;
1537     }
1538     conv_params.block_size = int4(1, 1, 1, 4);
1539     conv_params.src_depth_loop_size = 1;
1540     conv_params.weights_upload_type =
1541         WeightsUploadType::LOCAL_MEM_ASYNC_SUBGROUP;
1542     if (dst_depth % 8 == 0 || dst_depth >= 32) {
1543       conv_params.block_size.w = 8;
1544     } else if (dst_depth % 4 == 0 || dst_depth >= 8) {
1545       conv_params.block_size.w = 4;
1546     } else if (dst_depth % 2 == 0 || dst_depth >= 4) {
1547       conv_params.block_size.w = 2;
1548     } else {
1549       conv_params.block_size.w = dst_depth;
1550     }
1551     if (definition.precision == CalculationsPrecision::F16) {
1552       conv_params.block_size.w = std::min(4, conv_params.block_size.w);
1553       if (src_depth % 2 == 0) {
1554         conv_params.src_depth_loop_size = 2;
1555       }
1556       if (src_depth % 4 == 0 && conv_params.block_size.w <= 2) {
1557         conv_params.src_depth_loop_size = 4;
1558       }
1559       if (conv_params.block_size.w == 1) {
1560         if (src_depth % 2 == 0) {
1561           conv_params.src_depth_loop_size = 2;
1562         }
1563         if (src_depth % 4 == 0) {
1564           conv_params.src_depth_loop_size = 4;
1565         }
1566         if (src_depth <= 8) {
1567           conv_params.src_depth_loop_size = src_depth;
1568         }
1569       }
1570       conv_params.block_size.x = 2;
1571     }
1572   } else if (gpu_info.IsAMD()) {
1573     work_group_size_ = int3(8, 4, 1);
1574     work_group_launch_order_ = int3(0, 1, 2);
1575     conv_params.fixed_work_group_size = false;
1576 
1577     if (gpu_info.IsApiOpenCl()) {
1578       conv_params.weights_upload_type = WeightsUploadType::CONSTANT_MEM;
1579     } else {
1580       conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM;
1581     }
1582     if (dst_depth % 4 == 0 || dst_depth >= 8) {
1583       conv_params.block_size = int4(2, 2, 1, 4);
1584     } else if (dst_depth % 2 == 0 || dst_depth >= 4) {
1585       conv_params.block_size = int4(4, 2, 1, 2);
1586     } else {
1587       conv_params.block_size = int4(4, 4, 1, 1);
1588     }
1589     auto reduce_block_size_wzyx = [](int4* block_size) {
1590       if (block_size->w % 2 == 0) {
1591         block_size->w /= 2;
1592       } else if (block_size->z % 2 == 0) {
1593         block_size->z /= 2;
1594       } else if (block_size->y % 2 == 0) {
1595         block_size->y /= 2;
1596       } else if (block_size->x % 2 == 0) {
1597         block_size->x /= 2;
1598       }
1599     };
1600     if (definition_.precision != CalculationsPrecision::F16) {
1601       reduce_block_size_wzyx(&conv_params.block_size);
1602     }
1603     if (dst_shape) {
1604       int task_size = dst_shape->w * dst_shape->b * dst_shape->h * dst_depth;
1605       float task_size_per_cu =
1606           static_cast<float>(task_size) / gpu_info.GetComputeUnitsCount();
1607       int block_size = conv_params.block_size.x * conv_params.block_size.y *
1608                        conv_params.block_size.w;
1609       float threads_per_cu = task_size_per_cu / block_size;
1610       float warps_per_cu = threads_per_cu / 64;
1611       if (warps_per_cu < 4.0f) {
1612         reduce_block_size_wzyx(&conv_params.block_size);
1613       }
1614       if (warps_per_cu < 2.0f) {
1615         reduce_block_size_wzyx(&conv_params.block_size);
1616       }
1617       if (warps_per_cu < 1.0f) {
1618         reduce_block_size_wzyx(&conv_params.block_size);
1619       }
1620       if (warps_per_cu < 0.5f) {
1621         reduce_block_size_wzyx(&conv_params.block_size);
1622       }
1623     }
1624     int block_size = conv_params.block_size.x * conv_params.block_size.y *
1625                      conv_params.block_size.w;
1626     conv_params.src_depth_loop_size = 1;
1627     if (block_size <= 4 && src_depth % 2 == 0) {
1628       conv_params.src_depth_loop_size = 2;
1629     }
1630     if (block_size <= 2 && src_depth % 4 == 0) {
1631       conv_params.src_depth_loop_size = 4;
1632     }
1633     if (block_size <= 1 && src_depth % 8 == 0) {
1634       conv_params.src_depth_loop_size = 8;
1635     }
1636   } else if (gpu_info.IsMali()) {
1637     int block_size = 2;
1638     if (dst_shape) {
1639       int task_size = dst_shape->w * dst_shape->b * dst_shape->h * dst_depth;
1640       block_size = GetRecommendedBlockSizeForConv(
1641           gpu_info, definition.precision, task_size);
1642     }
1643     if (!x_kernel_is_1 || !y_kernel_is_1) {
1644       if (gpu_info.mali_info.IsMidgard() || gpu_info.mali_info.IsBifrost()) {
1645         block_size = std::min(block_size, 4);
1646       }
1647     }
1648     if (block_size == 8) {
1649       if (dst_depth == 1 || dst_depth == 3) {
1650         conv_params.block_size = int4(2, 2, 1, 1);
1651       } else {
1652         conv_params.block_size = int4(2, 2, 1, 2);
1653       }
1654     } else if (block_size == 4) {
1655       if (dst_depth == 1 || dst_depth == 3) {
1656         conv_params.block_size = int4(2, 2, 1, 1);
1657       } else {
1658         conv_params.block_size = int4(2, 1, 1, 1);
1659         if (definition.precision == CalculationsPrecision::F32 &&
1660             gpu_info.mali_info.IsValhall()) {
1661           conv_params.block_size.y = 2;
1662         } else {
1663           conv_params.block_size.w = 2;
1664         }
1665       }
1666     } else if (block_size == 2) {
1667       conv_params.block_size = int4(2, 1, 1, 1);
1668     } else {
1669       conv_params.block_size = int4(1, 1, 1, 1);
1670     }
1671     conv_params.src_depth_loop_size = 1;
1672     MaliInfo mali_info = gpu_info.mali_info;
1673     if (src_depth % 2 == 0 && block_size <= 2 && !mali_info.IsMidgard()) {
1674       conv_params.src_depth_loop_size = 2;
1675     }
1676     if (src_depth % 4 == 0 && block_size == 1 && !mali_info.IsMidgard() &&
1677         definition.precision == CalculationsPrecision::F16) {
1678       conv_params.src_depth_loop_size = 4;
1679     }
1680     work_group_size_ = int3(4, 4, 1);
1681     work_group_launch_order_ = int3(0, 1, 2);
1682     conv_params.fixed_work_group_size = false;
1683     conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM;
1684   } else if (gpu_info.IsAdreno()) {
1685     if (dst_shape) {
1686       const int wave_size = gpu_info.adreno_info.GetWaveSize(
1687           definition.precision == CalculationsPrecision::F16);
1688       const double task_size =
1689           1.0 * dst_shape->w * dst_shape->b * dst_shape->h * dst_depth;
1690       const double waves =
1691           task_size / gpu_info.GetComputeUnitsCount() / wave_size;
1692       if (waves <= 6.0f) {
1693         conv_params.block_size = int4(1, 1, 1, 1);
1694       } else if (waves <= 12.0f) {
1695         conv_params.block_size = int4(2, 1, 1, 1);
1696       } else if (waves <= 24.0f) {
1697         conv_params.block_size = int4(2, 1, 1, 2);
1698       } else {
1699         conv_params.block_size = int4(2, 2, 1, 2);
1700       }
1701     } else {
1702       conv_params.block_size = int4(2, 2, 1, 2);
1703     }
1704     if (gpu_info.adreno_info.IsAdreno3xx()) {
1705       if (definition.precision == CalculationsPrecision::F16) {
1706         conv_params.block_size = int4(2, 2, 1, 2);
1707       } else if (definition.precision == CalculationsPrecision::F32_F16) {
1708         conv_params.block_size = int4(2, 1, 1, 2);
1709       } else {  // F32
1710         conv_params.block_size = int4(2, 2, 1, 1);
1711       }
1712     }
1713     work_group_size_ = int3(8, 2, 1);
1714     work_group_launch_order_ = int3(0, 1, 2);
1715     conv_params.fixed_work_group_size = false;
1716     conv_params.src_depth_loop_size = 1;
1717     conv_params.weights_upload_type = WeightsUploadType::TEXTURES_MEM_X4;
1718   } else if (gpu_info.IsIntel()) {
1719     if (different_weights_for_height) {
1720       work_group_size_ = int3(16, 1, 1);
1721       work_group_launch_order_ = int3(0, 1, 2);
1722       conv_params.fixed_work_group_size = true;
1723     } else {
1724       conv_params.linear_spatial = true;
1725       work_group_size_ = int3(16, 1, 1);
1726       work_group_launch_order_ = int3(0, 1, 2);
1727       conv_params.fixed_work_group_size = true;
1728     }
1729     conv_params.block_size = int4(1, 1, 1, 4);
1730     conv_params.src_depth_loop_size = 1;
1731     conv_params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS;
1732     if (gpu_info.IsApiMetal() &&
1733         definition.precision != CalculationsPrecision::F32_F16 &&
1734         gpu_info.metal_info.IsMslVersionEqualOrHigher(2)) {
1735       conv_params.weights_upload_type =
1736           WeightsUploadType::PRIVATE_MEM_SIMD_BROADCAST;
1737       conv_params.simd_size = 8;
1738     }
1739     if (gpu_info.IsApiOpenCl()) {
1740       const int kSubGroupSize = 16;
1741       const bool supports_subgroups =
1742           gpu_info.SupportsExtension("cl_khr_subgroups") ||
1743           gpu_info.SupportsExtension("cl_intel_subgroups");
1744       if (definition.precision != CalculationsPrecision::F32_F16 &&
1745           supports_subgroups &&
1746           gpu_info.SupportsExtension("cl_intel_required_subgroup_size") &&
1747           gpu_info.SupportsSubGroupWithSize(kSubGroupSize)) {
1748         conv_params.weights_upload_type =
1749             WeightsUploadType::PRIVATE_MEM_SIMD_BROADCAST;
1750         conv_params.simd_size = kSubGroupSize;
1751       }
1752     }
1753     if (dst_depth % 4 == 0 || dst_depth >= 8) {
1754       conv_params.block_size.w = 4;
1755     } else if (dst_depth % 2 == 0 || dst_depth >= 4) {
1756       conv_params.block_size.w = 2;
1757     } else {
1758       conv_params.block_size.w = dst_depth;
1759     }
1760     if (src_depth % 2 == 0) {
1761       conv_params.src_depth_loop_size = 2;
1762     }
1763     if (src_depth % 4 == 0 && conv_params.block_size.w <= 2) {
1764       conv_params.src_depth_loop_size = 4;
1765     }
1766   } else if (gpu_info.IsApple()) {
1767     BHWC output_shape = BHWC(1, 32, 32, 128);
1768     if (dst_shape) {
1769       output_shape = *dst_shape;
1770     }
1771     conv_params = GuessBestParamsApple(
1772         gpu_info, definition, src_depth, dst_depth, x_kernel_is_1,
1773         y_kernel_is_1, different_weights_for_height, output_shape);
1774     conv_params.fixed_work_group_size = true;
1775     work_group_size_ = conv_params.work_group_size;
1776     work_group_launch_order_ = conv_params.work_group_launch_order;
1777     conv_params.weights_data_type =
1778         DeduceDataTypeFromPrecision(definition.precision);
1779     conv_params.x_kernel_is_1 = x_kernel_is_1;
1780     conv_params.y_kernel_is_1 = y_kernel_is_1;
1781     conv_params.different_weights_for_height = different_weights_for_height;
1782   } else {
1783     conv_params.block_size = int4(1, 1, 1, 4);
1784     work_group_size_ = int3(8, 2, 1);
1785     work_group_launch_order_ = int3(0, 1, 2);
1786     conv_params.fixed_work_group_size = false;
1787     conv_params.src_depth_loop_size = 1;
1788     conv_params.weights_upload_type = WeightsUploadType::GLOBAL_MEM;
1789     if (dst_depth % 4 == 0 || dst_depth >= 8) {
1790       conv_params.block_size.w = 4;
1791     } else if (dst_depth % 2 == 0 || dst_depth >= 4) {
1792       conv_params.block_size.w = 2;
1793     } else {
1794       conv_params.block_size.w = dst_depth;
1795     }
1796     if (src_depth % 2 == 0) {
1797       conv_params.src_depth_loop_size = 2;
1798     }
1799     if (src_depth % 4 == 0 && conv_params.block_size.w <= 2) {
1800       conv_params.src_depth_loop_size = 4;
1801     }
1802   }
1803   if (conv_params.AreWeightsBuffer()) {
1804     if (gpu_info.IsApple()) {
1805       conv_params.weights_layout = WeightsLayout::kOSpatialIOGroupO4I4;
1806     } else {
1807       conv_params.weights_layout = WeightsLayout::kOSpatialIOGroupI4O4;
1808     }
1809   } else {
1810     if (gpu_info.IsApple()) {
1811       conv_params.weights_layout =
1812           WeightsLayout::k2DX4O4YIsSpatialIAndXIsOOGroupI4;
1813     } else {
1814       conv_params.weights_layout =
1815           WeightsLayout::k2DX4I4YIsSpatialIAndXIsOOGroupO4;
1816     }
1817   }
1818 
1819   return conv_params;
1820 }
1821 
GuessBestParams(const GpuInfo & gpu_info,const OperationDef & definition,const Convolution2DAttributes & attr,const BHWC * dst_shape)1822 ConvGeneric::ConvParams ConvGeneric::GuessBestParams(
1823     const GpuInfo& gpu_info, const OperationDef& definition,
1824     const Convolution2DAttributes& attr, const BHWC* dst_shape) {
1825   const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
1826   const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
1827   const bool x_kernel_is_1 = attr.weights.shape.w == 1 && attr.strides.w == 1 &&
1828                              attr.dilations.w == 1 &&
1829                              attr.padding.prepended.w == 0 &&
1830                              attr.padding.appended.w == 0;
1831   const bool y_kernel_is_1 = attr.weights.shape.h == 1 && attr.strides.h == 1 &&
1832                              attr.dilations.h == 1 &&
1833                              attr.padding.prepended.h == 0 &&
1834                              attr.padding.appended.h == 0;
1835   return GuessBestParams(gpu_info, definition, src_depth, dst_depth,
1836                          x_kernel_is_1, y_kernel_is_1, false, dst_shape);
1837 }
1838 
GuessBestParams(const GpuInfo & gpu_info,const OperationDef & definition,const Convolution3DAttributes & attr,const BHWDC * dst_shape)1839 ConvGeneric::ConvParams ConvGeneric::GuessBestParams(
1840     const GpuInfo& gpu_info, const OperationDef& definition,
1841     const Convolution3DAttributes& attr, const BHWDC* dst_shape) {
1842   const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
1843   const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
1844   const bool x_kernel_is_1 = attr.weights.shape.w == 1 && attr.strides.w == 1 &&
1845                              attr.dilations.w == 1 &&
1846                              attr.padding.prepended.w == 0 &&
1847                              attr.padding.appended.w == 0;
1848   const bool y_kernel_is_1 = attr.weights.shape.h == 1 && attr.strides.h == 1 &&
1849                              attr.dilations.h == 1 &&
1850                              attr.padding.prepended.h == 0 &&
1851                              attr.padding.appended.h == 0;
1852   const bool z_kernel_is_1 = attr.weights.shape.d == 1 && attr.strides.d == 1 &&
1853                              attr.dilations.d == 1 &&
1854                              attr.padding.prepended.d == 0 &&
1855                              attr.padding.appended.d == 0;
1856 
1857   ConvGeneric::ConvParams result;
1858   BHWC shape;
1859   if (dst_shape) {
1860     shape.b = dst_shape->b;
1861     shape.h = dst_shape->h * dst_shape->d;
1862     shape.w = dst_shape->w;
1863     shape.c = dst_shape->c;
1864     result = GuessBestParams(gpu_info, definition, src_depth, dst_depth,
1865                              x_kernel_is_1, y_kernel_is_1, false, &shape);
1866   } else {
1867     result = GuessBestParams(gpu_info, definition, src_depth, dst_depth,
1868                              x_kernel_is_1, y_kernel_is_1, false, nullptr);
1869   }
1870   result.z_kernel_is_1 = z_kernel_is_1;
1871   return result;
1872 }
1873 
GuessBestParams(const GpuInfo & gpu_info,const OperationDef & definition,const Convolution2DAttributes & attr,const BHWC & weights_shape,const BHWC * dst_shape)1874 ConvGeneric::ConvParams ConvGeneric::GuessBestParams(
1875     const GpuInfo& gpu_info, const OperationDef& definition,
1876     const Convolution2DAttributes& attr, const BHWC& weights_shape,
1877     const BHWC* dst_shape) {
1878   const int dst_depth = DivideRoundUp(weights_shape.b, 4);
1879   const int src_depth = DivideRoundUp(weights_shape.c, 4);
1880   const bool x_kernel_is_1 =
1881       weights_shape.w == 1 && attr.strides.w == 1 && attr.dilations.w == 1 &&
1882       attr.padding.prepended.w == 0 && attr.padding.appended.w == 0;
1883   const bool y_kernel_is_1 =
1884       weights_shape.h == 1 && attr.strides.h == 1 && attr.dilations.h == 1 &&
1885       attr.padding.prepended.h == 0 && attr.padding.appended.h == 0;
1886   return GuessBestParams(gpu_info, definition, src_depth, dst_depth,
1887                          x_kernel_is_1, y_kernel_is_1, false, dst_shape);
1888 }
1889 
GuessBestParams(const GpuInfo & gpu_info,const OperationDef & definition,const FullyConnectedAttributes & attr,const BHWC * dst_shape)1890 ConvGeneric::ConvParams ConvGeneric::GuessBestParams(
1891     const GpuInfo& gpu_info, const OperationDef& definition,
1892     const FullyConnectedAttributes& attr, const BHWC* dst_shape) {
1893   const int dst_depth = DivideRoundUp(attr.weights.shape.o, 4);
1894   const int src_depth = DivideRoundUp(attr.weights.shape.i, 4);
1895   ConvGeneric::ConvParams params = GuessBestParams(
1896       gpu_info, definition, src_depth, dst_depth, true, true, false, dst_shape);
1897   work_group_size_.x *= work_group_size_.y;
1898   work_group_size_.y = 1;
1899   params.block_size.x *= params.block_size.y;
1900   params.block_size.y = 1;
1901   return params;
1902 }
1903 
GuessBestParamsPointwise(const GpuInfo & gpu_info,const OperationDef & definition,const OHWI & weights_shape,const BHWC * dst_shape)1904 ConvGeneric::ConvParams ConvGeneric::GuessBestParamsPointwise(
1905     const GpuInfo& gpu_info, const OperationDef& definition,
1906     const OHWI& weights_shape, const BHWC* dst_shape) {
1907   const int dst_depth = DivideRoundUp(weights_shape.o, 4);
1908   const int src_depth = DivideRoundUp(weights_shape.i, 4);
1909   ConvGeneric::ConvParams params = GuessBestParams(
1910       gpu_info, definition, src_depth, dst_depth, true, true, true, dst_shape);
1911   params.block_size.x *= params.block_size.y;
1912   params.block_size.y = 1;
1913   work_group_size_.x *= work_group_size_.y;
1914   work_group_size_.y = 1;
1915   return params;
1916 }
1917 
CreateConvGeneric(const GpuInfo & gpu_info,const OperationDef & definition,const Convolution2DAttributes & attr,const BHWC * dst_shape)1918 ConvGeneric CreateConvGeneric(const GpuInfo& gpu_info,
1919                               const OperationDef& definition,
1920                               const Convolution2DAttributes& attr,
1921                               const BHWC* dst_shape) {
1922   ConvGeneric result(definition, attr, gpu_info, dst_shape);
1923   result.GenerateCode(gpu_info);
1924   result.UploadData(attr.weights, attr.bias);
1925   return result;
1926 }
1927 
CreateConvGeneric(const GpuInfo & gpu_info,const OperationDef & definition,const FullyConnectedAttributes & attr,const BHWC * dst_shape)1928 ConvGeneric CreateConvGeneric(const GpuInfo& gpu_info,
1929                               const OperationDef& definition,
1930                               const FullyConnectedAttributes& attr,
1931                               const BHWC* dst_shape) {
1932   ConvGeneric result(definition, attr, gpu_info, dst_shape);
1933   result.GenerateCode(gpu_info);
1934   result.UploadData(attr.weights, attr.bias);
1935   return result;
1936 }
1937 
CreateConvGenericDynamicWeights(const GpuInfo & gpu_info,const OperationDef & definition,const Convolution2DAttributes & attr,const BHWC & weights_shape,const BHWC * dst_shape)1938 ConvGeneric CreateConvGenericDynamicWeights(const GpuInfo& gpu_info,
1939                                             const OperationDef& definition,
1940                                             const Convolution2DAttributes& attr,
1941                                             const BHWC& weights_shape,
1942                                             const BHWC* dst_shape) {
1943   ConvGeneric result(definition, attr, weights_shape, gpu_info, dst_shape);
1944   result.GenerateCode(gpu_info);
1945   result.UploadBias(attr.bias);
1946   return result;
1947 }
1948 
CreateConvGenericBatchedMatMul(const GpuInfo & gpu_info,const OperationDef & definition,const OHWI & weights_shape,const BHWC * dst_shape)1949 ConvGeneric CreateConvGenericBatchedMatMul(const GpuInfo& gpu_info,
1950                                            const OperationDef& definition,
1951                                            const OHWI& weights_shape,
1952                                            const BHWC* dst_shape) {
1953   ConvGeneric result(definition);
1954   result.conv_params_ = result.GuessBestParamsPointwise(
1955       gpu_info, definition, weights_shape, dst_shape);
1956   result.GenerateCode(gpu_info);
1957   tflite::gpu::Tensor<Linear, DataType::FLOAT32> biases;
1958   biases.shape = Linear(weights_shape.o);
1959   biases.data.resize(weights_shape.o, 0.0f);
1960   result.UploadBias(biases);
1961   return result;
1962 }
1963 
CreateConvGenericWino4x4To6x6(const GpuInfo & gpu_info,const OperationDef & definition,const Convolution2DAttributes & attr,const BHWC * dst_shape)1964 ConvGeneric CreateConvGenericWino4x4To6x6(const GpuInfo& gpu_info,
1965                                           const OperationDef& definition,
1966                                           const Convolution2DAttributes& attr,
1967                                           const BHWC* dst_shape) {
1968   ConvGeneric result(definition);
1969   result.conv_params_ = result.GuessBestParamsPointwise(
1970       gpu_info, definition, attr.weights.shape, dst_shape);
1971   result.GenerateCode(gpu_info);
1972   result.UploadDataForWinograd4x4To6x6(attr.weights);
1973   return result;
1974 }
1975 
CreateConvGeneric3D(const GpuInfo & gpu_info,const OperationDef & definition,const Convolution3DAttributes & attr,const BHWDC * dst_shape)1976 ConvGeneric CreateConvGeneric3D(const GpuInfo& gpu_info,
1977                                 const OperationDef& definition,
1978                                 const Convolution3DAttributes& attr,
1979                                 const BHWDC* dst_shape) {
1980   ConvGeneric result(definition, attr, gpu_info, dst_shape);
1981   result.GenerateCode(gpu_info);
1982   result.UploadWeights(attr.weights);
1983   result.UploadBias(attr.bias);
1984   return result;
1985 }
1986 
1987 }  // namespace gpu
1988 }  // namespace tflite
1989