xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/tasks/conv_metal_simd.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/conv_metal_simd.h"
17 
18 #include <cmath>
19 #include <cstdint>
20 #include <functional>
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
28 #include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
29 #include "tensorflow/lite/delegates/gpu/common/operations.h"
30 #include "tensorflow/lite/delegates/gpu/common/shape.h"
31 #include "tensorflow/lite/delegates/gpu/common/types.h"
32 #include "tensorflow/lite/delegates/gpu/common/util.h"
33 
34 namespace tflite {
35 namespace gpu {
36 namespace {
GenerateDstCoords(const int3 & work_group_launch_order,bool linear_spatial,bool need_depth,bool need_batch)37 std::string GenerateDstCoords(const int3& work_group_launch_order,
38                               bool linear_spatial, bool need_depth,
39                               bool need_batch) {
40   std::string c;
41   int3 launch_remap;
42   launch_remap[work_group_launch_order.x] = 0;
43   launch_remap[work_group_launch_order.y] = 1;
44   launch_remap[work_group_launch_order.z] = 2;
45   if (linear_spatial) {
46     if (work_group_launch_order[0] == 0) {
47       c += "  int linear_spatial = GLOBAL_ID_0;\n";
48     } else {
49       c += "  int linear_spatial = GROUP_ID_" +
50            std::to_string(launch_remap[0]) + " * GROUP_SIZE_0 + LOCAL_ID_0;\n";
51     }
52     if (need_batch) {
53       c += "  int B = linear_spatial % args.dst_tensor.Batch();\n";
54       c += "  linear_spatial = linear_spatial / args.dst_tensor.Batch();\n";
55     }
56     if (need_depth) {
57       c += "  int DST_X = linear_spatial % args.dst_tensor.Width();\n";
58       c += "  linear_spatial = linear_spatial / args.dst_tensor.Width();\n";
59       c += "  int DST_Y = linear_spatial % args.dst_tensor.Height();\n";
60       c += "  int DST_Z = linear_spatial / args.dst_tensor.Height();\n";
61     } else {
62       c += "  int DST_Y = linear_spatial / args.dst_tensor.Width();\n";
63       c += "  int DST_X = linear_spatial % args.dst_tensor.Width();\n";
64     }
65     if (work_group_launch_order[1] == 1) {
66       c += "  int DST_S = GLOBAL_ID_1;\n";
67     } else {
68       c += "  int DST_S = GROUP_ID_" + std::to_string(launch_remap[1]) +
69            " * GROUP_SIZE_1 + LOCAL_ID_1;\n";
70     }
71   } else {
72     if (work_group_launch_order[0] == 0) {
73       c += "  int DST_X = GLOBAL_ID_0;\n";
74     } else {
75       c += "  int DST_X = GROUP_ID_" + std::to_string(launch_remap[0]) +
76            " * GROUP_SIZE_0 + LOCAL_ID_0;\n";
77     }
78     if (need_batch) {
79       c += "  int B = DST_X % args.dst_tensor.Batch();\n";
80       c += "  DST_X = DST_X / args.dst_tensor.Batch();\n";
81     }
82     std::string global_id_1;
83     if (work_group_launch_order[1] == 1) {
84       global_id_1 = "GLOBAL_ID_1";
85     } else {
86       global_id_1 = "GROUP_ID_" + std::to_string(launch_remap[1]) +
87                     " * GROUP_SIZE_1 + LOCAL_ID_1";
88     }
89     if (need_depth) {
90       c += "  int linear_id_1 = " + global_id_1 + ";\n";
91       c += "  int DST_Z = linear_id_1 / dst_tensor.Height();\n";
92       c += "  int DST_Y = linear_id_1 % dst_tensor.Height();\n";
93     } else {
94       c += "  int DST_Y = " + global_id_1 + ";\n";
95     }
96     if (work_group_launch_order[2] == 2) {
97       c += "  int DST_S = GLOBAL_ID_2;\n";
98     } else {
99       c += "  int DST_S = GROUP_ID_" + std::to_string(launch_remap[2]) +
100            " * GROUP_SIZE_2 + LOCAL_ID_2;\n";
101     }
102   }
103 
104   return c;
105 }
106 
GenerateConvolution(const OperationDef & definition,const ConvolutionMetalSimd::ConvParams & conv_params)107 std::string GenerateConvolution(
108     const OperationDef& definition,
109     const ConvolutionMetalSimd::ConvParams& conv_params) {
110   std::string c;
111   c += "#define MMA simdgroup_multiply_accumulate\n";
112   const int spatial_threads = conv_params.GetSpatialThreadsCount();
113   c += "#define SPATIAL_THREADS " + std::to_string(spatial_threads) + "\n";
114   c += "MAIN_FUNCTION($0) {\n";
115   c += GenerateDstCoords(conv_params.work_group_launch_order,
116                          conv_params.linear_spatial,
117                          definition.src_tensors[0].HasAxis(Axis::DEPTH),
118                          definition.src_tensors[0].HasAxis(Axis::BATCH));
119   if (definition.src_tensors[0].HasAxis(Axis::BATCH)) {
120     c += "  args.src_tensor.SetBatchRef(B);\n";
121     c += "  args.dst_tensor.SetBatchRef(B);\n";
122   }
123   if (conv_params.slices_per_thread != 1) {
124     c += "  DST_S *= " + std::to_string(conv_params.slices_per_thread) + ";\n";
125   }
126   c += "  device FLT4* f_offseted = args.weights.GetPtr() + DST_S * 4 * "
127        "args.src_tensor.Slices();\n";
128   const bool cache_weight = spatial_threads > 32;
129   const int src_x4_slices = conv_params.GetX4SlicesCount();
130   const int src_x8_slices = src_x4_slices / 2;
131   const int dst_x4_slices = conv_params.slices_per_thread;
132   const int dst_x8_slices = dst_x4_slices / 2;
133   const int weights_tiles8x8_per_spatial = src_x8_slices * dst_x8_slices;
134   const int weights_flt4_per_spatial = weights_tiles8x8_per_spatial * 8 * 8 / 4;
135   if (conv_params.linear_spatial) {
136     c += "  int spatial_id = LOCAL_ID_0;\n";
137     c += "  int slice_id = LOCAL_ID_1;\n";
138   } else {
139     c += "  int spatial_id = LOCAL_ID_1 * GROUP_SIZE_0 + LOCAL_ID_0;\n";
140     c += "  int slice_id = LOCAL_ID_2;\n";
141   }
142   c += "  int tid = slice_id * SPATIAL_THREADS + spatial_id;\n";
143   if (cache_weight) {
144     c += "  threadgroup FLT4 tmp_w[" +
145          std::to_string(weights_flt4_per_spatial *
146                         conv_params.GetX4SlicesCount()) +
147          "];\n";
148     c += "  threadgroup FLT* tmp_w_x1 = (threadgroup FLT*)tmp_w;\n";
149     c += "  tmp_w_x1 += " + std::to_string(weights_flt4_per_spatial * 4) +
150          " * slice_id;\n";
151     c += "  threadgroup FLT4* tmp_w_x4 = (threadgroup FLT4*)tmp_w_x1;\n\n";
152   } else {
153     c += "  device FLT* f_offseted_x1 = (device FLT*)f_offseted;\n\n";
154   }
155   c += "  threadgroup FLT4 tmp_src[SPATIAL_THREADS * " +
156        std::to_string(src_x4_slices) + "];\n";
157   c += "  threadgroup FLT* tmp_src_x1 = (threadgroup FLT*)tmp_src;\n\n";
158   c += "  // sp - spatial dimensions, ch - channels dimension\n";
159   c += "  // indexing relative to simdgroup\n";
160   for (int sp = 0; sp < 32; sp += 8) {
161     const std::string sp_start = std::to_string(sp);
162     const std::string sp_end = std::to_string(sp + 8);
163     const std::string dst_name = "dst_sp" + sp_start + "_" + sp_end;
164     for (int slice = 0; slice < dst_x8_slices; slice += 1) {
165       const std::string sl_start = std::to_string(slice * 8);
166       const std::string sl_end = std::to_string(slice * 8 + 8);
167       c += "  simdgroup_matrix<FLT, 8, 8> " + dst_name + "_ch" + sl_start +
168            "_" + sl_end + "(0.0f);\n";
169     }
170   }
171   if (spatial_threads > 32) {
172     c += "  int spatial_group = spatial_id / 32;\n";
173     c += "  tmp_src_x1 += 8 * 8 * 4 * spatial_group;\n";
174   }
175   c += R"(
176   int c_x = min(DST_X, args.src_tensor.Width() - 1);
177   int c_y = min(DST_Y, args.src_tensor.Height() - 1);
178 )";
179   if (definition.src_tensors[0].IsLinear()) {
180     c +=
181         "  int src_address = args.src_tensor.GetAddress(c_x, c_y, slice_id);\n";
182   }
183   c += R"(
184   int tid2 = 0;
185   if (tid < SPATIAL_THREADS) {
186     tid2 = tid * 2 + 0;
187   } else if (tid < SPATIAL_THREADS * 2) {
188     tid2 = (tid - SPATIAL_THREADS) * 2 + 1;
189   })";
190   for (int src_s = 1; src_s < src_x8_slices; ++src_s) {
191     c += " else if (tid < SPATIAL_THREADS * " + std::to_string(src_s * 2 + 1) +
192          ") {\n";
193     c += "    tid2 = (tid - SPATIAL_THREADS * " + std::to_string(src_s * 2) +
194          ") * 2 + 0 + SPATIAL_THREADS * " + std::to_string(src_s * 2) + ";\n";
195     c += "  } else if (tid < SPATIAL_THREADS * " +
196          std::to_string(src_s * 2 + 2) + ") {\n";
197     c += "    tid2 = (tid - SPATIAL_THREADS * " +
198          std::to_string(src_s * 2 + 1) + ") * 2 + 1 + SPATIAL_THREADS * " +
199          std::to_string(src_s * 2) + ";\n";
200     c += "  }";
201   }
202   c += "\n\n";
203   c += "  for (int s = 0; s < args.src_tensor.Slices(); s += " +
204        std::to_string(src_x4_slices) + ") {\n";
205   for (int src_s = 0; src_s < src_x8_slices; ++src_s) {
206     const std::string src_range =
207         "i" + std::to_string(src_s * 8) + "_" + std::to_string(src_s * 8 + 8);
208     for (int dst_s = 0; dst_s < dst_x8_slices; ++dst_s) {
209       const std::string dst_range =
210           "o" + std::to_string(dst_s * 8) + "_" + std::to_string(dst_s * 8 + 8);
211       const std::string w_name = "w_" + dst_range + "_" + src_range;
212       c += "    simdgroup_matrix<FLT, 8, 8> " + w_name + ";\n";
213     }
214   }
215   c += "    threadgroup_barrier(mem_flags::mem_threadgroup);\n";
216   if (cache_weight) {
217     const int groups = weights_flt4_per_spatial / spatial_threads;
218     const int reminder = weights_flt4_per_spatial % spatial_threads;
219     for (int i = 0; i < groups; ++i) {
220       c += "    tmp_w_x4[spatial_id + " + std::to_string(spatial_threads * i) +
221            "] = f_offseted[spatial_id + " +
222            std::to_string(spatial_threads * i) + "];\n";
223     }
224     if (reminder != 0) {
225       c += "    if (spatial_id < " + std::to_string(reminder) + ") {\n";
226       c += "      tmp_w_x4[spatial_id + " +
227            std::to_string(spatial_threads * groups) +
228            "] = f_offseted[spatial_id + " +
229            std::to_string(spatial_threads * groups) + "];\n";
230       c += "    }\n";
231     }
232   } else {
233     for (int src_s = 0; src_s < src_x8_slices; ++src_s) {
234       const std::string src_range =
235           "i" + std::to_string(src_s * 8) + "_" + std::to_string(src_s * 8 + 8);
236       for (int dst_s = 0; dst_s < dst_x8_slices; ++dst_s) {
237         const std::string dst_range = "o" + std::to_string(dst_s * 8) + "_" +
238                                       std::to_string(dst_s * 8 + 8);
239         const std::string w_name = "w_" + dst_range + "_" + src_range;
240         c += "    simdgroup_load(" + w_name + ", f_offseted_x1 + " +
241              std::to_string((src_s * dst_x8_slices + dst_s) * 64) + ", 8);\n";
242       }
243     }
244   }
245   if (definition.src_tensors[0].IsLinear()) {
246     c += "    tmp_src[tid2] = args.src_tensor.Read(src_address);\n";
247   } else {
248     c += "    tmp_src[tid2] = args.src_tensor.Read(c_x, c_y, s + slice_id);\n";
249   }
250   if (cache_weight) {
251     c += "    f_offseted += 16 * " +
252          std::to_string(src_x8_slices * dst_x8_slices) + ";\n";
253   } else {
254     c += "    f_offseted_x1 += 64 * " +
255          std::to_string(src_x8_slices * dst_x8_slices) + ";\n";
256   }
257   if (definition.src_tensors[0].IsLinear()) {
258     c += "    src_address += args.src_tensor.SliceStride() * " +
259          std::to_string(src_x4_slices) + ";\n";
260   }
261   c += "    threadgroup_barrier(mem_flags::mem_threadgroup);\n";
262   if (cache_weight) {
263     for (int src_s = 0; src_s < src_x8_slices; ++src_s) {
264       const std::string src_range =
265           "i" + std::to_string(src_s * 8) + "_" + std::to_string(src_s * 8 + 8);
266       for (int dst_s = 0; dst_s < dst_x8_slices; ++dst_s) {
267         const std::string dst_range = "o" + std::to_string(dst_s * 8) + "_" +
268                                       std::to_string(dst_s * 8 + 8);
269         const std::string w_name = "w_" + dst_range + "_" + src_range;
270         c += "    simdgroup_load(" + w_name + ", tmp_w_x1 + " +
271              std::to_string((src_s * dst_x8_slices + dst_s) * 64) + ", 8);\n";
272       }
273     }
274   }
275   c += "    simdgroup_matrix<FLT, 8, 8> mat_src;\n";
276   const int spatial_x8_count = spatial_threads / 8;
277   for (int src_s = 0; src_s < src_x8_slices; ++src_s) {
278     const std::string src_s_range =
279         std::to_string(src_s * 8) + "_" + std::to_string(src_s * 8 + 8);
280     for (int sp = 0; sp < 32; sp += 8) {
281       const std::string sp_range =
282           std::to_string(sp) + "_" + std::to_string(sp + 8);
283       const int src_tile_offset = src_s * spatial_x8_count + (sp / 8);
284       c += "    simdgroup_load(mat_src, tmp_src_x1 + " +
285            std::to_string(src_tile_offset * 64) + ", 8);  // loading sp[" +
286            sp_range + "] src_ch[" + src_s_range + "]\n";
287       for (int dst_s = 0; dst_s < dst_x8_slices; ++dst_s) {
288         const std::string dst_s_range =
289             std::to_string(dst_s * 8) + "_" + std::to_string(dst_s * 8 + 8);
290         const std::string dst_name = "dst_sp" + sp_range + "_ch" + dst_s_range;
291         const std::string w_name = "w_o" + dst_s_range + "_i" + src_s_range;
292         c += "    MMA(" + dst_name + ", mat_src, " + w_name + ", " + dst_name +
293              ");\n";
294       }
295     }
296   }
297   c += "  }\n";
298   for (int slice = 0; slice < dst_x8_slices * 2; slice += 1) {
299     c += "  FLT4 r" + std::to_string(slice) + " = INIT_FLT4(0.0);\n";
300   }
301   c += "  // transferring from simdgroup memory to private registers.\n";
302   c += "  const int kSpatialGroupsCount = " + std::to_string(src_x4_slices) +
303        ";\n";
304   c += "  for (int i = 0; i < kSpatialGroupsCount; ++i) {\n";
305   c += "    int spatial_id = tid - i * SPATIAL_THREADS;\n";
306   c += "    bool current_spatial_group = spatial_id >= 0 && spatial_id < "
307        "SPATIAL_THREADS;\n";
308   for (int dst_s = 0; dst_s < dst_x8_slices; ++dst_s) {
309     const std::string dst_range =
310         "ch" + std::to_string(dst_s * 8) + "_" + std::to_string(dst_s * 8 + 8);
311     c += "    threadgroup_barrier(mem_flags::mem_threadgroup);\n";
312     c += "    if (current_spatial_group) {\n";
313     c += "      simdgroup_store(dst_sp0_8_" + dst_range + ", tmp_src_x1, 8);\n";
314     c += "      simdgroup_store(dst_sp8_16_" + dst_range +
315          ", tmp_src_x1 + 64, 8);\n";
316     c += "      simdgroup_store(dst_sp16_24_" + dst_range +
317          ", tmp_src_x1 + 64 * 2, 8);\n";
318     c += "      simdgroup_store(dst_sp24_32_" + dst_range +
319          ", tmp_src_x1 + 64 * 3, 8);\n";
320     c += "    }\n";
321     c += "    threadgroup_barrier(mem_flags::mem_threadgroup);\n";
322     c += "    if (current_spatial_group) {\n";
323     c += "      r" + std::to_string(dst_s * 2 + 0) +
324          " += tmp_src[spatial_id * 2 + 0];\n";
325     c += "      r" + std::to_string(dst_s * 2 + 1) +
326          " += tmp_src[spatial_id * 2 + 1];\n";
327     c += "    }\n";
328   }
329   c += "  }\n";
330   c += "  if (DST_X >= args.dst_tensor.Width() || DST_Y >= "
331        "args.dst_tensor.Height()) {\n";
332   c += "    return;\n";
333   c += "  }\n";
334   for (int slice = 0; slice < dst_x8_slices * 2; slice += 1) {
335     const std::string dst_s = "DST_S + " + std::to_string(slice);
336     const std::string r_name = "r" + std::to_string(slice);
337     c += "  if (" + dst_s + " < args.dst_tensor.Slices()) {\n";
338     c += "    " + r_name + " += args.biases.Read(" + dst_s + ");\n";
339     c += "    args.dst_tensor.Write(" + r_name + ", DST_X, DST_Y, " + dst_s +
340          ");\n";
341     c += "  }\n";
342   }
343   c += "}\n";
344   return c;
345 }
346 
OIToVecOIOGroupIO(const std::vector<float> & src,int o_size,int i_size,int vec_size,int o_group_size,std::vector<float> * dst)347 void OIToVecOIOGroupIO(const std::vector<float>& src, int o_size, int i_size,
348                        int vec_size, int o_group_size,
349                        std::vector<float>* dst) {
350   int o_slices = DivideRoundUp(o_size, vec_size);
351   int i_slices = DivideRoundUp(i_size, vec_size);
352   int o_groups = DivideRoundUp(o_slices, o_group_size);
353   dst->resize(o_slices * vec_size * i_slices * vec_size);
354   for (int os = 0; os < o_groups; ++os) {
355     for (int is = 0; is < i_slices; ++is) {
356       for (int o_group = 0; o_group < o_group_size; ++o_group) {
357         for (int sub_o = 0; sub_o < vec_size; ++sub_o) {
358           for (int sub_i = 0; sub_i < vec_size; ++sub_i) {
359             float value = 0.0f;
360             int i_ch = is * vec_size + sub_i;
361             int o_ch = (os * o_group_size + o_group) * vec_size + sub_o;
362             if (i_ch < i_size && o_ch < o_size) {
363               value = src[o_ch * i_size + i_ch];
364             }
365             (*dst)[(((os * i_slices + is) * o_group_size + o_group) * vec_size +
366                     sub_i) *
367                        vec_size +
368                    sub_o] = value;
369           }
370         }
371       }
372     }
373   }
374 }
375 
ReorderWeightsForConv(const tflite::gpu::Tensor<OHWI,DataType::FLOAT32> & weights,const DataType & weights_type,int dst_x8_slices)376 std::vector<uint8_t> ReorderWeightsForConv(
377     const tflite::gpu::Tensor<OHWI, DataType::FLOAT32>& weights,
378     const DataType& weights_type, int dst_x8_slices) {
379   std::vector<float> weights_gpu;
380   OIToVecOIOGroupIO(weights.data, weights.shape.o, weights.shape.i, 8,
381                     dst_x8_slices, &weights_gpu);
382   std::vector<uint8_t> result(weights_gpu.size() * SizeOf(weights_type));
383   if (weights_type == DataType::FLOAT32) {
384     float* gpu_data = reinterpret_cast<float*>(result.data());
385     for (int i = 0; i < weights_gpu.size(); ++i) {
386       gpu_data[i] = weights_gpu[i];
387     }
388   } else {
389     half* gpu_data = reinterpret_cast<half*>(result.data());
390     for (int i = 0; i < weights_gpu.size(); ++i) {
391       gpu_data[i] = weights_gpu[i];
392     }
393   }
394   return result;
395 }
396 
ReorderBiasesForConv(const tflite::gpu::Tensor<Linear,DataType::FLOAT32> & biases,const DataType & biases_type,int output_size)397 std::vector<uint8_t> ReorderBiasesForConv(
398     const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& biases,
399     const DataType& biases_type, int output_size) {
400   std::vector<uint8_t> result(output_size * SizeOf(biases_type));
401   if (biases_type == DataType::FLOAT32) {
402     float* gpu_data = reinterpret_cast<float*>(result.data());
403     for (int i = 0; i < output_size; ++i) {
404       gpu_data[i] = i < biases.shape.v ? biases.data[i] : 0.0f;
405     }
406   } else {
407     half* gpu_data = reinterpret_cast<half*>(result.data());
408     for (int i = 0; i < output_size; ++i) {
409       gpu_data[i] = i < biases.shape.v ? biases.data[i] : 0.0f;
410     }
411   }
412   return result;
413 }
414 
Get2DWorkgroupsEqualTo32()415 std::vector<int2> Get2DWorkgroupsEqualTo32() {
416   return {{8, 4}, {16, 2}, {4, 8}, {32, 1}, {2, 16}, {1, 32}};
417 }
418 
Get2dGroupsCount(const BHWC & dst_shape,const int2 group_size)419 int Get2dGroupsCount(const BHWC& dst_shape, const int2 group_size) {
420   int x_groups = DivideRoundUp(dst_shape.w * dst_shape.b, group_size.x);
421   int y_groups = DivideRoundUp(dst_shape.h, group_size.y);
422   return x_groups * y_groups;
423 }
424 
GetOptimalGroupSize(const BHWC & dst_shape)425 int2 GetOptimalGroupSize(const BHWC& dst_shape) {
426   const auto base_work_groups = Get2DWorkgroupsEqualTo32();
427   int min_2d_work_groups = Get2dGroupsCount(dst_shape, base_work_groups[0]);
428   int min_index = 0;
429   for (int i = 1; i < base_work_groups.size(); ++i) {
430     int groups_count = Get2dGroupsCount(dst_shape, base_work_groups[i]);
431     if (groups_count < min_2d_work_groups) {
432       min_2d_work_groups = groups_count;
433       min_index = i;
434     }
435   }
436   return base_work_groups[min_index];
437 }
438 
439 }  // namespace
440 
GetGridSize() const441 int3 ConvolutionMetalSimd::GetGridSize() const {
442   const int task_size_x = dst_[0]->Width() * dst_[0]->Batch();
443   const int task_size_y = dst_[0]->Height();
444   const int task_size_z = dst_[0]->Depth();
445   const int task_size_s =
446       DivideRoundUp(dst_[0]->Slices(), params_.slices_per_thread);
447   if (params_.linear_spatial) {
448     return int3(task_size_x * task_size_y * task_size_z, task_size_s, 1);
449   } else {
450     return int3(task_size_x, task_size_y * task_size_z, task_size_s);
451   }
452 }
453 
CreateConvolutionMetalSimd(const OperationDef & definition,const BHWC & dst_shape,const Convolution2DAttributes & attr,const GpuInfo & gpu_info)454 ConvolutionMetalSimd CreateConvolutionMetalSimd(
455     const OperationDef& definition, const BHWC& dst_shape,
456     const Convolution2DAttributes& attr, const GpuInfo& gpu_info) {
457   ConvolutionMetalSimd desc(definition);
458   const int2 optimal_2d_group_size = GetOptimalGroupSize(dst_shape);
459   const int groups2d_count = Get2dGroupsCount(dst_shape, optimal_2d_group_size);
460   const int groups1d_count =
461       DivideRoundUp(dst_shape.w * dst_shape.b * dst_shape.h, 32);
462   if (groups1d_count < groups2d_count) {
463     desc.params_.work_group_size = int3(32, 4, 1);
464     desc.params_.work_group_launch_order = int3(0, 1, 2);
465     desc.params_.linear_spatial = true;
466   } else {
467     desc.params_.work_group_size =
468         int3(optimal_2d_group_size.x, optimal_2d_group_size.y, 4);
469     desc.params_.work_group_launch_order = int3(0, 1, 2);
470     desc.params_.linear_spatial = false;
471   }
472   desc.params_.slices_per_thread = 4;
473   desc.params_.x_kernel_is_1 = true;
474   desc.params_.y_kernel_is_1 = true;
475   desc.params_.z_kernel_is_1 = true;
476   desc.code_ = GenerateConvolution(definition, desc.params_);
477 
478   desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
479   desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
480 
481   auto weights_type = DeduceDataTypeFromPrecision(definition.precision);
482 
483   MemoryType mem_type = MemoryType::GLOBAL;
484 
485   if (definition.src_tensors.size() == 2) {
486     // dynamic weights
487     BufferDescriptor weights_desc;
488     weights_desc.element_type = definition.src_tensors[1].GetDataType();
489     weights_desc.element_size = 4;
490     weights_desc.memory_type = mem_type;
491     desc.AddSrcBuffer("weights", weights_desc);
492   } else {
493     BufferDescriptor weights_desc;
494     weights_desc.element_type = weights_type;
495     weights_desc.element_size = 4;
496     weights_desc.memory_type = mem_type;
497     weights_desc.data = ReorderWeightsForConv(
498         attr.weights, weights_type, desc.params_.slices_per_thread / 2);
499     weights_desc.size = weights_desc.data.size();
500     desc.args_.AddObject(
501         "weights", std::make_unique<BufferDescriptor>(std::move(weights_desc)));
502   }
503 
504   BufferDescriptor bias_desc;
505   bias_desc.element_type = weights_type;
506   bias_desc.element_size = 4;
507   bias_desc.memory_type = mem_type;
508   bias_desc.data = ReorderBiasesForConv(attr.bias, weights_type,
509                                         AlignByN(attr.weights.shape.o, 4 * 4));
510   bias_desc.size = bias_desc.data.size();
511   desc.args_.AddObject(
512       "biases", std::make_unique<BufferDescriptor>(std::move(bias_desc)));
513 
514   desc.work_group_size_ = desc.params_.work_group_size;
515   desc.work_group_launch_order_ = desc.params_.work_group_launch_order;
516   if (desc.params_.linear_spatial) {
517     desc.grid_dimension_ = 2;
518   } else {
519     desc.grid_dimension_ = 3;
520   }
521 
522   return desc;
523 }
524 
IsConvolutionMetalSimdSupported(const GpuInfo & gpu_info,const OperationDef & definition,const Convolution2DAttributes & attr)525 bool IsConvolutionMetalSimdSupported(const GpuInfo& gpu_info,
526                                      const OperationDef& definition,
527                                      const Convolution2DAttributes& attr) {
528   if (!gpu_info.IsApple() || !gpu_info.metal_info.IsSIMDMatMulSupported() ||
529       !gpu_info.apple_info.IsSIMDMatMulSupported()) {
530     return false;
531   }
532   const bool genuine_1x1 =
533       attr.weights.shape.w == 1 && attr.weights.shape.h == 1 &&
534       attr.dilations.w == 1 && attr.dilations.h == 1 && attr.strides.w == 1 &&
535       attr.strides.h == 1 && attr.padding.prepended.w == 0 &&
536       attr.padding.prepended.h == 0 && attr.padding.appended.w == 0 &&
537       attr.padding.appended.h == 0 && attr.groups == 1;
538   const int src_slices = DivideRoundUp(attr.weights.shape.i, 4);
539   const int dst_slices = DivideRoundUp(attr.weights.shape.o, 4);
540   return genuine_1x1 && src_slices % 4 == 0 && dst_slices % 16 == 0;
541 }
542 
IsGoodTaskSizeForAppleConvSimd(const BHWC & dst_shape,const GpuInfo & gpu_info)543 bool IsGoodTaskSizeForAppleConvSimd(const BHWC& dst_shape,
544                                     const GpuInfo& gpu_info) {
545   const uint64_t task_size_spatial = dst_shape.b * dst_shape.h * dst_shape.w;
546   const uint64_t wave_size = 32;
547   const double useful_part = static_cast<double>(task_size_spatial) /
548                              AlignByN(task_size_spatial, wave_size);
549   if (useful_part < 0.625) {
550     return false;
551   }
552   const double task_size_slices = DivideRoundUp(dst_shape.c, 16);
553   const double task_size = task_size_spatial * task_size_slices;
554   const double task_size_per_cu = task_size / gpu_info.GetComputeUnitsCount();
555   const double waves_per_cu = task_size_per_cu / wave_size;
556   return waves_per_cu >= 8.0;
557 }
558 
559 }  // namespace gpu
560 }  // namespace tflite
561