xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/tasks/depthwise_conv_3x3.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/depthwise_conv_3x3.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "absl/strings/match.h"
22 #include "tensorflow/lite/delegates/gpu/common/status.h"
23 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
24 
25 namespace tflite {
26 namespace gpu {
27 
DepthwiseConv3x3(const OperationDef & definition,bool weights_are_buffer,bool local_mem_uploads,const GpuInfo & gpu_info)28 DepthwiseConv3x3::DepthwiseConv3x3(const OperationDef& definition,
29                                    bool weights_are_buffer,
30                                    bool local_mem_uploads,
31                                    const GpuInfo& gpu_info)
32     : GPUOperation(definition), local_mem_uploads_(local_mem_uploads) {
33   work_group_size_ = int3(8, 4, 1);
34   code_ = GenerateDepthwiseConvCode(gpu_info, definition_, weights_are_buffer,
35                                     local_mem_uploads_);
36 
37   if (definition_.precision == CalculationsPrecision::F16 &&
38       gpu_info.IsPowerVR()) {
39     compiler_options_.push_back(CompilerOptions::kClFastRelaxedMath);
40   }
41 }
42 
DepthwiseConv3x3(DepthwiseConv3x3 && operation)43 DepthwiseConv3x3::DepthwiseConv3x3(DepthwiseConv3x3&& operation)
44     : GPUOperation(std::move(operation)),
45       local_mem_uploads_(operation.local_mem_uploads_) {}
46 
operator =(DepthwiseConv3x3 && operation)47 DepthwiseConv3x3& DepthwiseConv3x3::operator=(DepthwiseConv3x3&& operation) {
48   if (this != &operation) {
49     std::swap(local_mem_uploads_, operation.local_mem_uploads_);
50     GPUOperation::operator=(std::move(operation));
51   }
52   return *this;
53 }
54 
GenerateDepthwiseConvCode(const GpuInfo & gpu_info,const OperationDef & op_def,bool weights_are_buffer,bool local_mem_uploads)55 std::string DepthwiseConv3x3::GenerateDepthwiseConvCode(
56     const GpuInfo& gpu_info, const OperationDef& op_def,
57     bool weights_are_buffer, bool local_mem_uploads) {
58   auto src_desc = op_def.src_tensors[0];
59   AddSrcTensor("src_tensor", src_desc);
60   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
61 
62   std::string c;
63   if (local_mem_uploads && gpu_info.IsApiOpenCl()) {
64     c += "__attribute__((reqd_work_group_size(8, 4, 1)))\n";
65   }
66   c += "MAIN_FUNCTION($0) {\n";
67   if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
68     c += "  int linear_id = GLOBAL_ID_0;\n";
69     c += "  int X = (linear_id / args.dst_tensor.Batch()) * 2;\n";
70     c += "  int B = linear_id % args.dst_tensor.Batch();\n";
71     c += "  args.dst_tensor.SetBatchRef(B);\n";
72     c += "  args.src_tensor.SetBatchRef(B);\n";
73   } else {
74     c += "  int X = GLOBAL_ID_0 * 2;\n";
75   }
76   c += "  int Y = GLOBAL_ID_1 * 2;\n";
77   c += "  int S = GLOBAL_ID_2;\n";
78   c += "   ACCUM_FLT4 r0 = INIT_ACCUM_FLT4(0.0f);\n";
79   c += "   ACCUM_FLT4 r1 = INIT_ACCUM_FLT4(0.0f);\n";
80   c += "   ACCUM_FLT4 r2 = INIT_ACCUM_FLT4(0.0f);\n";
81   c += "   ACCUM_FLT4 r3 = INIT_ACCUM_FLT4(0.0f);\n";
82   if (!local_mem_uploads) {
83     c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
84          "|| S >= args.dst_tensor.Slices()) { \n";
85     c += "    return; \n";
86     c += "  } \n";
87   }
88   if (local_mem_uploads) {
89     c += "  __local FLT4 f[10];\n";
90     if (gpu_info.IsApiOpenCl() && gpu_info.IsPowerVR()) {
91       c += "  event_t e = async_work_group_copy(f, args.weights.GetPtr() + S * "
92            "10, 10, 0);\n";
93       c += "  wait_group_events(1, &e);\n";
94     } else {
95       c += "  int local_id = LOCAL_ID_1 * 8 + LOCAL_ID_0;\n";
96       c += "  if (local_id < 10) {\n";
97       c += "    f[local_id] = args.weights.Read(S * 10 + local_id);\n";
98       c += "  }\n";
99       c += "  LOCAL_MEM_BARRIER;\n";
100     }
101   } else if (weights_are_buffer && gpu_info.SupportsPointersInKernels()) {
102     c += "  __global FLT4* f = args.weights.GetPtr() + S * 10;\n";
103   }
104   c += "  FLT4 s0;\n";
105   c += "  FLT4 s1;\n";
106   c += "  FLT4 s2;\n";
107   c += "  FLT4 s3;\n";
108   std::string W[9] = {"f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f8"};
109   std::string bias = "bias";
110   std::string xc[4] = {"X - 1", "X", "X + 1", "X + 2"};
111   std::string yc[4] = {"Y - 1", "Y", "Y + 1", "Y + 2"};
112   if (!weights_are_buffer) {
113     c += "   FLT4 f0 = args.weights.Read(0, S);\n";
114     c += "   FLT4 f1 = args.weights.Read(1, S);\n";
115     c += "   FLT4 f2 = args.weights.Read(2, S);\n";
116     c += "   FLT4 f3 = args.weights.Read(3, S);\n";
117     c += "   FLT4 f4 = args.weights.Read(4, S);\n";
118     c += "   FLT4 f5 = args.weights.Read(5, S);\n";
119     c += "   FLT4 f6 = args.weights.Read(6, S);\n";
120     c += "   FLT4 f7 = args.weights.Read(7, S);\n";
121     c += "   FLT4 f8 = args.weights.Read(8, S);\n";
122   }
123   if (!op_def.src_tensors[0].SupportsZeroClamp(Axis::WIDTH, gpu_info)) {
124     c += "  int x0 = X - 1;\n";
125     c += "  int x1 = X;\n";
126     c += "  int x2 = X + 1;\n";
127     c += "  int x3 = X + 2;\n";
128     c += "  bool x0_in = x0 >= 0 && x0 < args.dst_tensor.Width();\n";
129     c += "  bool x1_in = x1 >= 0 && x1 < args.dst_tensor.Width();\n";
130     c += "  bool x2_in = x2 >= 0 && x2 < args.dst_tensor.Width();\n";
131     c += "  bool x3_in = x3 >= 0 && x3 < args.dst_tensor.Width();\n";
132     c += "  x0 = clamp(x0, 0, args.dst_tensor.Width() - 1);\n";
133     c += "  x1 = clamp(x1, 0, args.dst_tensor.Width() - 1);\n";
134     c += "  x2 = clamp(x2, 0, args.dst_tensor.Width() - 1);\n";
135     c += "  x3 = clamp(x3, 0, args.dst_tensor.Width() - 1);\n";
136     xc[0] = "x0";
137     xc[1] = "x1";
138     xc[2] = "x2";
139     xc[3] = "x3";
140   }
141   if (!op_def.src_tensors[0].SupportsZeroClamp(Axis::HEIGHT, gpu_info)) {
142     c += "  int y0 = Y - 1;\n";
143     c += "  int y1 = Y;\n";
144     c += "  int y2 = Y + 1;\n";
145     c += "  int y3 = Y + 2;\n";
146     c += "  bool y0_in = y0 >= 0 && y0 < args.dst_tensor.Height();\n";
147     c += "  bool y1_in = y1 >= 0 && y1 < args.dst_tensor.Height();\n";
148     c += "  bool y2_in = y2 >= 0 && y2 < args.dst_tensor.Height();\n";
149     c += "  bool y3_in = y3 >= 0 && y3 < args.dst_tensor.Height();\n";
150     c += "  y0 = clamp(y0, 0, args.dst_tensor.Height() - 1);\n";
151     c += "  y1 = clamp(y1, 0, args.dst_tensor.Height() - 1);\n";
152     c += "  y2 = clamp(y2, 0, args.dst_tensor.Height() - 1);\n";
153     c += "  y3 = clamp(y3, 0, args.dst_tensor.Height() - 1);\n";
154     yc[0] = "y0";
155     yc[1] = "y1";
156     yc[2] = "y2";
157     yc[3] = "y3";
158   }
159   if (local_mem_uploads || weights_are_buffer) {
160     const bool use_direct_buffer =
161         !local_mem_uploads && !gpu_info.SupportsPointersInKernels();
162     const std::string fetch_start =
163         use_direct_buffer ? "args.weights.Read(S * 10 + " : "f[";
164     const std::string fetch_end = use_direct_buffer ? ")" : "]";
165     W[0] = fetch_start + "0" + fetch_end;
166     W[1] = fetch_start + "1" + fetch_end;
167     W[2] = fetch_start + "2" + fetch_end;
168     W[3] = fetch_start + "3" + fetch_end;
169     W[4] = fetch_start + "4" + fetch_end;
170     W[5] = fetch_start + "5" + fetch_end;
171     W[6] = fetch_start + "6" + fetch_end;
172     W[7] = fetch_start + "7" + fetch_end;
173     W[8] = fetch_start + "8" + fetch_end;
174     bias = fetch_start + "9" + fetch_end;
175   }
176   auto read_4x_line = [&](int y) {
177     std::string s0_check, s1_check, s2_check, s3_check;
178     if (!op_def.src_tensors[0].SupportsZeroClamp(Axis::WIDTH, gpu_info)) {
179       s0_check += "x0_in";
180       s1_check += "x1_in";
181       s2_check += "x2_in";
182       s3_check += "x3_in";
183     }
184     if (!op_def.src_tensors[0].SupportsZeroClamp(Axis::HEIGHT, gpu_info)) {
185       const std::string y_in = "y" + std::to_string(y) + "_in";
186       s0_check += s0_check.empty() ? y_in : (" && " + y_in);
187       s1_check += s1_check.empty() ? y_in : (" && " + y_in);
188       s2_check += s2_check.empty() ? y_in : (" && " + y_in);
189       s3_check += s3_check.empty() ? y_in : (" && " + y_in);
190     }
191     if (!s0_check.empty()) {
192       s0_check = " * INIT_FLT(" + s0_check + ")";
193     }
194     if (!s1_check.empty()) {
195       s1_check = " * INIT_FLT(" + s1_check + ")";
196     }
197     if (!s2_check.empty()) {
198       s2_check = " * INIT_FLT(" + s2_check + ")";
199     }
200     if (!s3_check.empty()) {
201       s3_check = " * INIT_FLT(" + s3_check + ")";
202     }
203     c += "    s0 = args.src_tensor.Read(" + xc[0] + ", " + yc[y] + ", S)" +
204          s0_check + ";\n";
205     c += "    s1 = args.src_tensor.Read(" + xc[1] + ", " + yc[y] + ", S)" +
206          s1_check + ";\n";
207     c += "    s2 = args.src_tensor.Read(" + xc[2] + ", " + yc[y] + ", S)" +
208          s2_check + ";\n";
209     c += "    s3 = args.src_tensor.Read(" + xc[3] + ", " + yc[y] + ", S)" +
210          s3_check + ";\n";
211   };
212   c += "  {\n";
213   read_4x_line(0);
214   c += "    r0 += TO_ACCUM_TYPE(" + W[0] + " * s0);\n";
215   c += "    r0 += TO_ACCUM_TYPE(" + W[1] + " * s1);\n";
216   c += "    r1 += TO_ACCUM_TYPE(" + W[0] + " * s1);\n";
217   c += "    r0 += TO_ACCUM_TYPE(" + W[2] + " * s2);\n";
218   c += "    r1 += TO_ACCUM_TYPE(" + W[1] + " * s2);\n";
219   c += "    r1 += TO_ACCUM_TYPE(" + W[2] + " * s3);\n";
220   c += "  }\n";
221   c += "  {\n";
222   read_4x_line(1);
223   c += "    r0 += TO_ACCUM_TYPE(" + W[3] + " * s0);\n";
224   c += "    r2 += TO_ACCUM_TYPE(" + W[0] + " * s0);\n";
225   c += "    r0 += TO_ACCUM_TYPE(" + W[4] + " * s1);\n";
226   c += "    r1 += TO_ACCUM_TYPE(" + W[3] + " * s1);\n";
227   c += "    r2 += TO_ACCUM_TYPE(" + W[1] + " * s1);\n";
228   c += "    r3 += TO_ACCUM_TYPE(" + W[0] + " * s1);\n";
229   c += "    r0 += TO_ACCUM_TYPE(" + W[5] + " * s2);\n";
230   c += "    r1 += TO_ACCUM_TYPE(" + W[4] + " * s2);\n";
231   c += "    r2 += TO_ACCUM_TYPE(" + W[2] + " * s2);\n";
232   c += "    r3 += TO_ACCUM_TYPE(" + W[1] + " * s2);\n";
233   c += "    r1 += TO_ACCUM_TYPE(" + W[5] + " * s3);\n";
234   c += "    r3 += TO_ACCUM_TYPE(" + W[2] + " * s3);\n";
235   c += "  }\n";
236   c += "  {\n";
237   read_4x_line(2);
238   c += "    r0 += TO_ACCUM_TYPE(" + W[6] + " * s0);\n";
239   c += "    r2 += TO_ACCUM_TYPE(" + W[3] + " * s0);\n";
240   c += "    r0 += TO_ACCUM_TYPE(" + W[7] + " * s1);\n";
241   c += "    r1 += TO_ACCUM_TYPE(" + W[6] + " * s1);\n";
242   c += "    r2 += TO_ACCUM_TYPE(" + W[4] + " * s1);\n";
243   c += "    r3 += TO_ACCUM_TYPE(" + W[3] + " * s1);\n";
244   c += "    r0 += TO_ACCUM_TYPE(" + W[8] + " * s2);\n";
245   c += "    r1 += TO_ACCUM_TYPE(" + W[7] + " * s2);\n";
246   c += "    r2 += TO_ACCUM_TYPE(" + W[5] + " * s2);\n";
247   c += "    r3 += TO_ACCUM_TYPE(" + W[4] + " * s2);\n";
248   c += "    r1 += TO_ACCUM_TYPE(" + W[8] + " * s3);\n";
249   c += "    r3 += TO_ACCUM_TYPE(" + W[5] + " * s3);\n";
250   c += "  }\n";
251   c += "  {\n";
252   read_4x_line(3);
253   c += "    r2 += TO_ACCUM_TYPE(" + W[6] + " * s0);\n";
254   c += "    r2 += TO_ACCUM_TYPE(" + W[7] + " * s1);\n";
255   c += "    r3 += TO_ACCUM_TYPE(" + W[6] + " * s1);\n";
256   c += "    r2 += TO_ACCUM_TYPE(" + W[8] + " * s2);\n";
257   c += "    r3 += TO_ACCUM_TYPE(" + W[7] + " * s2);\n";
258   c += "    r3 += TO_ACCUM_TYPE(" + W[8] + " * s3);\n";
259   c += "  }\n";
260   if (!weights_are_buffer) {
261     c += "   FLT4 bias = args.weights.Read(9, S);\n";
262   }
263   c += "  r0 += TO_ACCUM_TYPE(" + bias + ");\n";
264   c += "  r1 += TO_ACCUM_TYPE(" + bias + ");\n";
265   c += "  r2 += TO_ACCUM_TYPE(" + bias + ");\n";
266   c += "  r3 += TO_ACCUM_TYPE(" + bias + ");\n";
267   if (local_mem_uploads) {
268     c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
269          "|| S >= args.dst_tensor.Slices()) { \n";
270     c += "    return; \n";
271     c += "  } \n";
272   }
273   c += "  if(X + 0 < args.dst_tensor.Width() && Y + 0 < "
274        "args.dst_tensor.Height()) {\n";
275   c += "    FLT4 result = TO_FLT4(r0);\n";
276   c += "    args.dst_tensor.Write(result, X + 0, Y + 0, S);\n";
277   c += "  }\n";
278   c += "  if(X + 1 < args.dst_tensor.Width() && Y + 0 < "
279        "args.dst_tensor.Height()) {\n";
280   c += "    FLT4 result = TO_FLT4(r1);\n";
281   c += "    args.dst_tensor.Write(result, X + 1, Y + 0, S);\n";
282   c += "  }\n";
283   c += "  if(X + 0 < args.dst_tensor.Width() && Y + 1 < "
284        "args.dst_tensor.Height()) {\n";
285   c += "    FLT4 result = TO_FLT4(r2);\n";
286   c += "    args.dst_tensor.Write(result, X + 0, Y + 1, S);\n";
287   c += "  }\n";
288   c += "  if(X + 1 < args.dst_tensor.Width() && Y + 1 < "
289        "args.dst_tensor.Height()) {\n";
290   c += "    FLT4 result = TO_FLT4(r3);\n";
291   c += "    args.dst_tensor.Write(result, X + 1, Y + 1, S);\n";
292   c += "  }\n";
293   c += "}\n";
294 
295   return c;
296 }
297 
GetGridSize() const298 int3 DepthwiseConv3x3::GetGridSize() const {
299   const int grid_x = DivideRoundUp(dst_[0]->Width(), 2) * dst_[0]->Batch();
300   const int grid_y = DivideRoundUp(dst_[0]->Height(), 2);
301   const int grid_z = dst_[0]->Slices();
302   return int3(grid_x, grid_y, grid_z);
303 }
304 
GetPossibleKernelWorkGroups(TuningType tuning_type,const GpuInfo & gpu_info,const KernelInfo & kernel_info,std::vector<int3> * work_groups) const305 void DepthwiseConv3x3::GetPossibleKernelWorkGroups(
306     TuningType tuning_type, const GpuInfo& gpu_info,
307     const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
308   if (local_mem_uploads_) {
309     work_groups->push_back(work_group_size_);
310   } else {
311     GetPossibleWorkGroups(tuning_type, gpu_info, kernel_info, grid_size_,
312                           work_groups);
313   }
314 }
315 
IsDepthwiseConv3x3Supported(const GpuInfo & gpu_info,const DepthwiseConvolution2DAttributes & attr)316 bool IsDepthwiseConv3x3Supported(const GpuInfo& gpu_info,
317                                  const DepthwiseConvolution2DAttributes& attr) {
318   if (gpu_info.IsApiOpenCl() && gpu_info.IsAdreno()) {
319     const std::string kBadDriver =
320         "OpenCL 2.0 QUALCOMM build: commit #7daed58 changeid #I7ece6fe30d "
321         "Date: 10/19/16";
322     if (absl::StrContains(gpu_info.opencl_info.platform_version, kBadDriver)) {
323       return false;
324     }
325   }
326   return attr.weights.shape.o == 1 && attr.dilations.w == 1 &&
327          attr.dilations.h == 1 && attr.weights.shape.w == 3 &&
328          attr.weights.shape.h == 3 && attr.strides.w == 1 &&
329          attr.strides.h == 1 && attr.padding.prepended.w == 1 &&
330          attr.padding.prepended.h == 1 && attr.padding.appended.w == 1 &&
331          attr.padding.appended.h == 1;
332 }
333 
CreateDepthwiseConv3x3(const GpuInfo & gpu_info,const OperationDef & definition,const DepthwiseConvolution2DAttributes & attr)334 DepthwiseConv3x3 CreateDepthwiseConv3x3(
335     const GpuInfo& gpu_info, const OperationDef& definition,
336     const DepthwiseConvolution2DAttributes& attr) {
337   bool weights_are_buffer = !gpu_info.SupportsImages() ||
338                             gpu_info.IsPowerVR() || gpu_info.IsMali() ||
339                             gpu_info.IsApple();
340   bool local_mem_uploads = weights_are_buffer && gpu_info.IsPowerVR();
341   if (gpu_info.IsApple() &&
342       gpu_info.apple_info.IsLocalMemoryPreferredOverGlobal()) {
343     local_mem_uploads = true;
344   }
345   DepthwiseConv3x3 result(definition, weights_are_buffer, local_mem_uploads,
346                           gpu_info);
347   result.UploadWeightsAndBiases(attr.weights, attr.bias, weights_are_buffer);
348   return result;
349 }
350 
351 }  // namespace gpu
352 }  // namespace tflite
353