1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/conv_weights_converter.h"
17
18 #include <cstring>
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
24
25 namespace tflite {
26 namespace gpu {
27
ConverterToConvWeights(const OperationDef & definition,const WeightsDescription & weights_desc,Layout input_layout)28 ConverterToConvWeights::ConverterToConvWeights(
29 const OperationDef& definition, const WeightsDescription& weights_desc,
30 Layout input_layout)
31 : GPUOperation(definition),
32 weights_desc_(weights_desc),
33 input_layout_(input_layout) {
34 code_ = GetConverterToConvWeightsCode();
35 }
36
GetConverterToConvWeightsCode()37 std::string ConverterToConvWeights::GetConverterToConvWeightsCode() {
38 AddSrcTensor("src_tensor", definition_.src_tensors[0]);
39 args_.AddFloat("mask_x");
40 args_.AddFloat("mask_y");
41 args_.AddFloat("mask_z");
42 args_.AddFloat("mask_w");
43 args_.AddInt("out_ch");
44 args_.AddInt("out_ch_x4_groups");
45 args_.AddInt("in_ch");
46 args_.AddInt("in_ch_x4_groups");
47 args_.AddInt("kernel_width");
48 args_.AddInt("kernel_height");
49 args_.AddInt("kernel_spatial_size");
50
51 if (weights_desc_.layout == WeightsLayout::kOICustomSpatialI4O4 ||
52 weights_desc_.layout == WeightsLayout::kOICustomSpatialO4I4) {
53 std::vector<int32_t> remap(weights_desc_.spatial_remap.size());
54 for (int i = 0; i < remap.size(); ++i) {
55 remap[i] = weights_desc_.spatial_remap[i];
56 }
57 BufferDescriptor desc;
58 desc.element_type = DataType::INT32;
59 desc.element_size = 1;
60 desc.memory_type = MemoryType::GLOBAL;
61 desc.size = remap.size() * sizeof(int32_t);
62 desc.data.resize(desc.size);
63 std::memcpy(desc.data.data(), remap.data(), desc.size);
64 args_.AddObject("spatial_remap",
65 std::make_unique<BufferDescriptor>(std::move(desc)));
66 }
67
68 std::string c;
69 c += "MAIN_FUNCTION($0) {\n";
70 c += " int O = GLOBAL_ID_0;\n";
71 c += " int I = GLOBAL_ID_1;\n";
72 c += " int spatial_linear = GLOBAL_ID_2;\n";
73 c += " if (O >= args.out_ch_x4_groups) return;\n";
74 c += " if (I >= args.in_ch_x4_groups) return;\n";
75 c += " if (spatial_linear >= args.kernel_spatial_size) return;\n";
76 if (weights_desc_.layout == WeightsLayout::kOICustomSpatialI4O4 ||
77 weights_desc_.layout == WeightsLayout::kOICustomSpatialO4I4) {
78 c += " int linear_remap = args.spatial_remap.Read(spatial_linear);\n";
79 c += " int W = linear_remap % args.kernel_width;\n";
80 c += " int H = linear_remap / args.kernel_width;\n";
81 } else {
82 c += " int W = spatial_linear % args.kernel_width;\n";
83 c += " int H = spatial_linear / args.kernel_width;\n";
84 }
85 // W and H is src coordinates, spatial_linear is dst coordinate
86 c += " FLT4 v0 = INIT_FLT4(0.0f);\n";
87 c += " FLT4 v1 = INIT_FLT4(0.0f);\n";
88 c += " FLT4 v2 = INIT_FLT4(0.0f);\n";
89 c += " FLT4 v3 = INIT_FLT4(0.0f);\n";
90 if (input_layout_ == Layout::OHWI) {
91 c += " if (O * 4 < args.out_ch) {\n";
92 c += " v0 = args.src_tensor.Read(W, H, I, O * 4);\n";
93 c += " }\n";
94 c += " if (O * 4 + 1 < args.out_ch) {\n";
95 c += " v1 = args.src_tensor.Read(W, H, I, O * 4 + 1);\n";
96 c += " }\n";
97 c += " if (O * 4 + 2 < args.out_ch) {\n";
98 c += " v2 = args.src_tensor.Read(W, H, I, O * 4 + 2);\n";
99 c += " }\n";
100 c += " if (O * 4 + 3 < args.out_ch) {\n";
101 c += " v3 = args.src_tensor.Read(W, H, I, O * 4 + 3);\n";
102 c += " }\n";
103 c += " if (I == args.src_tensor.Slices() - 1) {\n";
104 c += " FLT4 mask = INIT_FLT4v4(args.mask_x, args.mask_y, args.mask_z, "
105 "args.mask_w);\n";
106 c += " v0 *= mask;\n";
107 c += " v1 *= mask;\n";
108 c += " v2 *= mask;\n";
109 c += " v3 *= mask;\n";
110 c += " }\n";
111 } else if (input_layout_ == Layout::HWIO) {
112 c += " if (I * 4 < args.in_ch && O < args.src_tensor.Slices()) {\n";
113 c += " v0 = args.src_tensor.Read(I * 4, W, O, H);\n";
114 c += " }\n";
115 c += " if (I * 4 + 1 < args.in_ch && O < args.src_tensor.Slices()) {\n";
116 c += " v1 = args.src_tensor.Read(I * 4 + 1, W, O, H);\n";
117 c += " }\n";
118 c += " if (I * 4 + 2 < args.in_ch && O < args.src_tensor.Slices()) {\n";
119 c += " v2 = args.src_tensor.Read(I * 4 + 2, W, O, H);\n";
120 c += " }\n";
121 c += " if (I * 4 + 3 < args.in_ch && O < args.src_tensor.Slices()) {\n";
122 c += " v3 = args.src_tensor.Read(I * 4 + 3, W, O, H);\n";
123 c += " }\n";
124 c += " if (O == args.src_tensor.Slices() - 1) {\n";
125 c += " FLT4 mask = INIT_FLT4v4(args.mask_x, args.mask_y, args.mask_z, "
126 "args.mask_w);\n";
127 c += " v0 *= mask;\n";
128 c += " v1 *= mask;\n";
129 c += " v2 *= mask;\n";
130 c += " v3 *= mask;\n";
131 c += " }\n";
132 }
133 const bool need_transpose =
134 (input_layout_ == Layout::HWIO && weights_desc_.IsO4I4()) ||
135 (input_layout_ == Layout::OHWI && weights_desc_.IsI4O4());
136 if (need_transpose) {
137 c += " FLT4 r0 = INIT_FLT4v4(v0.x, v1.x, v2.x, v3.x);\n";
138 c += " FLT4 r1 = INIT_FLT4v4(v0.y, v1.y, v2.y, v3.y);\n";
139 c += " FLT4 r2 = INIT_FLT4v4(v0.z, v1.z, v2.z, v3.z);\n";
140 c += " FLT4 r3 = INIT_FLT4v4(v0.w, v1.w, v2.w, v3.w);\n";
141 } else {
142 c += " FLT4 r0 = v0;\n";
143 c += " FLT4 r1 = v1;\n";
144 c += " FLT4 r2 = v2;\n";
145 c += " FLT4 r3 = v3;\n";
146 }
147 if (weights_desc_.layout ==
148 WeightsLayout::k2DX4I4YIsSpatialIAndXIsOOGroupO4 ||
149 weights_desc_.layout ==
150 WeightsLayout::k2DX4O4YIsSpatialIAndXIsOOGroupI4) {
151 // Writing to 4X Textures 2D
152 AddDstTensor("dst_tensor0", definition_.dst_tensors[0]);
153 AddDstTensor("dst_tensor1", definition_.dst_tensors[1]);
154 AddDstTensor("dst_tensor2", definition_.dst_tensors[2]);
155 AddDstTensor("dst_tensor3", definition_.dst_tensors[3]);
156 c += " int yc = spatial_linear * args.in_ch_x4_groups + I;\n";
157 c += " args.dst_tensor0.Write2D(r0, O, yc);\n";
158 c += " args.dst_tensor1.Write2D(r1, O, yc);\n";
159 c += " args.dst_tensor2.Write2D(r2, O, yc);\n";
160 c += " args.dst_tensor3.Write2D(r3, O, yc);\n";
161 c += "}\n";
162 } else {
163 // Writing to linear buffer
164 AddDstTensor("dst_tensor", definition_.dst_tensors[0]);
165 c += " int OUTPUT_GROUP_SIZE = " +
166 std::to_string(weights_desc_.GetOutputGroupSize()) + ";\n";
167 c += " int d_index = (O * 4) / (OUTPUT_GROUP_SIZE * 4);\n";
168 c += " int k_index = ((O * 4) % (OUTPUT_GROUP_SIZE * 4)) / 4;\n";
169 std::string index;
170 if (weights_desc_.layout == WeightsLayout::kOICustomSpatialI4O4 ||
171 weights_desc_.layout == WeightsLayout::kOICustomSpatialO4I4) {
172 index =
173 "(d_index * args.in_ch_x4_groups + I) * args.kernel_spatial_size + "
174 "spatial_linear";
175 } else if (weights_desc_.layout == WeightsLayout::kOSpatialIOGroupI4O4 ||
176 weights_desc_.layout == WeightsLayout::kOSpatialIOGroupO4I4) {
177 index =
178 "(d_index * args.kernel_spatial_size + spatial_linear) * "
179 "args.in_ch_x4_groups + I";
180 }
181 c += " int dst_offset = (" + index + ") * OUTPUT_GROUP_SIZE + k_index;\n";
182 c += " args.dst_tensor.WriteLinear(r0, dst_offset * 4 + 0);\n";
183 c += " args.dst_tensor.WriteLinear(r1, dst_offset * 4 + 1);\n";
184 c += " args.dst_tensor.WriteLinear(r2, dst_offset * 4 + 2);\n";
185 c += " args.dst_tensor.WriteLinear(r3, dst_offset * 4 + 3);\n";
186 c += "}\n";
187 }
188 return c;
189 }
190
GetWeightsSize() const191 OHWI ConverterToConvWeights::GetWeightsSize() const {
192 int output_channels = 0;
193 int input_channels = 0;
194 int kernel_width = 0;
195 int kernel_height = 0;
196 if (input_layout_ == Layout::HWIO) {
197 output_channels = src_[0]->Channels();
198 input_channels = src_[0]->Width();
199 kernel_width = src_[0]->Height();
200 kernel_height = src_[0]->Batch();
201 } else if (input_layout_ == Layout::OHWI) {
202 output_channels = src_[0]->Batch();
203 input_channels = src_[0]->Channels();
204 kernel_width = src_[0]->Width();
205 kernel_height = src_[0]->Height();
206 }
207 return OHWI(output_channels, kernel_height, kernel_width, input_channels);
208 }
209
BindArguments(ArgumentsBinder * args)210 absl::Status ConverterToConvWeights::BindArguments(ArgumentsBinder* args) {
211 const auto& weights_shape = GetWeightsSize();
212 const int output_channels_x4_groups = DivideRoundUp(
213 AlignByN(weights_shape.o, 4 * weights_desc_.GetOutputGroupSize()), 4);
214 RETURN_IF_ERROR(args->SetInt("out_ch", weights_shape.o));
215 RETURN_IF_ERROR(args->SetInt("out_ch_x4_groups", output_channels_x4_groups));
216 RETURN_IF_ERROR(args->SetInt("in_ch", weights_shape.i));
217 RETURN_IF_ERROR(
218 args->SetInt("in_ch_x4_groups", DivideRoundUp(weights_shape.i, 4)));
219 RETURN_IF_ERROR(args->SetInt("kernel_width", weights_shape.w));
220 RETURN_IF_ERROR(args->SetInt("kernel_height", weights_shape.h));
221 RETURN_IF_ERROR(
222 args->SetInt("kernel_spatial_size", weights_shape.w * weights_shape.h));
223 float4 mask = GetMaskForLastPlane(src_[0]->Channels());
224 RETURN_IF_ERROR(args->SetFloat("mask_x", mask.x));
225 RETURN_IF_ERROR(args->SetFloat("mask_y", mask.y));
226 RETURN_IF_ERROR(args->SetFloat("mask_z", mask.z));
227 return args->SetFloat("mask_w", mask.w);
228 }
229
GetGridSize() const230 int3 ConverterToConvWeights::GetGridSize() const {
231 const auto& weights_shape = GetWeightsSize();
232 const int out_group_size = weights_desc_.GetOutputGroupSize();
233 const int grid_x =
234 DivideRoundUp(AlignByN(weights_shape.o, 4 * out_group_size), 4);
235 const int grid_y = DivideRoundUp(weights_shape.i, 4);
236 const int grid_z = weights_shape.w * weights_shape.h;
237 return int3(grid_x, grid_y, grid_z);
238 }
239
CreateConverterToConvWeights(const OperationDef & definition,const WeightsDescription & weights_desc,Layout input_layout)240 ConverterToConvWeights CreateConverterToConvWeights(
241 const OperationDef& definition, const WeightsDescription& weights_desc,
242 Layout input_layout) {
243 return ConverterToConvWeights(definition, weights_desc, input_layout);
244 }
245
246 } // namespace gpu
247 } // namespace tflite
248