xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/tasks/winograd.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/winograd.h"
17 
18 #include <cstring>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
26 #include "tensorflow/lite/delegates/gpu/common/shape.h"
27 #include "tensorflow/lite/delegates/gpu/common/status.h"
28 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
29 #include "tensorflow/lite/delegates/gpu/common/winograd_util.h"
30 
31 namespace tflite {
32 namespace gpu {
33 namespace {
VectorToKernelBufferDesc(const std::vector<float> & data,DataType data_type,BufferDescriptor * buffer_desc)34 void VectorToKernelBufferDesc(const std::vector<float>& data,
35                               DataType data_type,
36                               BufferDescriptor* buffer_desc) {
37   buffer_desc->element_type = data_type;
38   buffer_desc->element_size = 1;
39   buffer_desc->memory_type = MemoryType::CONSTANT;
40   buffer_desc->attributes.push_back("kernel_global_space");
41   buffer_desc->size = SizeOf(data_type) * data.size();
42   buffer_desc->data.resize(buffer_desc->size);
43   if (data_type == DataType::FLOAT32) {
44     memcpy(buffer_desc->data.data(), data.data(), buffer_desc->size);
45   } else {
46     half* hf_ptr = reinterpret_cast<half*>(buffer_desc->data.data());
47     for (int i = 0; i < data.size(); ++i) {
48       hf_ptr[i] = data[i];
49     }
50   }
51 }
GetKernelWinograd4x4To36(const GpuInfo & gpu_info,const OperationDef & op_def)52 std::string GetKernelWinograd4x4To36(const GpuInfo& gpu_info,
53                                      const OperationDef& op_def) {
54   std::string c;
55   const auto src_desc = op_def.src_tensors[0];
56   c += R"(
57 MAIN_FUNCTION($0) {
58   int X = GLOBAL_ID_0 * 4;
59   int Y = GLOBAL_ID_1 * 4;
60   int S = GLOBAL_ID_2;
61 
62   if (GLOBAL_ID_0 >= args.tiles_x || GLOBAL_ID_1 >= args.tiles_y) return;
63 
64   FLT4 I[6][6];
65   for (int y = 0; y < 6; ++y) {
66     for (int x = 0; x < 6; ++x) {
67       I[y][x] = INIT_FLT4(0.0f);
68     }
69   }
70 )";
71   if (src_desc.IsLinear()) {
72     c += "  int src_base = args.src_tensor.GetAddress(0, 0, S);\n";
73   }
74   for (int y = 0; y < 6; ++y) {
75     const std::string s_y = std::to_string(y);
76     c += "  {\n";
77     c += "    int coord_y = Y + " + s_y + " + args.padding_y;\n";
78     if (!src_desc.SupportsZeroClamp(Axis::HEIGHT, gpu_info)) {
79       c += "    bool in_y = coord_y >= 0 && coord_y < "
80            "args.src_tensor.Height();\n";
81       c += "    coord_y = clamp(coord_y, 0, args.src_tensor.Height() - 1);\n";
82     }
83     if (src_desc.IsLinear()) {
84       c += "    int src_adress_y = src_base + coord_y * "
85            "args.src_tensor.Width();\n";
86     }
87     for (int x = 0; x < 6; ++x) {
88       const std::string s_x = std::to_string(x);
89       c += "    {\n";
90       c += "      int coord_x = X + " + s_x + " + args.padding_x;\n";
91       if (!src_desc.SupportsZeroClamp(Axis::WIDTH, gpu_info)) {
92         c += "      bool in_x = coord_x >= 0 && coord_x < "
93              "args.src_tensor.Width();\n";
94         c += "      coord_x = clamp(coord_x, 0, args.src_tensor.Width()-1);\n";
95       }
96       std::string multiplier;
97       if (!src_desc.SupportsZeroClamp(Axis::WIDTH, gpu_info) &&
98           !src_desc.SupportsZeroClamp(Axis::HEIGHT, gpu_info)) {
99         multiplier = " * INIT_FLT(in_y && in_x)";
100       } else if (!src_desc.SupportsZeroClamp(Axis::WIDTH, gpu_info)) {
101         multiplier = " * INIT_FLT(in_x)";
102       } else if (!src_desc.SupportsZeroClamp(Axis::HEIGHT, gpu_info)) {
103         multiplier = " * INIT_FLT(in_y)";
104       }
105       if (src_desc.IsLinear()) {
106         c += "      FLT4 src = args.src_tensor.Read(src_adress_y + coord_x)" +
107              multiplier + ";\n";
108       } else {
109         c += "      FLT4 src = args.src_tensor.Read(coord_x, coord_y, S)" +
110              multiplier + ";\n";
111       }
112       c += "      I[0][" + s_x + "] += args.Bt.Read(" + std::to_string(y) +
113            ") * src;\n";
114       c += "      I[1][" + s_x + "] += args.Bt.Read(" + std::to_string(y + 6) +
115            ") * src;\n";
116       c += "      I[2][" + s_x + "] += args.Bt.Read(" + std::to_string(y + 12) +
117            ") * src;\n";
118       c += "      I[3][" + s_x + "] += args.Bt.Read(" + std::to_string(y + 18) +
119            ") * src;\n";
120       c += "      I[4][" + s_x + "] += args.Bt.Read(" + std::to_string(y + 24) +
121            ") * src;\n";
122       c += "      I[5][" + s_x + "] += args.Bt.Read(" + std::to_string(y + 30) +
123            ") * src;\n";
124       c += "    }\n";
125     }
126     c += "  }\n";
127   }
128 
129   c += R"(
130   int dst_x = GLOBAL_ID_1 * args.tiles_x + GLOBAL_ID_0;
131   for (int y = 0; y < 6; ++y) {
132     FLT4 value = I[y][0] + args.Bt.Read(2) * I[y][2] + args.Bt.Read(4) * I[y][4];
133     args.dst_tensor.Write(value, dst_x, y * 6 + 0, S);
134     value = args.Bt.Read(7) * I[y][1] + args.Bt.Read(8) * I[y][2] + args.Bt.Read(9) * I[y][3] + args.Bt.Read(10) * I[y][4];
135     args.dst_tensor.Write(value, dst_x, y * 6 + 1, S);
136     value = args.Bt.Read(13) * I[y][1] + args.Bt.Read(14) * I[y][2] + args.Bt.Read(15) * I[y][3] + args.Bt.Read(16) * I[y][4];
137     args.dst_tensor.Write(value, dst_x, y * 6 + 2, S);
138     value = args.Bt.Read(19) * I[y][1] + args.Bt.Read(20) * I[y][2] + args.Bt.Read(21) * I[y][3] + args.Bt.Read(22) * I[y][4];
139     args.dst_tensor.Write(value, dst_x, y * 6 + 3, S);
140     value = args.Bt.Read(25) * I[y][1] + args.Bt.Read(26) * I[y][2] + args.Bt.Read(27) * I[y][3] + args.Bt.Read(28) * I[y][4];
141     args.dst_tensor.Write(value, dst_x, y * 6 + 4, S);
142     value = args.Bt.Read(31) * I[y][1] + args.Bt.Read(33) * I[y][3] + I[y][5];
143     args.dst_tensor.Write(value, dst_x, y * 6 + 5, S);
144   }
145 }
146 )";
147   return c;
148 }
149 
GetKernelWinograd36To4x4(const OperationDef & op_def)150 std::string GetKernelWinograd36To4x4(const OperationDef& op_def) {
151   std::string c;
152   const auto src_desc = op_def.src_tensors[0];
153 
154   c += R"(
155 MAIN_FUNCTION($0) {
156   int tile_id = GLOBAL_ID_0;
157   int Z = GLOBAL_ID_2;
158   int tiles_count_x = (args.dst_tensor.Width() + 3) / 4;
159   int tile_x = (tile_id % tiles_count_x) * 4;
160   int tile_y = (tile_id / tiles_count_x) * 4;
161   if (tile_x >= args.dst_tensor.Width() || tile_y >= args.dst_tensor.Height()) return;
162 
163   FLT4 I[4][6];
164   for (int y = 0; y < 4; ++y) {
165     for (int x = 0; x < 6; ++x) {
166       I[y][x] = INIT_FLT4(0.0f);
167     }
168   }
169 )";
170   if (src_desc.IsLinear()) {
171     c += R"(
172   int src_adress = args.src_tensor.GetAddress(tile_id, 0, Z);
173   for (int y = 0; y < 6; ++y) {
174     for (int x = 0; x < 6; ++x, src_adress += args.src_tensor.Width()) {
175       FLT4 src = args.src_tensor.Read(src_adress);
176       I[0][x] += src * args.At.Read(y);
177       I[1][x] += src * args.At.Read(y + 6);
178       I[2][x] += src * args.At.Read(y + 12);
179       I[3][x] += src * args.At.Read(y + 18);
180     }
181   }
182 )";
183   } else {
184     c += R"(
185   for (int y = 0; y < 6; ++y) {
186     for (int x = 0; x < 6; ++x) {
187       FLT4 src = args.src_tensor.Read(tile_id, y * 6 + x, Z);
188       I[0][x] += src * args.At.Read(y);
189       I[1][x] += src * args.At.Read(y + 6);
190       I[2][x] += src * args.At.Read(y + 12);
191       I[3][x] += src * args.At.Read(y + 18);
192     }
193   }
194 )";
195   }
196   c += R"(
197 
198   FLT4 bias_val = args.biases.Read(Z);
199   for (int y = 0; y < 4; ++y) {
200     FLT4 t0 = I[y][1] + I[y][2];
201     FLT4 t1 = I[y][3] + I[y][4];
202     if (tile_x < args.dst_tensor.Width() && tile_y + y < args.dst_tensor.Height()) {
203       FLT4 value = I[y][0] + t0 + t1 + bias_val;
204       args.dst_tensor.Write(value, tile_x, tile_y + y, Z);
205     }
206     FLT4 t2 = I[y][1] - I[y][2];
207     FLT4 t3 = I[y][3] - I[y][4];
208     if (tile_x + 1 < args.dst_tensor.Width() && tile_y + y < args.dst_tensor.Height()) {
209       FLT4 value = t2 * args.At.Read(7) + t3 * args.At.Read(9) + bias_val;
210       args.dst_tensor.Write(value, tile_x + 1, tile_y + y, Z);
211     }
212     if (tile_x + 2 < args.dst_tensor.Width() && tile_y + y < args.dst_tensor.Height()) {
213       FLT4 value = t0 * args.At.Read(13) + t1 * args.At.Read(15) + bias_val;
214       args.dst_tensor.Write(value, tile_x + 2, tile_y + y, Z);
215     }
216     if (tile_x + 3 < args.dst_tensor.Width() && tile_y + y < args.dst_tensor.Height()) {
217       FLT4 value = t2 * args.At.Read(19) + t3 * args.At.Read(21) + I[y][5] + bias_val;
218       args.dst_tensor.Write(value, tile_x + 3, tile_y + y, Z);
219     }
220   }
221 }
222 )";
223   return c;
224 }
225 }  // namespace
226 
GetGridSize() const227 int3 Winograd4x4To36::GetGridSize() const {
228   int new_width =
229       src_[0]->Width() + padding_.prepended.w + padding_.appended.w - 2;
230   int new_height =
231       src_[0]->Height() + padding_.prepended.h + padding_.appended.h - 2;
232   int tiles_x = DivideRoundUp(new_width, 4);
233   int tiles_y = DivideRoundUp(new_height, 4);
234   return int3(tiles_x, tiles_y, src_[0]->Slices());
235 }
236 
BindArguments(ArgumentsBinder * args)237 absl::Status Winograd4x4To36::BindArguments(ArgumentsBinder* args) {
238   int new_width =
239       src_[0]->Width() + padding_.prepended.w + padding_.appended.w - 2;
240   int new_height =
241       src_[0]->Height() + padding_.prepended.h + padding_.appended.h - 2;
242   int tiles_x = DivideRoundUp(new_width, 4);
243   int tiles_y = DivideRoundUp(new_height, 4);
244   RETURN_IF_ERROR(args->SetInt("tiles_x", tiles_x));
245   RETURN_IF_ERROR(args->SetInt("tiles_y", tiles_y));
246   return absl::OkStatus();
247 }
248 
CreateWinograd4x4To36(const OperationDef & definition,const Padding2D & padding,const GpuInfo & gpu_info)249 Winograd4x4To36 CreateWinograd4x4To36(const OperationDef& definition,
250                                       const Padding2D& padding,
251                                       const GpuInfo& gpu_info) {
252   Winograd4x4To36 desc(definition, padding);
253   desc.code_ = GetKernelWinograd4x4To36(gpu_info, definition);
254 
255   desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
256   desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
257 
258   desc.args_.AddInt("padding_x", -padding.prepended.w);
259   desc.args_.AddInt("padding_y", -padding.prepended.h);
260   desc.args_.AddInt("tiles_x");
261   desc.args_.AddInt("tiles_y");
262 
263   BufferDescriptor buffer_desc;
264   VectorToKernelBufferDesc(BtMatrixForWinograd4x4To6x6(),
265                            definition.GetDataType(), &buffer_desc);
266   desc.args_.AddObject(
267       "Bt", std::make_unique<BufferDescriptor>(std::move(buffer_desc)));
268 
269   desc.work_group_size_ = int3(8, 4, 1);
270   return desc;
271 }
272 
Winograd4x4To36TileX6(const OperationDef & definition,const Padding2D & padding,const GpuInfo & gpu_info)273 Winograd4x4To36TileX6::Winograd4x4To36TileX6(const OperationDef& definition,
274                                              const Padding2D& padding,
275                                              const GpuInfo& gpu_info)
276     : GPUOperation(definition), padding_(padding) {
277   work_group_size_ = int3(32, 1, 1);
278   code_ = GetWinograd4x4To36TileX6Code(definition_, gpu_info);
279   if (gpu_info.IsAdreno()) {
280     compiler_options_.push_back(CompilerOptions::kAdrenoMoreWaves);
281   }
282   if (definition_.precision == CalculationsPrecision::F16 &&
283       gpu_info.IsPowerVR()) {
284     compiler_options_.push_back(CompilerOptions::kClFastRelaxedMath);
285   }
286 }
287 
GetWinograd4x4To36TileX6Code(const OperationDef & op_def,const GpuInfo & gpu_info)288 std::string Winograd4x4To36TileX6::GetWinograd4x4To36TileX6Code(
289     const OperationDef& op_def, const GpuInfo& gpu_info) {
290   std::string c;
291   const auto& src_desc = op_def.src_tensors[0];
292   AddSrcTensor("src_tensor", op_def.src_tensors[0]);
293   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
294   args_.AddInt("padding_x");
295   args_.AddInt("padding_y");
296   args_.AddInt("tiles_total");
297   args_.AddInt("tiles_x");
298 
299   c += "MAIN_FUNCTION($0) {\n";
300   c += "  int DST_X = GLOBAL_ID_0;\n";
301   c += "  int DST_Y = GLOBAL_ID_1;\n";
302   c += "  int DST_Z = GLOBAL_ID_2;\n";
303   c += "  if (DST_X >= args.tiles_total || DST_Y >= 6 || DST_Z >= "
304        "args.dst_tensor.Slices()) {\n";
305   c += "    return; \n";
306   c += "  }\n";
307   c += "  int tile_x = (DST_X % args.tiles_x) * 4;\n";
308   c += "  int tile_y = (DST_X / args.tiles_x) * 4;\n";
309   c += "  FLT4 I0, I1, I2, I3, I4, I5;\n";
310   c += "  FLT bt_ar[6];\n";
311   c += "  FLT4 t0 = args.bt_non_uniform.Read(DST_Y * 2 + 0);\n";
312   c += "  FLT4 t1 = args.bt_non_uniform.Read(DST_Y * 2 + 1);\n";
313   c += "  DST_Y *= 6;\n";
314   c += "  bt_ar[0] = t0.x;\n";
315   c += "  bt_ar[1] = t0.y;\n";
316   c += "  bt_ar[2] = t0.z;\n";
317   c += "  bt_ar[3] = t0.w;\n";
318   c += "  bt_ar[4] = t1.x;\n";
319   c += "  bt_ar[5] = t1.y;\n";
320   auto read_src = [&](const std::string& src, const std::string& xs) {
321     std::string read_statement;
322     if (src_desc.IsLinear()) {
323       read_statement = "args.src_tensor.Read(src_a_" + xs + " + offset)";
324     } else {
325       read_statement = "args.src_tensor.Read(xc" + xs + ", yc, DST_Z)";
326     }
327     std::string multiplier;
328     if (!src_desc.SupportsZeroClamp(Axis::WIDTH, gpu_info)) {
329       if (!(src_desc.IsLinear() &&
330             src_desc.ReturnsZeroForNegOneRead(gpu_info))) {
331         multiplier = " * m" + xs + "_x";
332       }
333     }
334     c += "    FLT4 " + src + " = " + read_statement + multiplier + ";\n";
335   };
336   for (int x = 0; x < 6; ++x) {
337     const std::string xs = std::to_string(x);
338     c += "  int xc" + xs + " = tile_x + args.padding_x + " + xs + ";\n";
339     if (!src_desc.SupportsZeroClamp(Axis::WIDTH, gpu_info)) {
340       c += "  bool inx" + xs + " = (xc" + xs + " >= 0 && xc" + xs +
341            " < args.src_tensor.Width());\n";
342       c += "  FLT m" + xs + "_x = INIT_FLT(inx" + xs + ");\n";
343       c += "  xc" + xs + " = clamp(xc" + xs +
344            ", 0, args.src_tensor.Width() - 1);\n";
345     }
346     if (src_desc.IsLinear()) {
347       c += "  int src_a_" + xs + " = args.src_tensor.GetAddress(xc" + xs +
348            ", 0, DST_Z);\n";
349       if (src_desc.ReturnsZeroForNegOneRead(gpu_info)) {
350         c += "  src_a_" + xs +
351              " = select(-args.src_tensor.Width() * args.src_tensor.Height(), "
352              "src_a_" +
353              xs + ", inx" + xs + ");\n";
354       }
355     }
356   }
357   const bool manual_unroll =
358       !(op_def.precision == CalculationsPrecision::F32 && gpu_info.IsMali());
359   if (manual_unroll) {
360     c += "  {\n";
361     c += "    int yc = tile_y + args.padding_y;\n";
362     if (!src_desc.SupportsZeroClamp(Axis::HEIGHT, gpu_info)) {
363       c += "    bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n";
364       c += "    yc = clamp(yc, 0, args.src_tensor.Height() - 1);\n";
365       c += "    int offset = select(0, yc * args.src_tensor.Width(), iny);\n";
366       c += "    FLT bt = bt_ar[0] * INIT_FLT(iny);\n";
367     } else {
368       c += "    FLT bt = bt_ar[0];\n";
369     }
370     for (int x = 0; x < 6; ++x) {
371       const std::string xs = std::to_string(x);
372       const std::string src = "src" + xs;
373       read_src(src, xs);
374       c += "    I" + xs + " = bt * " + src + ";\n";
375     }
376     c += "  }\n";
377     for (int y = 1; y < 6; ++y) {
378       const std::string ys = std::to_string(y);
379       c += "  {\n";
380       c += "    int yc = tile_y + args.padding_y + (" + ys + ");\n";
381       if (!src_desc.SupportsZeroClamp(Axis::HEIGHT, gpu_info)) {
382         c += "    bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n";
383         c += "    yc = clamp(yc, 0, args.src_tensor.Height() - 1);\n";
384         c += "    int offset = select(0, yc * args.src_tensor.Width(), iny);\n";
385         c += "    FLT bt = bt_ar[" + ys + "] * INIT_FLT(iny);\n";
386       } else {
387         c += "    FLT bt = bt_ar[" + ys + "];\n";
388       }
389       for (int x = 0; x < 6; ++x) {
390         const std::string xs = std::to_string(x);
391         const std::string src = "src" + xs;
392         read_src(src, xs);
393         c += "    I" + xs + " += bt * " + src + ";\n";
394       }
395       c += "  }\n";
396     }
397   } else {
398     c += "  I0 = INIT_FLT4(0.0f);\n";
399     c += "  I1 = INIT_FLT4(0.0f);\n";
400     c += "  I2 = INIT_FLT4(0.0f);\n";
401     c += "  I3 = INIT_FLT4(0.0f);\n";
402     c += "  I4 = INIT_FLT4(0.0f);\n";
403     c += "  I5 = INIT_FLT4(0.0f);\n";
404     c += "  for (int y = 0; y < 6; ++y) {\n";
405     c += "    int yc = tile_y + args.padding_y + y;\n";
406     if (!src_desc.SupportsZeroClamp(Axis::HEIGHT, gpu_info)) {
407       c += "    bool iny = (yc >= 0 && yc < args.src_tensor.Height());\n";
408       c += "    yc = clamp(yc, 0, args.src_tensor.Height() - 1);\n";
409       c += "    int offset = select(0, yc * args.src_tensor.Width(), iny);\n";
410       c += "    FLT bt = bt_ar[y] * INIT_FLT(iny);\n";
411     } else {
412       c += "    FLT bt = bt_ar[y];\n";
413     }
414     for (int x = 0; x < 6; ++x) {
415       const std::string xs = std::to_string(x);
416       const std::string src = "src" + xs;
417       read_src(src, xs);
418       c += "    I" + xs + " += bt * " + src + ";\n";
419     }
420     c += "  }\n";
421   }
422   c += "  {\n";
423   c += "    FLT4 r0 = I0 + args.Bt.Read(2) * I2 + args.Bt.Read(4) * I4;\n";
424   c += "    args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
425   c += "    DST_Y++;\n";
426   c += "  }\n";
427   c += "  {\n";
428   c += "    FLT4 r0 = args.Bt.Read(7) * I1 + args.Bt.Read(8) * I2 + "
429        "args.Bt.Read(9) * I3 + args.Bt.Read(10) * I4;\n";
430   c += "    args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
431   c += "    DST_Y++;\n";
432   c += "  }\n";
433   c += "  {\n";
434   c += "    FLT4 r0 = args.Bt.Read(13) * I1 + args.Bt.Read(14) * I2 + "
435        "args.Bt.Read(15) * I3 + args.Bt.Read(16) * I4;\n";
436   c += "    args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
437   c += "    DST_Y++;\n";
438   c += "  }\n";
439   c += "  {\n";
440   c += "    FLT4 r0 = args.Bt.Read(19) * I1 + args.Bt.Read(20) * I2 + "
441        "args.Bt.Read(21) * I3 + args.Bt.Read(22) * I4;\n";
442   c += "    args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
443   c += "    DST_Y++;\n";
444   c += "  }\n";
445   c += "  {\n";
446   c += "    FLT4 r0 = args.Bt.Read(25) * I1 + args.Bt.Read(26) * I2 + "
447        "args.Bt.Read(27) * I3 + args.Bt.Read(28) * I4;\n";
448   c += "    args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
449   c += "    DST_Y++;\n";
450   c += "  }\n";
451   c += "  {\n";
452   c += "    FLT4 r0 = args.Bt.Read(31) * I1 + args.Bt.Read(33) * I3 + I5;\n";
453   c += "    args.dst_tensor.Write(r0, DST_X, DST_Y, DST_Z);\n";
454   c += "    DST_Y++;\n";
455   c += "  }\n";
456   c += "}\n";
457   return c;
458 }
459 
UploadBt()460 void Winograd4x4To36TileX6::UploadBt() {
461   tflite::gpu::Tensor<Linear, DataType::FLOAT32> bt_aligned;
462   bt_aligned.shape = Linear(6 * 8);
463   bt_aligned.data.resize(6 * 8);
464   auto bt_mat = BtMatrixForWinograd4x4To6x6();
465   for (int y = 0; y < 6; ++y) {
466     for (int x = 0; x < 6; ++x) {
467       bt_aligned.data[y * 8 + x] = bt_mat[y * 6 + x];
468     }
469     bt_aligned.data[y * 8 + 6] = 0.0f;
470     bt_aligned.data[y * 8 + 7] = 0.0f;
471   }
472 
473   TensorDescriptor bt_tensor_desc = CreateConstantLinearTensorDescriptor(
474       definition_.src_tensors[0].GetDataType(),
475       definition_.src_tensors[0].GetStorageType(), bt_aligned);
476   args_.AddObject("bt_non_uniform", std::make_unique<TensorDescriptor>(
477                                         std::move(bt_tensor_desc)));
478 
479   BufferDescriptor buffer_desc;
480   VectorToKernelBufferDesc(bt_mat, definition_.GetDataType(), &buffer_desc);
481   args_.AddObject("Bt",
482                   std::make_unique<BufferDescriptor>(std::move(buffer_desc)));
483 }
484 
SelectBestWorkGroup(const KernelInfo & kernel_info) const485 int3 Winograd4x4To36TileX6::SelectBestWorkGroup(
486     const KernelInfo& kernel_info) const {
487   const std::vector<int3> wgs = {{8, 6, 4}, {8, 6, 2}, {4, 6, 2},
488                                  {4, 6, 2}, {2, 6, 2}, {2, 6, 1},
489                                  {1, 6, 1}, {1, 3, 1}, {1, 1, 1}};
490   return GetFirstSuitableWorkGroup(wgs, kernel_info.max_work_group_size);
491 }
492 
BindArguments(ArgumentsBinder * args)493 absl::Status Winograd4x4To36TileX6::BindArguments(ArgumentsBinder* args) {
494   const int tiles_x = DivideRoundUp(
495       src_[0]->Width() + padding_.prepended.w + padding_.appended.w - 2, 4);
496   const int tiles_y = DivideRoundUp(
497       src_[0]->Height() + padding_.prepended.h + padding_.appended.h - 2, 4);
498   const int tiles_total = tiles_x * tiles_y;
499   RETURN_IF_ERROR(args->SetInt("padding_x", -padding_.prepended.w));
500   RETURN_IF_ERROR(args->SetInt("padding_y", -padding_.prepended.h));
501   RETURN_IF_ERROR(args->SetInt("tiles_total", tiles_total));
502   RETURN_IF_ERROR(args->SetInt("tiles_x", tiles_x));
503   return absl::OkStatus();
504 }
505 
GetGridSize() const506 int3 Winograd4x4To36TileX6::GetGridSize() const {
507   const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
508   const int grid_y = 6;
509   const int grid_z = dst_[0]->Slices();
510   return int3(grid_x, grid_y, grid_z);
511 }
512 
GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups) const513 void Winograd4x4To36TileX6::GetPossibleKernelWorkGroups(
514     TuningType tuning_type, const GpuInfo& gpu_info,
515     const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
516   if (gpu_info.IsIntel()) {
517     work_groups->push_back(int3(4, 6, 1));
518     return;
519   }
520   switch (tuning_type) {
521     case TuningType::kExhaustive:
522       GetPossibleWorkGroups(tuning_type, gpu_info, kernel_info, grid_size_,
523                             work_groups);
524       return;
525     case TuningType::kFast:
526     default:
527       work_groups->push_back(SelectBestWorkGroup(kernel_info));
528       return;
529   }
530 }
531 
CreateWinograd4x4To36TileX6(const GpuInfo & gpu_info,const OperationDef & definition,const Padding2D & padding)532 Winograd4x4To36TileX6 CreateWinograd4x4To36TileX6(
533     const GpuInfo& gpu_info, const OperationDef& definition,
534     const Padding2D& padding) {
535   Winograd4x4To36TileX6 result(definition, padding, gpu_info);
536   result.UploadBt();
537   return result;
538 }
539 
GetGridSize() const540 int3 Winograd36To4x4::GetGridSize() const {
541   return int3(src_[0]->Width(), 1, src_[0]->Slices());
542 }
543 
CreateWinograd36To4x4(const OperationDef & definition,const tflite::gpu::Tensor<Linear,DataType::FLOAT32> & biases)544 Winograd36To4x4 CreateWinograd36To4x4(
545     const OperationDef& definition,
546     const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& biases) {
547   Winograd36To4x4 desc(definition);
548   desc.code_ = GetKernelWinograd36To4x4(definition);
549 
550   desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
551   desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
552 
553   TensorDescriptor bias_tensor_desc = CreateConstantLinearTensorDescriptor(
554       definition.src_tensors[0].GetDataType(),
555       definition.src_tensors[0].GetStorageType(), biases);
556   desc.args_.AddObject("biases", std::make_unique<TensorDescriptor>(
557                                      std::move(bias_tensor_desc)));
558 
559   BufferDescriptor buffer_desc;
560   VectorToKernelBufferDesc(AtMatrixForWinograd4x4To6x6(),
561                            definition.GetDataType(), &buffer_desc);
562   desc.args_.AddObject(
563       "At", std::make_unique<BufferDescriptor>(std::move(buffer_desc)));
564 
565   desc.work_group_size_ = int3(32, 1, 1);
566   return desc;
567 }
568 
Winograd36To4x4Tile4x1(const OperationDef & definition,const GpuInfo & gpu_info)569 Winograd36To4x4Tile4x1::Winograd36To4x4Tile4x1(const OperationDef& definition,
570                                                const GpuInfo& gpu_info)
571     : GPUOperation(definition) {
572   work_group_size_ = int3(32, 1, 1);
573   if (definition_.precision == CalculationsPrecision::F16 &&
574       gpu_info.IsPowerVR()) {
575     compiler_options_.push_back(CompilerOptions::kClFastRelaxedMath);
576   }
577   code_ = GetWinograd36To4x4Tile4x1Code(definition_, gpu_info);
578 }
579 
GetWinograd36To4x4Tile4x1Code(const OperationDef & op_def,const GpuInfo & gpu_info)580 std::string Winograd36To4x4Tile4x1::GetWinograd36To4x4Tile4x1Code(
581     const OperationDef& op_def, const GpuInfo& gpu_info) {
582   std::string c;
583 
584   AddSrcTensor("src_tensor", op_def.src_tensors[0]);
585   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
586   args_.AddInt("tiles_x");
587 
588   c += "MAIN_FUNCTION($0) {\n";
589   c += "  int tile_id = GLOBAL_ID_0;\n";
590   c += "  int DST_Y = GLOBAL_ID_1;\n";
591   c += "  int DST_Z = GLOBAL_ID_2;\n";
592   c += "  int tile_x = (tile_id % args.tiles_x) * 4;\n";
593   c += "  int tile_y = (tile_id / args.tiles_x) * 4 + DST_Y;\n";
594 
595   c += "  if (tile_x >= args.dst_tensor.Width() || tile_y >= "
596        "args.dst_tensor.Height() || DST_Z >= args.dst_tensor.Slices()) {\n";
597   c += "    return; \n";
598   c += "  }\n";
599   c += "  FLT4 I0, I1, I2, I3, I4, I5;\n";
600   c += "  FLT at_ar[6];\n";
601   c += "  FLT4 t00 = args.at_non_uniform.Read(DST_Y * 2 + 0);\n";
602   c += "  FLT4 t01 = args.at_non_uniform.Read(DST_Y * 2 + 1);\n";
603   c += "  at_ar[0] = t00.x;\n";
604   c += "  at_ar[1] = t00.y;\n";
605   c += "  at_ar[2] = t00.z;\n";
606   c += "  at_ar[3] = t00.w;\n";
607   c += "  at_ar[4] = t01.x;\n";
608   c += "  at_ar[5] = t01.y;\n";
609   const bool manual_unroll =
610       !(op_def.precision == CalculationsPrecision::F32 && gpu_info.IsMali());
611   if (manual_unroll) {
612     c += "  {\n";
613     c += "    FLT at = at_ar[0];\n";
614     for (int x = 0; x < 6; ++x) {
615       const std::string yc = std::to_string(x);
616       const std::string src = "src" + std::to_string(x);
617       c += "    FLT4 " + src + " = args.src_tensor.Read(tile_id, " + yc +
618            ", DST_Z);\n";
619       c += "    I" + std::to_string(x) + " = at * " + src + ";\n";
620     }
621     c += "  }\n";
622     for (int y = 1; y < 6; ++y) {
623       c += "  {\n";
624       c += "    FLT at = at_ar[" + std::to_string(y) + "];\n";
625       for (int x = 0; x < 6; ++x) {
626         const std::string yc = std::to_string(y * 6 + x);
627         const std::string src = "src" + std::to_string(x);
628         c += "    FLT4 " + src + " = args.src_tensor.Read(tile_id, " + yc +
629              ", DST_Z);\n";
630         c += "    I" + std::to_string(x) + " += at * " + src + ";\n";
631       }
632       c += "  }\n";
633     }
634   } else {
635     c += "  I0 = INIT_FLT4(0.0f);\n";
636     c += "  I1 = INIT_FLT4(0.0f);\n";
637     c += "  I2 = INIT_FLT4(0.0f);\n";
638     c += "  I3 = INIT_FLT4(0.0f);\n";
639     c += "  I4 = INIT_FLT4(0.0f);\n";
640     c += "  I5 = INIT_FLT4(0.0f);\n";
641     c += "  for (int y = 0; y < 6; ++y) {\n";
642     c += "    FLT at = at_ar[y];\n";
643     for (int x = 0; x < 6; ++x) {
644       const std::string src = "src" + std::to_string(x);
645       c += "    FLT4 " + src + " = args.src_tensor.Read(tile_id, y * 6 + " +
646            std::to_string(x) + ", DST_Z);\n";
647       c += "    I" + std::to_string(x) + " += at * " + src + ";\n";
648     }
649     c += "  }\n";
650   }
651   c += "  FLT4 t0 = I1 + I2;\n";
652   c += "  FLT4 t1 = I3 + I4;\n";
653   c += "  FLT4 bias_val = args.biases.Read(DST_Z);\n";
654   c += "  {\n";
655   c += "    FLT4 r0 = I0 + t0 + t1 + bias_val;\n";
656   c += "    args.dst_tensor.Write(r0, tile_x, tile_y, DST_Z);\n";
657   c += "    tile_x++;\n";
658   c += "  }\n";
659   c += "  FLT4 t2 = I1 - I2;\n";
660   c += "  FLT4 t3 = I3 - I4;\n";
661   c += "  if (tile_x < args.dst_tensor.Width()) {\n";
662   c +=
663       "    FLT4 r0 = t2 * args.At.Read(7) + t3 * args.At.Read(9) + bias_val;\n";
664   c += "    args.dst_tensor.Write(r0, tile_x, tile_y, DST_Z);\n";
665   c += "    tile_x++;\n";
666   c += "  }\n";
667   c += "  if (tile_x < args.dst_tensor.Width()) {\n";
668   c += "    FLT4 r0 = t0 * args.At.Read(13) + t1 * args.At.Read(15) + "
669        "bias_val;\n";
670   c += "    args.dst_tensor.Write(r0, tile_x, tile_y, DST_Z);\n";
671   c += "    tile_x++;\n";
672   c += "  }\n";
673   c += "  if (tile_x < args.dst_tensor.Width()) {\n";
674   c += "    FLT4 r0 = t2 * args.At.Read(19) + t3 * args.At.Read(21) + I5 + "
675        "bias_val;\n";
676   c += "    args.dst_tensor.Write(r0, tile_x, tile_y, DST_Z);\n";
677   c += "    tile_x++;\n";
678   c += "  }\n";
679   c += "}\n";
680   return c;
681 }
682 
UploadAt()683 void Winograd36To4x4Tile4x1::UploadAt() {
684   tflite::gpu::Tensor<Linear, DataType::FLOAT32> at_aligned;
685   at_aligned.shape = Linear(4 * 8);
686   at_aligned.data.resize(4 * 8);
687   auto at_mat = AtMatrixForWinograd4x4To6x6();
688   for (int y = 0; y < 4; ++y) {
689     for (int x = 0; x < 6; ++x) {
690       at_aligned.data[y * 8 + x] = at_mat[y * 6 + x];
691     }
692     at_aligned.data[y * 8 + 6] = 0.0f;
693     at_aligned.data[y * 8 + 7] = 0.0f;
694   }
695 
696   TensorDescriptor at_tensor_desc = CreateConstantLinearTensorDescriptor(
697       definition_.src_tensors[0].GetDataType(),
698       definition_.src_tensors[0].GetStorageType(), at_aligned);
699   args_.AddObject("at_non_uniform", std::make_unique<TensorDescriptor>(
700                                         std::move(at_tensor_desc)));
701 
702   BufferDescriptor buffer_desc;
703   VectorToKernelBufferDesc(at_mat, definition_.GetDataType(), &buffer_desc);
704   args_.AddObject("At",
705                   std::make_unique<BufferDescriptor>(std::move(buffer_desc)));
706 }
707 
SelectBestWorkGroup(const KernelInfo & kernel_info) const708 int3 Winograd36To4x4Tile4x1::SelectBestWorkGroup(
709     const KernelInfo& kernel_info) const {
710   const std::vector<int3> wgs = {{32, 4, 2}, {16, 4, 2}, {16, 4, 1},
711                                  {8, 4, 1},  {4, 4, 1},  {2, 4, 1},
712                                  {1, 4, 1},  {1, 2, 1},  {1, 1, 1}};
713   return GetFirstSuitableWorkGroup(wgs, kernel_info.max_work_group_size);
714 }
715 
BindArguments(ArgumentsBinder * args)716 absl::Status Winograd36To4x4Tile4x1::BindArguments(ArgumentsBinder* args) {
717   const int tiles_x = DivideRoundUp(dst_[0]->Width(), 4);
718   RETURN_IF_ERROR(args->SetInt("tiles_x", tiles_x));
719   return absl::OkStatus();
720 }
721 
GetGridSize() const722 int3 Winograd36To4x4Tile4x1::GetGridSize() const {
723   const int tiles_x = DivideRoundUp(dst_[0]->Width(), 4);
724   const int tiles_y = DivideRoundUp(dst_[0]->Height(), 4);
725   const int grid_x = tiles_x * tiles_y * dst_[0]->Batch();
726   const int grid_y = 4;
727   const int grid_z = dst_[0]->Slices();
728   return int3(grid_x, grid_y, grid_z);
729 }
730 
GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups) const731 void Winograd36To4x4Tile4x1::GetPossibleKernelWorkGroups(
732     TuningType tuning_type, const GpuInfo& gpu_info,
733     const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
734   if (gpu_info.IsIntel()) {
735     work_groups->push_back(int3(8, 4, 1));
736     return;
737   }
738   switch (tuning_type) {
739     case TuningType::kExhaustive:
740       GetPossibleWorkGroups(tuning_type, gpu_info, kernel_info, grid_size_,
741                             work_groups);
742       return;
743     case TuningType::kFast:
744     default:
745       work_groups->push_back(SelectBestWorkGroup(kernel_info));
746       return;
747   }
748 }
749 
CreateWinograd36To4x4Tile4x1(const GpuInfo & gpu_info,const OperationDef & definition,const tflite::gpu::Tensor<Linear,DataType::FLOAT32> & biases)750 Winograd36To4x4Tile4x1 CreateWinograd36To4x4Tile4x1(
751     const GpuInfo& gpu_info, const OperationDef& definition,
752     const tflite::gpu::Tensor<Linear, DataType::FLOAT32>& biases) {
753   Winograd36To4x4Tile4x1 result(definition, gpu_info);
754   TensorDescriptor bias_tensor_desc = CreateConstantLinearTensorDescriptor(
755       gpu_info, definition.src_tensors[0].GetDataType(), biases);
756   result.args_.AddObject("biases", std::make_unique<TensorDescriptor>(
757                                        std::move(bias_tensor_desc)));
758   result.UploadAt();
759   return result;
760 }
761 
762 }  // namespace gpu
763 }  // namespace tflite
764