xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/tasks/concat_z.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/concat_z.h"
17 
18 #include <algorithm>
19 #include <string>
20 
21 #include "tensorflow/lite/delegates/gpu/common/status.h"
22 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
23 #include "tensorflow/lite/delegates/gpu/common/types.h"
24 
25 namespace tflite {
26 namespace gpu {
27 namespace {
28 
IsAllChannelsX4(const std::vector<int> & channels)29 bool IsAllChannelsX4(const std::vector<int>& channels) {
30   for (int channel : channels) {
31     if (channel % 4 != 0) {
32       return false;
33     }
34   }
35   return true;
36 }
37 
GetConcatKernelCode(const OperationDef & op_def,const std::vector<int> & channels)38 std::string GetConcatKernelCode(const OperationDef& op_def,
39                                 const std::vector<int>& channels) {
40   std::vector<std::string> tensor_names(op_def.src_tensors.size());
41   for (int i = 0; i < op_def.src_tensors.size(); ++i) {
42     tensor_names[i] = "src_tensor_" + std::to_string(i);
43   }
44 
45   std::string c;
46   c += "MAIN_FUNCTION($0) {\n";
47   if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
48     c += "  int linear_id = GLOBAL_ID_0;\n";
49     c += "  int X = linear_id / args.dst_tensor.Batch();\n";
50     c += "  int B = linear_id % args.dst_tensor.Batch();\n";
51     c += "  args.dst_tensor.SetBatchRef(B);\n";
52     for (int i = 0; i < op_def.src_tensors.size(); ++i) {
53       c += "  args." + tensor_names[i] + ".SetBatchRef(B);\n";
54     }
55   } else {
56     c += "  int X = GLOBAL_ID_0;\n";
57   }
58   c += "  int Y = GLOBAL_ID_1;\n";
59   std::string coords = "X, Y";
60   if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
61     c += "  int Z = GLOBAL_ID_2;\n";
62     c += "  if (Z >= args.dst_tensor.Depth()) return;\n";
63     coords = "X, Y, Z";
64   }
65   c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height()) "
66        "return; \n";
67 
68   if (IsAllChannelsX4(channels)) {
69     // When all channels % 4 == 0 we can read/assign/write VEC4 elements easily.
70     // Also it is easy to write a loop in this case, to prevent long kernel
71     // generation.
72     c += "  int S = 0;\n";
73     for (int i = 0; i < channels.size(); ++i) {
74       std::string t_name = "args." + tensor_names[i];
75       const int src_depth = DivideRoundUp(channels[i], 4);
76       if (src_depth % 2 == 0) {
77         // We can read more at once inside of loop in case src_depth % 2 == 0
78         // it should be better for reading latency hiding
79         c += "  for (int i = 0; i < " + t_name + ".Slices(); i += 2) {\n";
80         c += "    " + t_name + "::type result0 = " + t_name + ".Read(" +
81              coords + ", i);\n";
82         c += "    " + t_name + "::type result1 = " + t_name + ".Read(" +
83              coords + ", i + 1);\n";
84         c += "    args.dst_tensor.Write(result0, " + coords + ", S);\n";
85         c += "    args.dst_tensor.Write(result1, " + coords + ", S + 1);\n";
86         c += "    S += 2;\n";
87         c += "  }\n";
88       } else {
89         c += "  for (int i = 0; i < " + t_name + ".Slices(); ++i) {\n";
90         c += "    " + t_name + "::type result = " + t_name + ".Read(" + coords +
91              ", i);\n";
92         c += "    args.dst_tensor.Write(result, " + coords + ", S);\n";
93         c += "    S++;\n";
94         c += "  }\n";
95       }
96     }
97   } else {
98     c += "  args.src_tensor_0::type result = args.src_tensor_0::zero_value;\n";
99     int out_channel = 0;
100     int read_index = 0;
101     int z = 0;
102     const std::string postfix[] = {".x", ".y", ".z", ".w"};
103     for (int i = 0; i < channels.size(); ++i) {
104       std::string tensor_name = "args." + tensor_names[i];
105       const int depth = DivideRoundUp(channels[i], 4);
106       for (int d = 0; d < depth; ++d) {
107         const int channels_in_group = std::min(4, channels[i] - d * 4);
108         const std::string temp_name = "t" + std::to_string(read_index);
109         c += "  " + tensor_name + "::type " + temp_name + " = " + tensor_name +
110              ".Read(" + coords + ", " + std::to_string(d) + ");\n";
111         for (int ch = 0; ch < channels_in_group; ++ch) {
112           c += "  result" + postfix[out_channel] + " = ";
113           c += temp_name + postfix[ch] + ";\n";
114           out_channel++;
115           if (out_channel == 4) {
116             out_channel = 0;
117             c += "  args.dst_tensor.Write(result, " + coords + ", " +
118                  std::to_string(z) + ");\n";
119             z++;
120           }
121         }
122         read_index++;
123       }
124     }
125     if (out_channel != 0) {
126       c += "  args.dst_tensor.Write(result, " + coords + ", " +
127            std::to_string(z) + ");\n";
128     }
129   }
130   c += "}\n";
131   return c;
132 }
133 
134 }  // namespace
135 
CreateConcatZ(const OperationDef & definition,const std::vector<int> & channels,const GpuInfo & gpu_info)136 GPUOperation CreateConcatZ(const OperationDef& definition,
137                            const std::vector<int>& channels,
138                            const GpuInfo& gpu_info) {
139   GPUOperation op(definition);
140   for (int i = 0; i < definition.src_tensors.size(); ++i) {
141     const std::string name = "src_tensor_" + std::to_string(i);
142     op.AddSrcTensor(name, definition.src_tensors[i]);
143   }
144   op.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
145   op.code_ = GetConcatKernelCode(definition, channels);
146   if (gpu_info.IsPowerVR() &&
147       definition.precision == CalculationsPrecision::F32 &&
148       !IsAllChannelsX4(channels)) {
149     // BUG, some PowerVRs (GE8320) produce incorrect result without it
150     op.compiler_options_.push_back(CompilerOptions::kClDisableOptimizations);
151   }
152   op.tensor_to_grid_ = TensorToGrid::kWBToX_HToY_DToZ;
153   return op;
154 }
155 
156 }  // namespace gpu
157 }  // namespace tflite
158