xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/tasks/split.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/split.h"
17 
18 #include <string>
19 
20 
21 namespace tflite {
22 namespace gpu {
23 
Split(const GpuInfo & gpu_info,const OperationDef & definition,const SplitAttributes & attr,const std::vector<int> & channels)24 Split::Split(const GpuInfo& gpu_info, const OperationDef& definition,
25              const SplitAttributes& attr, const std::vector<int>& channels)
26     : GPUOperation(definition), attr_(attr) {
27   work_group_size_ = int3(8, 4, 1);
28   code_ = attr.axis == Axis::CHANNELS ? GetSplitChannelsCode(gpu_info, channels)
29                                       : GetSplitCode();
30 }
31 
GetSplitCode()32 std::string Split::GetSplitCode() {
33   AddSrcTensor("src_tensor", definition_.src_tensors[0]);
34   for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
35     AddDstTensor("dst_tensor_" + std::to_string(i), definition_.dst_tensors[i]);
36   }
37   const std::string task_width =
38       attr_.axis == Axis::WIDTH ? "1" : "args.src_tensor.Width()";
39   const std::string task_height =
40       attr_.axis == Axis::HEIGHT ? "1" : "args.src_tensor.Height()";
41   const std::string task_depth =
42       attr_.axis == Axis::DEPTH ? "1" : "args.src_tensor.Depth()";
43   const std::string task_batch =
44       attr_.axis == Axis::BATCH ? "1" : "args.src_tensor.Batch()";
45   const std::string task_slices =
46       attr_.axis == Axis::CHANNELS ? "1" : "args.src_tensor.Slices()";
47 
48   std::map<Axis, std::string> axis_to_selector = {
49       {Axis::WIDTH, "Width"}, {Axis::HEIGHT, "Height"},
50       {Axis::DEPTH, "Depth"}, {Axis::CHANNELS, "Slices"},
51       {Axis::BATCH, "Batch"},
52   };
53   std::map<Axis, std::string> axis_to_coord = {
54       {Axis::WIDTH, "X"},    {Axis::HEIGHT, "Y"}, {Axis::DEPTH, "D"},
55       {Axis::CHANNELS, "S"}, {Axis::BATCH, "B"},
56   };
57 
58   std::string c;
59   c += "MAIN_FUNCTION($0) {\n";
60   if (definition_.src_tensors[0].HasAxis(Axis::BATCH)) {
61     c += "  int linear_id = GLOBAL_ID_0;\n";
62     c += "  int X = linear_id / " + task_batch + ";\n";
63     c += "  int B = linear_id % " + task_batch + ";\n";
64     c += "  if (X >= " + task_width + ") return;\n";
65   } else {
66     c += "  int X = GLOBAL_ID_0;\n";
67     c += "  if (X >= " + task_width + ") return;\n";
68   }
69   if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
70     c += "  int linear_id = GLOBAL_ID_1;\n";
71     c += "  int Y = linear_id % " + task_height + ";\n";
72     c += "  int D = linear_id / " + task_height + ";\n";
73     c += "  if (D >= " + task_depth + ") return;\n";
74   } else {
75     c += "  int Y = GLOBAL_ID_1;\n";
76     c += "  if (Y >= " + task_height + ") return;\n";
77   }
78   c += "  int S = GLOBAL_ID_2;\n";
79   c += "  if (S >= " + task_slices + ") return;\n";
80   c += "  int src_counter = 0;\n";
81   std::vector<std::string> src_coords;
82   for (auto axis :
83        {Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH, Axis::CHANNELS, Axis::BATCH}) {
84     if (definition_.src_tensors[0].HasAxis(axis)) {
85       const std::string coord_name =
86           attr_.axis == axis ? "src_counter" : axis_to_coord[axis];
87       src_coords.push_back(coord_name);
88     }
89   }
90   std::string src_coords_str = src_coords[0];
91   for (int i = 1; i < src_coords.size(); ++i) {
92     src_coords_str += ", " + src_coords[i];
93   }
94   for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
95     std::vector<std::string> dst_coords;
96     for (auto axis : {Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH, Axis::CHANNELS,
97                       Axis::BATCH}) {
98       if (definition_.dst_tensors[i].HasAxis(axis)) {
99         const std::string coord_name =
100             attr_.axis == axis ? "i" : axis_to_coord[axis];
101         dst_coords.push_back(coord_name);
102       }
103     }
104     std::string dst_coords_str = dst_coords[0];
105     for (int j = 1; j < dst_coords.size(); ++j) {
106       dst_coords_str += ", " + dst_coords[j];
107     }
108     const std::string dst_name = "args.dst_tensor_" + std::to_string(i);
109     c += "  for (int i = 0; i < " + dst_name + "." +
110          axis_to_selector[attr_.axis] + "(); ++i, src_counter++) {\n";
111     c += "    args.src_tensor::type result = args.src_tensor.Read(" +
112          src_coords_str + ");\n";
113     c += "    " + dst_name + ".Write(result, " + dst_coords_str + ");\n";
114     c += "  }\n";
115   }
116   c += "}\n";
117   return c;
118 }
119 
GetSplitChannelsCode(const GpuInfo & gpu_info,const std::vector<int> & channels)120 std::string Split::GetSplitChannelsCode(const GpuInfo& gpu_info,
121                                         const std::vector<int>& channels) {
122   AddSrcTensor("src_tensor", definition_.src_tensors[0]);
123   for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
124     AddDstTensor("dst_tensor_" + std::to_string(i), definition_.dst_tensors[i]);
125   }
126 
127   const std::string batch_coord =
128       definition_.src_tensors[0].HasAxis(Axis::BATCH) ? ", B" : "";
129   std::string coords = "X, Y";
130   std::string c;
131   c += "MAIN_FUNCTION($0) {\n";
132   if (definition_.src_tensors[0].HasAxis(Axis::BATCH)) {
133     c += "  int linear_id = GLOBAL_ID_0;\n";
134     c += "  int X = linear_id / args.src_tensor.Batch();\n";
135     c += "  int B = linear_id % args.src_tensor.Batch();\n";
136     c += "  if (X >= args.src_tensor.Width()) return;\n";
137   } else {
138     c += "  int X = GLOBAL_ID_0;\n";
139     c += "  if (X >= args.src_tensor.Width()) return;\n";
140   }
141   if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
142     c += "  int linear_id = GLOBAL_ID_1;\n";
143     c += "  int Y = linear_id % args.src_tensor.Height();\n";
144     c += "  int Z = linear_id / args.src_tensor.Height();\n";
145     c += "  if (Z >= args.src_tensor.Depth()) return;\n";
146     coords += ", Z";
147   } else {
148     c += "  int Y = GLOBAL_ID_1;\n";
149     c += "  if (Y >= args.src_tensor.Height()) return;\n";
150   }
151   int src_channels = 0;
152   for (auto dst_ch : channels) {
153     src_channels += dst_ch;
154   }
155   const int src_slices = DivideRoundUp(src_channels, 4);
156   int dst_ch = 0;
157   int dst_slice = 0;
158   int dst_tensor = 0;
159   const std::string postfix[] = {".x", ".y", ".z", ".w"};
160   c += "  args.src_tensor::type dst_val;\n";
161   for (int s = 0; s < src_slices; ++s) {
162     c += "  if (" + std::to_string(s) + " < args.src_tensor.Slices()) {\n";
163     c += "    args.src_tensor::type src_val = args.src_tensor.Read(" + coords +
164          ", " + std::to_string(s) + batch_coord + ");\n";
165     for (int k = 0; k < 4; ++k) {
166       c += "    dst_val" + postfix[dst_ch % 4] + " = src_val" + postfix[k] +
167            ";\n";
168       dst_ch++;
169       if (dst_ch == channels[dst_tensor]) {
170         const std::string dst_name =
171             "args.dst_tensor_" + std::to_string(dst_tensor);
172         c += "    " + dst_name + ".Write(dst_val, " + coords + ", " +
173              std::to_string(dst_slice) + batch_coord + ");\n";
174         dst_tensor += 1;
175         dst_ch = 0;
176         dst_slice = 0;
177       }
178       if (dst_ch != 0 && dst_ch % 4 == 0) {
179         const std::string dst_name =
180             "args.dst_tensor_" + std::to_string(dst_tensor);
181         c += "    " + dst_name + ".Write(dst_val, " + coords + ", " +
182              std::to_string(dst_slice) + batch_coord + ");\n";
183         dst_slice += 1;
184       }
185     }
186     if (gpu_info.IsMali()) {
187       // workaround for Mali
188       // without it, kernel can fail with CL_OUT_OF_RESOURCES with big enough
189       // src channels count.
190       c += "  } else { return; }\n";
191     } else {
192       c += "  }\n";
193     }
194   }
195   c += "}\n";
196   return c;
197 }
198 
GetGridSize() const199 int3 Split::GetGridSize() const {
200   const int width = attr_.axis == Axis::WIDTH ? 1 : src_[0]->Width();
201   const int height = attr_.axis == Axis::HEIGHT ? 1 : src_[0]->Height();
202   const int depth = attr_.axis == Axis::DEPTH ? 1 : src_[0]->Depth();
203   const int batch = attr_.axis == Axis::BATCH ? 1 : src_[0]->Batch();
204   const int slices = attr_.axis == Axis::CHANNELS ? 1 : src_[0]->Slices();
205   const int grid_x = width * batch;
206   const int grid_y = height * depth;
207   const int grid_z = slices;
208   return int3(grid_x, grid_y, grid_z);
209 }
210 
CreateSplit(const GpuInfo & gpu_info,const OperationDef & definition,const SplitAttributes & attr,const std::vector<int> & channels)211 Split CreateSplit(const GpuInfo& gpu_info, const OperationDef& definition,
212                   const SplitAttributes& attr,
213                   const std::vector<int>& channels) {
214   return Split(gpu_info, definition, attr, channels);
215 }
216 
217 }  // namespace gpu
218 }  // namespace tflite
219