1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/conv_metal_simd.h"
17
18 #include <cmath>
19 #include <cstdint>
20 #include <functional>
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
28 #include "tensorflow/lite/delegates/gpu/common/gpu_info.h"
29 #include "tensorflow/lite/delegates/gpu/common/operations.h"
30 #include "tensorflow/lite/delegates/gpu/common/shape.h"
31 #include "tensorflow/lite/delegates/gpu/common/types.h"
32 #include "tensorflow/lite/delegates/gpu/common/util.h"
33
34 namespace tflite {
35 namespace gpu {
36 namespace {
GenerateDstCoords(const int3 & work_group_launch_order,bool linear_spatial,bool need_depth,bool need_batch)37 std::string GenerateDstCoords(const int3& work_group_launch_order,
38 bool linear_spatial, bool need_depth,
39 bool need_batch) {
40 std::string c;
41 int3 launch_remap;
42 launch_remap[work_group_launch_order.x] = 0;
43 launch_remap[work_group_launch_order.y] = 1;
44 launch_remap[work_group_launch_order.z] = 2;
45 if (linear_spatial) {
46 if (work_group_launch_order[0] == 0) {
47 c += " int linear_spatial = GLOBAL_ID_0;\n";
48 } else {
49 c += " int linear_spatial = GROUP_ID_" +
50 std::to_string(launch_remap[0]) + " * GROUP_SIZE_0 + LOCAL_ID_0;\n";
51 }
52 if (need_batch) {
53 c += " int B = linear_spatial % args.dst_tensor.Batch();\n";
54 c += " linear_spatial = linear_spatial / args.dst_tensor.Batch();\n";
55 }
56 if (need_depth) {
57 c += " int DST_X = linear_spatial % args.dst_tensor.Width();\n";
58 c += " linear_spatial = linear_spatial / args.dst_tensor.Width();\n";
59 c += " int DST_Y = linear_spatial % args.dst_tensor.Height();\n";
60 c += " int DST_Z = linear_spatial / args.dst_tensor.Height();\n";
61 } else {
62 c += " int DST_Y = linear_spatial / args.dst_tensor.Width();\n";
63 c += " int DST_X = linear_spatial % args.dst_tensor.Width();\n";
64 }
65 if (work_group_launch_order[1] == 1) {
66 c += " int DST_S = GLOBAL_ID_1;\n";
67 } else {
68 c += " int DST_S = GROUP_ID_" + std::to_string(launch_remap[1]) +
69 " * GROUP_SIZE_1 + LOCAL_ID_1;\n";
70 }
71 } else {
72 if (work_group_launch_order[0] == 0) {
73 c += " int DST_X = GLOBAL_ID_0;\n";
74 } else {
75 c += " int DST_X = GROUP_ID_" + std::to_string(launch_remap[0]) +
76 " * GROUP_SIZE_0 + LOCAL_ID_0;\n";
77 }
78 if (need_batch) {
79 c += " int B = DST_X % args.dst_tensor.Batch();\n";
80 c += " DST_X = DST_X / args.dst_tensor.Batch();\n";
81 }
82 std::string global_id_1;
83 if (work_group_launch_order[1] == 1) {
84 global_id_1 = "GLOBAL_ID_1";
85 } else {
86 global_id_1 = "GROUP_ID_" + std::to_string(launch_remap[1]) +
87 " * GROUP_SIZE_1 + LOCAL_ID_1";
88 }
89 if (need_depth) {
90 c += " int linear_id_1 = " + global_id_1 + ";\n";
91 c += " int DST_Z = linear_id_1 / dst_tensor.Height();\n";
92 c += " int DST_Y = linear_id_1 % dst_tensor.Height();\n";
93 } else {
94 c += " int DST_Y = " + global_id_1 + ";\n";
95 }
96 if (work_group_launch_order[2] == 2) {
97 c += " int DST_S = GLOBAL_ID_2;\n";
98 } else {
99 c += " int DST_S = GROUP_ID_" + std::to_string(launch_remap[2]) +
100 " * GROUP_SIZE_2 + LOCAL_ID_2;\n";
101 }
102 }
103
104 return c;
105 }
106
GenerateConvolution(const OperationDef & definition,const ConvolutionMetalSimd::ConvParams & conv_params)107 std::string GenerateConvolution(
108 const OperationDef& definition,
109 const ConvolutionMetalSimd::ConvParams& conv_params) {
110 std::string c;
111 c += "#define MMA simdgroup_multiply_accumulate\n";
112 const int spatial_threads = conv_params.GetSpatialThreadsCount();
113 c += "#define SPATIAL_THREADS " + std::to_string(spatial_threads) + "\n";
114 c += "MAIN_FUNCTION($0) {\n";
115 c += GenerateDstCoords(conv_params.work_group_launch_order,
116 conv_params.linear_spatial,
117 definition.src_tensors[0].HasAxis(Axis::DEPTH),
118 definition.src_tensors[0].HasAxis(Axis::BATCH));
119 if (definition.src_tensors[0].HasAxis(Axis::BATCH)) {
120 c += " args.src_tensor.SetBatchRef(B);\n";
121 c += " args.dst_tensor.SetBatchRef(B);\n";
122 }
123 if (conv_params.slices_per_thread != 1) {
124 c += " DST_S *= " + std::to_string(conv_params.slices_per_thread) + ";\n";
125 }
126 c += " device FLT4* f_offseted = args.weights.GetPtr() + DST_S * 4 * "
127 "args.src_tensor.Slices();\n";
128 const bool cache_weight = spatial_threads > 32;
129 const int src_x4_slices = conv_params.GetX4SlicesCount();
130 const int src_x8_slices = src_x4_slices / 2;
131 const int dst_x4_slices = conv_params.slices_per_thread;
132 const int dst_x8_slices = dst_x4_slices / 2;
133 const int weights_tiles8x8_per_spatial = src_x8_slices * dst_x8_slices;
134 const int weights_flt4_per_spatial = weights_tiles8x8_per_spatial * 8 * 8 / 4;
135 if (conv_params.linear_spatial) {
136 c += " int spatial_id = LOCAL_ID_0;\n";
137 c += " int slice_id = LOCAL_ID_1;\n";
138 } else {
139 c += " int spatial_id = LOCAL_ID_1 * GROUP_SIZE_0 + LOCAL_ID_0;\n";
140 c += " int slice_id = LOCAL_ID_2;\n";
141 }
142 c += " int tid = slice_id * SPATIAL_THREADS + spatial_id;\n";
143 if (cache_weight) {
144 c += " threadgroup FLT4 tmp_w[" +
145 std::to_string(weights_flt4_per_spatial *
146 conv_params.GetX4SlicesCount()) +
147 "];\n";
148 c += " threadgroup FLT* tmp_w_x1 = (threadgroup FLT*)tmp_w;\n";
149 c += " tmp_w_x1 += " + std::to_string(weights_flt4_per_spatial * 4) +
150 " * slice_id;\n";
151 c += " threadgroup FLT4* tmp_w_x4 = (threadgroup FLT4*)tmp_w_x1;\n\n";
152 } else {
153 c += " device FLT* f_offseted_x1 = (device FLT*)f_offseted;\n\n";
154 }
155 c += " threadgroup FLT4 tmp_src[SPATIAL_THREADS * " +
156 std::to_string(src_x4_slices) + "];\n";
157 c += " threadgroup FLT* tmp_src_x1 = (threadgroup FLT*)tmp_src;\n\n";
158 c += " // sp - spatial dimensions, ch - channels dimension\n";
159 c += " // indexing relative to simdgroup\n";
160 for (int sp = 0; sp < 32; sp += 8) {
161 const std::string sp_start = std::to_string(sp);
162 const std::string sp_end = std::to_string(sp + 8);
163 const std::string dst_name = "dst_sp" + sp_start + "_" + sp_end;
164 for (int slice = 0; slice < dst_x8_slices; slice += 1) {
165 const std::string sl_start = std::to_string(slice * 8);
166 const std::string sl_end = std::to_string(slice * 8 + 8);
167 c += " simdgroup_matrix<FLT, 8, 8> " + dst_name + "_ch" + sl_start +
168 "_" + sl_end + "(0.0f);\n";
169 }
170 }
171 if (spatial_threads > 32) {
172 c += " int spatial_group = spatial_id / 32;\n";
173 c += " tmp_src_x1 += 8 * 8 * 4 * spatial_group;\n";
174 }
175 c += R"(
176 int c_x = min(DST_X, args.src_tensor.Width() - 1);
177 int c_y = min(DST_Y, args.src_tensor.Height() - 1);
178 )";
179 if (definition.src_tensors[0].IsLinear()) {
180 c +=
181 " int src_address = args.src_tensor.GetAddress(c_x, c_y, slice_id);\n";
182 }
183 c += R"(
184 int tid2 = 0;
185 if (tid < SPATIAL_THREADS) {
186 tid2 = tid * 2 + 0;
187 } else if (tid < SPATIAL_THREADS * 2) {
188 tid2 = (tid - SPATIAL_THREADS) * 2 + 1;
189 })";
190 for (int src_s = 1; src_s < src_x8_slices; ++src_s) {
191 c += " else if (tid < SPATIAL_THREADS * " + std::to_string(src_s * 2 + 1) +
192 ") {\n";
193 c += " tid2 = (tid - SPATIAL_THREADS * " + std::to_string(src_s * 2) +
194 ") * 2 + 0 + SPATIAL_THREADS * " + std::to_string(src_s * 2) + ";\n";
195 c += " } else if (tid < SPATIAL_THREADS * " +
196 std::to_string(src_s * 2 + 2) + ") {\n";
197 c += " tid2 = (tid - SPATIAL_THREADS * " +
198 std::to_string(src_s * 2 + 1) + ") * 2 + 1 + SPATIAL_THREADS * " +
199 std::to_string(src_s * 2) + ";\n";
200 c += " }";
201 }
202 c += "\n\n";
203 c += " for (int s = 0; s < args.src_tensor.Slices(); s += " +
204 std::to_string(src_x4_slices) + ") {\n";
205 for (int src_s = 0; src_s < src_x8_slices; ++src_s) {
206 const std::string src_range =
207 "i" + std::to_string(src_s * 8) + "_" + std::to_string(src_s * 8 + 8);
208 for (int dst_s = 0; dst_s < dst_x8_slices; ++dst_s) {
209 const std::string dst_range =
210 "o" + std::to_string(dst_s * 8) + "_" + std::to_string(dst_s * 8 + 8);
211 const std::string w_name = "w_" + dst_range + "_" + src_range;
212 c += " simdgroup_matrix<FLT, 8, 8> " + w_name + ";\n";
213 }
214 }
215 c += " threadgroup_barrier(mem_flags::mem_threadgroup);\n";
216 if (cache_weight) {
217 const int groups = weights_flt4_per_spatial / spatial_threads;
218 const int reminder = weights_flt4_per_spatial % spatial_threads;
219 for (int i = 0; i < groups; ++i) {
220 c += " tmp_w_x4[spatial_id + " + std::to_string(spatial_threads * i) +
221 "] = f_offseted[spatial_id + " +
222 std::to_string(spatial_threads * i) + "];\n";
223 }
224 if (reminder != 0) {
225 c += " if (spatial_id < " + std::to_string(reminder) + ") {\n";
226 c += " tmp_w_x4[spatial_id + " +
227 std::to_string(spatial_threads * groups) +
228 "] = f_offseted[spatial_id + " +
229 std::to_string(spatial_threads * groups) + "];\n";
230 c += " }\n";
231 }
232 } else {
233 for (int src_s = 0; src_s < src_x8_slices; ++src_s) {
234 const std::string src_range =
235 "i" + std::to_string(src_s * 8) + "_" + std::to_string(src_s * 8 + 8);
236 for (int dst_s = 0; dst_s < dst_x8_slices; ++dst_s) {
237 const std::string dst_range = "o" + std::to_string(dst_s * 8) + "_" +
238 std::to_string(dst_s * 8 + 8);
239 const std::string w_name = "w_" + dst_range + "_" + src_range;
240 c += " simdgroup_load(" + w_name + ", f_offseted_x1 + " +
241 std::to_string((src_s * dst_x8_slices + dst_s) * 64) + ", 8);\n";
242 }
243 }
244 }
245 if (definition.src_tensors[0].IsLinear()) {
246 c += " tmp_src[tid2] = args.src_tensor.Read(src_address);\n";
247 } else {
248 c += " tmp_src[tid2] = args.src_tensor.Read(c_x, c_y, s + slice_id);\n";
249 }
250 if (cache_weight) {
251 c += " f_offseted += 16 * " +
252 std::to_string(src_x8_slices * dst_x8_slices) + ";\n";
253 } else {
254 c += " f_offseted_x1 += 64 * " +
255 std::to_string(src_x8_slices * dst_x8_slices) + ";\n";
256 }
257 if (definition.src_tensors[0].IsLinear()) {
258 c += " src_address += args.src_tensor.SliceStride() * " +
259 std::to_string(src_x4_slices) + ";\n";
260 }
261 c += " threadgroup_barrier(mem_flags::mem_threadgroup);\n";
262 if (cache_weight) {
263 for (int src_s = 0; src_s < src_x8_slices; ++src_s) {
264 const std::string src_range =
265 "i" + std::to_string(src_s * 8) + "_" + std::to_string(src_s * 8 + 8);
266 for (int dst_s = 0; dst_s < dst_x8_slices; ++dst_s) {
267 const std::string dst_range = "o" + std::to_string(dst_s * 8) + "_" +
268 std::to_string(dst_s * 8 + 8);
269 const std::string w_name = "w_" + dst_range + "_" + src_range;
270 c += " simdgroup_load(" + w_name + ", tmp_w_x1 + " +
271 std::to_string((src_s * dst_x8_slices + dst_s) * 64) + ", 8);\n";
272 }
273 }
274 }
275 c += " simdgroup_matrix<FLT, 8, 8> mat_src;\n";
276 const int spatial_x8_count = spatial_threads / 8;
277 for (int src_s = 0; src_s < src_x8_slices; ++src_s) {
278 const std::string src_s_range =
279 std::to_string(src_s * 8) + "_" + std::to_string(src_s * 8 + 8);
280 for (int sp = 0; sp < 32; sp += 8) {
281 const std::string sp_range =
282 std::to_string(sp) + "_" + std::to_string(sp + 8);
283 const int src_tile_offset = src_s * spatial_x8_count + (sp / 8);
284 c += " simdgroup_load(mat_src, tmp_src_x1 + " +
285 std::to_string(src_tile_offset * 64) + ", 8); // loading sp[" +
286 sp_range + "] src_ch[" + src_s_range + "]\n";
287 for (int dst_s = 0; dst_s < dst_x8_slices; ++dst_s) {
288 const std::string dst_s_range =
289 std::to_string(dst_s * 8) + "_" + std::to_string(dst_s * 8 + 8);
290 const std::string dst_name = "dst_sp" + sp_range + "_ch" + dst_s_range;
291 const std::string w_name = "w_o" + dst_s_range + "_i" + src_s_range;
292 c += " MMA(" + dst_name + ", mat_src, " + w_name + ", " + dst_name +
293 ");\n";
294 }
295 }
296 }
297 c += " }\n";
298 for (int slice = 0; slice < dst_x8_slices * 2; slice += 1) {
299 c += " FLT4 r" + std::to_string(slice) + " = INIT_FLT4(0.0);\n";
300 }
301 c += " // transferring from simdgroup memory to private registers.\n";
302 c += " const int kSpatialGroupsCount = " + std::to_string(src_x4_slices) +
303 ";\n";
304 c += " for (int i = 0; i < kSpatialGroupsCount; ++i) {\n";
305 c += " int spatial_id = tid - i * SPATIAL_THREADS;\n";
306 c += " bool current_spatial_group = spatial_id >= 0 && spatial_id < "
307 "SPATIAL_THREADS;\n";
308 for (int dst_s = 0; dst_s < dst_x8_slices; ++dst_s) {
309 const std::string dst_range =
310 "ch" + std::to_string(dst_s * 8) + "_" + std::to_string(dst_s * 8 + 8);
311 c += " threadgroup_barrier(mem_flags::mem_threadgroup);\n";
312 c += " if (current_spatial_group) {\n";
313 c += " simdgroup_store(dst_sp0_8_" + dst_range + ", tmp_src_x1, 8);\n";
314 c += " simdgroup_store(dst_sp8_16_" + dst_range +
315 ", tmp_src_x1 + 64, 8);\n";
316 c += " simdgroup_store(dst_sp16_24_" + dst_range +
317 ", tmp_src_x1 + 64 * 2, 8);\n";
318 c += " simdgroup_store(dst_sp24_32_" + dst_range +
319 ", tmp_src_x1 + 64 * 3, 8);\n";
320 c += " }\n";
321 c += " threadgroup_barrier(mem_flags::mem_threadgroup);\n";
322 c += " if (current_spatial_group) {\n";
323 c += " r" + std::to_string(dst_s * 2 + 0) +
324 " += tmp_src[spatial_id * 2 + 0];\n";
325 c += " r" + std::to_string(dst_s * 2 + 1) +
326 " += tmp_src[spatial_id * 2 + 1];\n";
327 c += " }\n";
328 }
329 c += " }\n";
330 c += " if (DST_X >= args.dst_tensor.Width() || DST_Y >= "
331 "args.dst_tensor.Height()) {\n";
332 c += " return;\n";
333 c += " }\n";
334 for (int slice = 0; slice < dst_x8_slices * 2; slice += 1) {
335 const std::string dst_s = "DST_S + " + std::to_string(slice);
336 const std::string r_name = "r" + std::to_string(slice);
337 c += " if (" + dst_s + " < args.dst_tensor.Slices()) {\n";
338 c += " " + r_name + " += args.biases.Read(" + dst_s + ");\n";
339 c += " args.dst_tensor.Write(" + r_name + ", DST_X, DST_Y, " + dst_s +
340 ");\n";
341 c += " }\n";
342 }
343 c += "}\n";
344 return c;
345 }
346
OIToVecOIOGroupIO(const std::vector<float> & src,int o_size,int i_size,int vec_size,int o_group_size,std::vector<float> * dst)347 void OIToVecOIOGroupIO(const std::vector<float>& src, int o_size, int i_size,
348 int vec_size, int o_group_size,
349 std::vector<float>* dst) {
350 int o_slices = DivideRoundUp(o_size, vec_size);
351 int i_slices = DivideRoundUp(i_size, vec_size);
352 int o_groups = DivideRoundUp(o_slices, o_group_size);
353 dst->resize(o_slices * vec_size * i_slices * vec_size);
354 for (int os = 0; os < o_groups; ++os) {
355 for (int is = 0; is < i_slices; ++is) {
356 for (int o_group = 0; o_group < o_group_size; ++o_group) {
357 for (int sub_o = 0; sub_o < vec_size; ++sub_o) {
358 for (int sub_i = 0; sub_i < vec_size; ++sub_i) {
359 float value = 0.0f;
360 int i_ch = is * vec_size + sub_i;
361 int o_ch = (os * o_group_size + o_group) * vec_size + sub_o;
362 if (i_ch < i_size && o_ch < o_size) {
363 value = src[o_ch * i_size + i_ch];
364 }
365 (*dst)[(((os * i_slices + is) * o_group_size + o_group) * vec_size +
366 sub_i) *
367 vec_size +
368 sub_o] = value;
369 }
370 }
371 }
372 }
373 }
374 }
375
ReorderWeightsForConv(const tflite::gpu::Tensor<OHWI,DataType::FLOAT32> & weights,const DataType & weights_type,int dst_x8_slices)376 std::vector<uint8_t> ReorderWeightsForConv(
377 const tflite::gpu::Tensor<OHWI, DataType::FLOAT32>& weights,
378 const DataType& weights_type, int dst_x8_slices) {
379 std::vector<float> weights_gpu;
380 OIToVecOIOGroupIO(weights.data, weights.shape.o, weights.shape.i, 8,
381 dst_x8_slices, &weights_gpu);
382 std::vector<uint8_t> result(weights_gpu.size() * SizeOf(weights_type));
383 if (weights_type == DataType::FLOAT32) {
384 float* gpu_data = reinterpret_cast<float*>(result.data());
385 for (int i = 0; i < weights_gpu.size(); ++i) {
386 gpu_data[i] = weights_gpu[i];
387 }
388 } else {
389 half* gpu_data = reinterpret_cast<half*>(result.data());
390 for (int i = 0; i < weights_gpu.size(); ++i) {
391 gpu_data[i] = weights_gpu[i];
392 }
393 }
394 return result;
395 }
396
ReorderBiasesForConv(const tflite::gpu::Tensor<Linear,DataType::FLOAT32> & biases,const DataType & biases_type,int output_size)397 std::vector<uint8_t> ReorderBiasesForConv(
398 const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& biases,
399 const DataType& biases_type, int output_size) {
400 std::vector<uint8_t> result(output_size * SizeOf(biases_type));
401 if (biases_type == DataType::FLOAT32) {
402 float* gpu_data = reinterpret_cast<float*>(result.data());
403 for (int i = 0; i < output_size; ++i) {
404 gpu_data[i] = i < biases.shape.v ? biases.data[i] : 0.0f;
405 }
406 } else {
407 half* gpu_data = reinterpret_cast<half*>(result.data());
408 for (int i = 0; i < output_size; ++i) {
409 gpu_data[i] = i < biases.shape.v ? biases.data[i] : 0.0f;
410 }
411 }
412 return result;
413 }
414
Get2DWorkgroupsEqualTo32()415 std::vector<int2> Get2DWorkgroupsEqualTo32() {
416 return {{8, 4}, {16, 2}, {4, 8}, {32, 1}, {2, 16}, {1, 32}};
417 }
418
Get2dGroupsCount(const BHWC & dst_shape,const int2 group_size)419 int Get2dGroupsCount(const BHWC& dst_shape, const int2 group_size) {
420 int x_groups = DivideRoundUp(dst_shape.w * dst_shape.b, group_size.x);
421 int y_groups = DivideRoundUp(dst_shape.h, group_size.y);
422 return x_groups * y_groups;
423 }
424
GetOptimalGroupSize(const BHWC & dst_shape)425 int2 GetOptimalGroupSize(const BHWC& dst_shape) {
426 const auto base_work_groups = Get2DWorkgroupsEqualTo32();
427 int min_2d_work_groups = Get2dGroupsCount(dst_shape, base_work_groups[0]);
428 int min_index = 0;
429 for (int i = 1; i < base_work_groups.size(); ++i) {
430 int groups_count = Get2dGroupsCount(dst_shape, base_work_groups[i]);
431 if (groups_count < min_2d_work_groups) {
432 min_2d_work_groups = groups_count;
433 min_index = i;
434 }
435 }
436 return base_work_groups[min_index];
437 }
438
439 } // namespace
440
GetGridSize() const441 int3 ConvolutionMetalSimd::GetGridSize() const {
442 const int task_size_x = dst_[0]->Width() * dst_[0]->Batch();
443 const int task_size_y = dst_[0]->Height();
444 const int task_size_z = dst_[0]->Depth();
445 const int task_size_s =
446 DivideRoundUp(dst_[0]->Slices(), params_.slices_per_thread);
447 if (params_.linear_spatial) {
448 return int3(task_size_x * task_size_y * task_size_z, task_size_s, 1);
449 } else {
450 return int3(task_size_x, task_size_y * task_size_z, task_size_s);
451 }
452 }
453
CreateConvolutionMetalSimd(const OperationDef & definition,const BHWC & dst_shape,const Convolution2DAttributes & attr,const GpuInfo & gpu_info)454 ConvolutionMetalSimd CreateConvolutionMetalSimd(
455 const OperationDef& definition, const BHWC& dst_shape,
456 const Convolution2DAttributes& attr, const GpuInfo& gpu_info) {
457 ConvolutionMetalSimd desc(definition);
458 const int2 optimal_2d_group_size = GetOptimalGroupSize(dst_shape);
459 const int groups2d_count = Get2dGroupsCount(dst_shape, optimal_2d_group_size);
460 const int groups1d_count =
461 DivideRoundUp(dst_shape.w * dst_shape.b * dst_shape.h, 32);
462 if (groups1d_count < groups2d_count) {
463 desc.params_.work_group_size = int3(32, 4, 1);
464 desc.params_.work_group_launch_order = int3(0, 1, 2);
465 desc.params_.linear_spatial = true;
466 } else {
467 desc.params_.work_group_size =
468 int3(optimal_2d_group_size.x, optimal_2d_group_size.y, 4);
469 desc.params_.work_group_launch_order = int3(0, 1, 2);
470 desc.params_.linear_spatial = false;
471 }
472 desc.params_.slices_per_thread = 4;
473 desc.params_.x_kernel_is_1 = true;
474 desc.params_.y_kernel_is_1 = true;
475 desc.params_.z_kernel_is_1 = true;
476 desc.code_ = GenerateConvolution(definition, desc.params_);
477
478 desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
479 desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
480
481 auto weights_type = DeduceDataTypeFromPrecision(definition.precision);
482
483 MemoryType mem_type = MemoryType::GLOBAL;
484
485 if (definition.src_tensors.size() == 2) {
486 // dynamic weights
487 BufferDescriptor weights_desc;
488 weights_desc.element_type = definition.src_tensors[1].GetDataType();
489 weights_desc.element_size = 4;
490 weights_desc.memory_type = mem_type;
491 desc.AddSrcBuffer("weights", weights_desc);
492 } else {
493 BufferDescriptor weights_desc;
494 weights_desc.element_type = weights_type;
495 weights_desc.element_size = 4;
496 weights_desc.memory_type = mem_type;
497 weights_desc.data = ReorderWeightsForConv(
498 attr.weights, weights_type, desc.params_.slices_per_thread / 2);
499 weights_desc.size = weights_desc.data.size();
500 desc.args_.AddObject(
501 "weights", std::make_unique<BufferDescriptor>(std::move(weights_desc)));
502 }
503
504 BufferDescriptor bias_desc;
505 bias_desc.element_type = weights_type;
506 bias_desc.element_size = 4;
507 bias_desc.memory_type = mem_type;
508 bias_desc.data = ReorderBiasesForConv(attr.bias, weights_type,
509 AlignByN(attr.weights.shape.o, 4 * 4));
510 bias_desc.size = bias_desc.data.size();
511 desc.args_.AddObject(
512 "biases", std::make_unique<BufferDescriptor>(std::move(bias_desc)));
513
514 desc.work_group_size_ = desc.params_.work_group_size;
515 desc.work_group_launch_order_ = desc.params_.work_group_launch_order;
516 if (desc.params_.linear_spatial) {
517 desc.grid_dimension_ = 2;
518 } else {
519 desc.grid_dimension_ = 3;
520 }
521
522 return desc;
523 }
524
IsConvolutionMetalSimdSupported(const GpuInfo & gpu_info,const OperationDef & definition,const Convolution2DAttributes & attr)525 bool IsConvolutionMetalSimdSupported(const GpuInfo& gpu_info,
526 const OperationDef& definition,
527 const Convolution2DAttributes& attr) {
528 if (!gpu_info.IsApple() || !gpu_info.metal_info.IsSIMDMatMulSupported() ||
529 !gpu_info.apple_info.IsSIMDMatMulSupported()) {
530 return false;
531 }
532 const bool genuine_1x1 =
533 attr.weights.shape.w == 1 && attr.weights.shape.h == 1 &&
534 attr.dilations.w == 1 && attr.dilations.h == 1 && attr.strides.w == 1 &&
535 attr.strides.h == 1 && attr.padding.prepended.w == 0 &&
536 attr.padding.prepended.h == 0 && attr.padding.appended.w == 0 &&
537 attr.padding.appended.h == 0 && attr.groups == 1;
538 const int src_slices = DivideRoundUp(attr.weights.shape.i, 4);
539 const int dst_slices = DivideRoundUp(attr.weights.shape.o, 4);
540 return genuine_1x1 && src_slices % 4 == 0 && dst_slices % 16 == 0;
541 }
542
IsGoodTaskSizeForAppleConvSimd(const BHWC & dst_shape,const GpuInfo & gpu_info)543 bool IsGoodTaskSizeForAppleConvSimd(const BHWC& dst_shape,
544 const GpuInfo& gpu_info) {
545 const uint64_t task_size_spatial = dst_shape.b * dst_shape.h * dst_shape.w;
546 const uint64_t wave_size = 32;
547 const double useful_part = static_cast<double>(task_size_spatial) /
548 AlignByN(task_size_spatial, wave_size);
549 if (useful_part < 0.625) {
550 return false;
551 }
552 const double task_size_slices = DivideRoundUp(dst_shape.c, 16);
553 const double task_size = task_size_spatial * task_size_slices;
554 const double task_size_per_cu = task_size / gpu_info.GetComputeUnitsCount();
555 const double waves_per_cu = task_size_per_cu / wave_size;
556 return waves_per_cu >= 8.0;
557 }
558
559 } // namespace gpu
560 } // namespace tflite
561