1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/split.h"
17
18 #include <string>
19
20
21 namespace tflite {
22 namespace gpu {
23
Split(const GpuInfo & gpu_info,const OperationDef & definition,const SplitAttributes & attr,const std::vector<int> & channels)24 Split::Split(const GpuInfo& gpu_info, const OperationDef& definition,
25 const SplitAttributes& attr, const std::vector<int>& channels)
26 : GPUOperation(definition), attr_(attr) {
27 work_group_size_ = int3(8, 4, 1);
28 code_ = attr.axis == Axis::CHANNELS ? GetSplitChannelsCode(gpu_info, channels)
29 : GetSplitCode();
30 }
31
GetSplitCode()32 std::string Split::GetSplitCode() {
33 AddSrcTensor("src_tensor", definition_.src_tensors[0]);
34 for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
35 AddDstTensor("dst_tensor_" + std::to_string(i), definition_.dst_tensors[i]);
36 }
37 const std::string task_width =
38 attr_.axis == Axis::WIDTH ? "1" : "args.src_tensor.Width()";
39 const std::string task_height =
40 attr_.axis == Axis::HEIGHT ? "1" : "args.src_tensor.Height()";
41 const std::string task_depth =
42 attr_.axis == Axis::DEPTH ? "1" : "args.src_tensor.Depth()";
43 const std::string task_batch =
44 attr_.axis == Axis::BATCH ? "1" : "args.src_tensor.Batch()";
45 const std::string task_slices =
46 attr_.axis == Axis::CHANNELS ? "1" : "args.src_tensor.Slices()";
47
48 std::map<Axis, std::string> axis_to_selector = {
49 {Axis::WIDTH, "Width"}, {Axis::HEIGHT, "Height"},
50 {Axis::DEPTH, "Depth"}, {Axis::CHANNELS, "Slices"},
51 {Axis::BATCH, "Batch"},
52 };
53 std::map<Axis, std::string> axis_to_coord = {
54 {Axis::WIDTH, "X"}, {Axis::HEIGHT, "Y"}, {Axis::DEPTH, "D"},
55 {Axis::CHANNELS, "S"}, {Axis::BATCH, "B"},
56 };
57
58 std::string c;
59 c += "MAIN_FUNCTION($0) {\n";
60 if (definition_.src_tensors[0].HasAxis(Axis::BATCH)) {
61 c += " int linear_id = GLOBAL_ID_0;\n";
62 c += " int X = linear_id / " + task_batch + ";\n";
63 c += " int B = linear_id % " + task_batch + ";\n";
64 c += " if (X >= " + task_width + ") return;\n";
65 } else {
66 c += " int X = GLOBAL_ID_0;\n";
67 c += " if (X >= " + task_width + ") return;\n";
68 }
69 if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
70 c += " int linear_id = GLOBAL_ID_1;\n";
71 c += " int Y = linear_id % " + task_height + ";\n";
72 c += " int D = linear_id / " + task_height + ";\n";
73 c += " if (D >= " + task_depth + ") return;\n";
74 } else {
75 c += " int Y = GLOBAL_ID_1;\n";
76 c += " if (Y >= " + task_height + ") return;\n";
77 }
78 c += " int S = GLOBAL_ID_2;\n";
79 c += " if (S >= " + task_slices + ") return;\n";
80 c += " int src_counter = 0;\n";
81 std::vector<std::string> src_coords;
82 for (auto axis :
83 {Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH, Axis::CHANNELS, Axis::BATCH}) {
84 if (definition_.src_tensors[0].HasAxis(axis)) {
85 const std::string coord_name =
86 attr_.axis == axis ? "src_counter" : axis_to_coord[axis];
87 src_coords.push_back(coord_name);
88 }
89 }
90 std::string src_coords_str = src_coords[0];
91 for (int i = 1; i < src_coords.size(); ++i) {
92 src_coords_str += ", " + src_coords[i];
93 }
94 for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
95 std::vector<std::string> dst_coords;
96 for (auto axis : {Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH, Axis::CHANNELS,
97 Axis::BATCH}) {
98 if (definition_.dst_tensors[i].HasAxis(axis)) {
99 const std::string coord_name =
100 attr_.axis == axis ? "i" : axis_to_coord[axis];
101 dst_coords.push_back(coord_name);
102 }
103 }
104 std::string dst_coords_str = dst_coords[0];
105 for (int j = 1; j < dst_coords.size(); ++j) {
106 dst_coords_str += ", " + dst_coords[j];
107 }
108 const std::string dst_name = "args.dst_tensor_" + std::to_string(i);
109 c += " for (int i = 0; i < " + dst_name + "." +
110 axis_to_selector[attr_.axis] + "(); ++i, src_counter++) {\n";
111 c += " args.src_tensor::type result = args.src_tensor.Read(" +
112 src_coords_str + ");\n";
113 c += " " + dst_name + ".Write(result, " + dst_coords_str + ");\n";
114 c += " }\n";
115 }
116 c += "}\n";
117 return c;
118 }
119
GetSplitChannelsCode(const GpuInfo & gpu_info,const std::vector<int> & channels)120 std::string Split::GetSplitChannelsCode(const GpuInfo& gpu_info,
121 const std::vector<int>& channels) {
122 AddSrcTensor("src_tensor", definition_.src_tensors[0]);
123 for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
124 AddDstTensor("dst_tensor_" + std::to_string(i), definition_.dst_tensors[i]);
125 }
126
127 const std::string batch_coord =
128 definition_.src_tensors[0].HasAxis(Axis::BATCH) ? ", B" : "";
129 std::string coords = "X, Y";
130 std::string c;
131 c += "MAIN_FUNCTION($0) {\n";
132 if (definition_.src_tensors[0].HasAxis(Axis::BATCH)) {
133 c += " int linear_id = GLOBAL_ID_0;\n";
134 c += " int X = linear_id / args.src_tensor.Batch();\n";
135 c += " int B = linear_id % args.src_tensor.Batch();\n";
136 c += " if (X >= args.src_tensor.Width()) return;\n";
137 } else {
138 c += " int X = GLOBAL_ID_0;\n";
139 c += " if (X >= args.src_tensor.Width()) return;\n";
140 }
141 if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
142 c += " int linear_id = GLOBAL_ID_1;\n";
143 c += " int Y = linear_id % args.src_tensor.Height();\n";
144 c += " int Z = linear_id / args.src_tensor.Height();\n";
145 c += " if (Z >= args.src_tensor.Depth()) return;\n";
146 coords += ", Z";
147 } else {
148 c += " int Y = GLOBAL_ID_1;\n";
149 c += " if (Y >= args.src_tensor.Height()) return;\n";
150 }
151 int src_channels = 0;
152 for (auto dst_ch : channels) {
153 src_channels += dst_ch;
154 }
155 const int src_slices = DivideRoundUp(src_channels, 4);
156 int dst_ch = 0;
157 int dst_slice = 0;
158 int dst_tensor = 0;
159 const std::string postfix[] = {".x", ".y", ".z", ".w"};
160 c += " args.src_tensor::type dst_val;\n";
161 for (int s = 0; s < src_slices; ++s) {
162 c += " if (" + std::to_string(s) + " < args.src_tensor.Slices()) {\n";
163 c += " args.src_tensor::type src_val = args.src_tensor.Read(" + coords +
164 ", " + std::to_string(s) + batch_coord + ");\n";
165 for (int k = 0; k < 4; ++k) {
166 c += " dst_val" + postfix[dst_ch % 4] + " = src_val" + postfix[k] +
167 ";\n";
168 dst_ch++;
169 if (dst_ch == channels[dst_tensor]) {
170 const std::string dst_name =
171 "args.dst_tensor_" + std::to_string(dst_tensor);
172 c += " " + dst_name + ".Write(dst_val, " + coords + ", " +
173 std::to_string(dst_slice) + batch_coord + ");\n";
174 dst_tensor += 1;
175 dst_ch = 0;
176 dst_slice = 0;
177 }
178 if (dst_ch != 0 && dst_ch % 4 == 0) {
179 const std::string dst_name =
180 "args.dst_tensor_" + std::to_string(dst_tensor);
181 c += " " + dst_name + ".Write(dst_val, " + coords + ", " +
182 std::to_string(dst_slice) + batch_coord + ");\n";
183 dst_slice += 1;
184 }
185 }
186 if (gpu_info.IsMali()) {
187 // workaround for Mali
188 // without it, kernel can fail with CL_OUT_OF_RESOURCES with big enough
189 // src channels count.
190 c += " } else { return; }\n";
191 } else {
192 c += " }\n";
193 }
194 }
195 c += "}\n";
196 return c;
197 }
198
GetGridSize() const199 int3 Split::GetGridSize() const {
200 const int width = attr_.axis == Axis::WIDTH ? 1 : src_[0]->Width();
201 const int height = attr_.axis == Axis::HEIGHT ? 1 : src_[0]->Height();
202 const int depth = attr_.axis == Axis::DEPTH ? 1 : src_[0]->Depth();
203 const int batch = attr_.axis == Axis::BATCH ? 1 : src_[0]->Batch();
204 const int slices = attr_.axis == Axis::CHANNELS ? 1 : src_[0]->Slices();
205 const int grid_x = width * batch;
206 const int grid_y = height * depth;
207 const int grid_z = slices;
208 return int3(grid_x, grid_y, grid_z);
209 }
210
CreateSplit(const GpuInfo & gpu_info,const OperationDef & definition,const SplitAttributes & attr,const std::vector<int> & channels)211 Split CreateSplit(const GpuInfo& gpu_info, const OperationDef& definition,
212 const SplitAttributes& attr,
213 const std::vector<int>& channels) {
214 return Split(gpu_info, definition, attr, channels);
215 }
216
217 } // namespace gpu
218 } // namespace tflite
219