xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/tasks/conv_weights_converter.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/conv_weights_converter.h"
17 
18 #include <cstring>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "tensorflow/lite/delegates/gpu/common/task/util.h"
24 
25 namespace tflite {
26 namespace gpu {
27 
ConverterToConvWeights(const OperationDef & definition,const WeightsDescription & weights_desc,Layout input_layout)28 ConverterToConvWeights::ConverterToConvWeights(
29     const OperationDef& definition, const WeightsDescription& weights_desc,
30     Layout input_layout)
31     : GPUOperation(definition),
32       weights_desc_(weights_desc),
33       input_layout_(input_layout) {
34   code_ = GetConverterToConvWeightsCode();
35 }
36 
GetConverterToConvWeightsCode()37 std::string ConverterToConvWeights::GetConverterToConvWeightsCode() {
38   AddSrcTensor("src_tensor", definition_.src_tensors[0]);
39   args_.AddFloat("mask_x");
40   args_.AddFloat("mask_y");
41   args_.AddFloat("mask_z");
42   args_.AddFloat("mask_w");
43   args_.AddInt("out_ch");
44   args_.AddInt("out_ch_x4_groups");
45   args_.AddInt("in_ch");
46   args_.AddInt("in_ch_x4_groups");
47   args_.AddInt("kernel_width");
48   args_.AddInt("kernel_height");
49   args_.AddInt("kernel_spatial_size");
50 
51   if (weights_desc_.layout == WeightsLayout::kOICustomSpatialI4O4 ||
52       weights_desc_.layout == WeightsLayout::kOICustomSpatialO4I4) {
53     std::vector<int32_t> remap(weights_desc_.spatial_remap.size());
54     for (int i = 0; i < remap.size(); ++i) {
55       remap[i] = weights_desc_.spatial_remap[i];
56     }
57     BufferDescriptor desc;
58     desc.element_type = DataType::INT32;
59     desc.element_size = 1;
60     desc.memory_type = MemoryType::GLOBAL;
61     desc.size = remap.size() * sizeof(int32_t);
62     desc.data.resize(desc.size);
63     std::memcpy(desc.data.data(), remap.data(), desc.size);
64     args_.AddObject("spatial_remap",
65                     std::make_unique<BufferDescriptor>(std::move(desc)));
66   }
67 
68   std::string c;
69   c += "MAIN_FUNCTION($0) {\n";
70   c += "  int O = GLOBAL_ID_0;\n";
71   c += "  int I = GLOBAL_ID_1;\n";
72   c += "  int spatial_linear = GLOBAL_ID_2;\n";
73   c += "  if (O >= args.out_ch_x4_groups) return;\n";
74   c += "  if (I >= args.in_ch_x4_groups) return;\n";
75   c += "  if (spatial_linear >= args.kernel_spatial_size) return;\n";
76   if (weights_desc_.layout == WeightsLayout::kOICustomSpatialI4O4 ||
77       weights_desc_.layout == WeightsLayout::kOICustomSpatialO4I4) {
78     c += "  int linear_remap = args.spatial_remap.Read(spatial_linear);\n";
79     c += "  int W = linear_remap % args.kernel_width;\n";
80     c += "  int H = linear_remap / args.kernel_width;\n";
81   } else {
82     c += "  int W = spatial_linear % args.kernel_width;\n";
83     c += "  int H = spatial_linear / args.kernel_width;\n";
84   }
85   // W and H is src coordinates, spatial_linear is dst coordinate
86   c += "  FLT4 v0 = INIT_FLT4(0.0f);\n";
87   c += "  FLT4 v1 = INIT_FLT4(0.0f);\n";
88   c += "  FLT4 v2 = INIT_FLT4(0.0f);\n";
89   c += "  FLT4 v3 = INIT_FLT4(0.0f);\n";
90   if (input_layout_ == Layout::OHWI) {
91     c += "  if (O * 4 < args.out_ch) {\n";
92     c += "    v0 = args.src_tensor.Read(W, H, I, O * 4);\n";
93     c += "  }\n";
94     c += "  if (O * 4 + 1 < args.out_ch) {\n";
95     c += "    v1 = args.src_tensor.Read(W, H, I, O * 4 + 1);\n";
96     c += "  }\n";
97     c += "  if (O * 4 + 2 < args.out_ch) {\n";
98     c += "    v2 = args.src_tensor.Read(W, H, I, O * 4 + 2);\n";
99     c += "  }\n";
100     c += "  if (O * 4 + 3 < args.out_ch) {\n";
101     c += "    v3 = args.src_tensor.Read(W, H, I, O * 4 + 3);\n";
102     c += "  }\n";
103     c += "  if (I == args.src_tensor.Slices() - 1) {\n";
104     c += "    FLT4 mask = INIT_FLT4v4(args.mask_x, args.mask_y, args.mask_z, "
105          "args.mask_w);\n";
106     c += "    v0 *= mask;\n";
107     c += "    v1 *= mask;\n";
108     c += "    v2 *= mask;\n";
109     c += "    v3 *= mask;\n";
110     c += "  }\n";
111   } else if (input_layout_ == Layout::HWIO) {
112     c += "  if (I * 4 < args.in_ch && O < args.src_tensor.Slices()) {\n";
113     c += "    v0 = args.src_tensor.Read(I * 4, W, O, H);\n";
114     c += "  }\n";
115     c += "  if (I * 4 + 1 < args.in_ch && O < args.src_tensor.Slices()) {\n";
116     c += "    v1 = args.src_tensor.Read(I * 4 + 1, W, O, H);\n";
117     c += "  }\n";
118     c += "  if (I * 4 + 2 < args.in_ch && O < args.src_tensor.Slices()) {\n";
119     c += "    v2 = args.src_tensor.Read(I * 4 + 2, W, O, H);\n";
120     c += "  }\n";
121     c += "  if (I * 4 + 3 < args.in_ch && O < args.src_tensor.Slices()) {\n";
122     c += "    v3 = args.src_tensor.Read(I * 4 + 3, W, O, H);\n";
123     c += "  }\n";
124     c += "  if (O == args.src_tensor.Slices() - 1) {\n";
125     c += "    FLT4 mask = INIT_FLT4v4(args.mask_x, args.mask_y, args.mask_z, "
126          "args.mask_w);\n";
127     c += "    v0 *= mask;\n";
128     c += "    v1 *= mask;\n";
129     c += "    v2 *= mask;\n";
130     c += "    v3 *= mask;\n";
131     c += "  }\n";
132   }
133   const bool need_transpose =
134       (input_layout_ == Layout::HWIO && weights_desc_.IsO4I4()) ||
135       (input_layout_ == Layout::OHWI && weights_desc_.IsI4O4());
136   if (need_transpose) {
137     c += "  FLT4 r0 = INIT_FLT4v4(v0.x, v1.x, v2.x, v3.x);\n";
138     c += "  FLT4 r1 = INIT_FLT4v4(v0.y, v1.y, v2.y, v3.y);\n";
139     c += "  FLT4 r2 = INIT_FLT4v4(v0.z, v1.z, v2.z, v3.z);\n";
140     c += "  FLT4 r3 = INIT_FLT4v4(v0.w, v1.w, v2.w, v3.w);\n";
141   } else {
142     c += "  FLT4 r0 = v0;\n";
143     c += "  FLT4 r1 = v1;\n";
144     c += "  FLT4 r2 = v2;\n";
145     c += "  FLT4 r3 = v3;\n";
146   }
147   if (weights_desc_.layout ==
148           WeightsLayout::k2DX4I4YIsSpatialIAndXIsOOGroupO4 ||
149       weights_desc_.layout ==
150           WeightsLayout::k2DX4O4YIsSpatialIAndXIsOOGroupI4) {
151     // Writing to 4X Textures 2D
152     AddDstTensor("dst_tensor0", definition_.dst_tensors[0]);
153     AddDstTensor("dst_tensor1", definition_.dst_tensors[1]);
154     AddDstTensor("dst_tensor2", definition_.dst_tensors[2]);
155     AddDstTensor("dst_tensor3", definition_.dst_tensors[3]);
156     c += "  int yc = spatial_linear *  args.in_ch_x4_groups + I;\n";
157     c += "  args.dst_tensor0.Write2D(r0, O, yc);\n";
158     c += "  args.dst_tensor1.Write2D(r1, O, yc);\n";
159     c += "  args.dst_tensor2.Write2D(r2, O, yc);\n";
160     c += "  args.dst_tensor3.Write2D(r3, O, yc);\n";
161     c += "}\n";
162   } else {
163     // Writing to linear buffer
164     AddDstTensor("dst_tensor", definition_.dst_tensors[0]);
165     c += "  int OUTPUT_GROUP_SIZE = " +
166          std::to_string(weights_desc_.GetOutputGroupSize()) + ";\n";
167     c += "  int d_index = (O * 4) / (OUTPUT_GROUP_SIZE * 4);\n";
168     c += "  int k_index = ((O * 4) % (OUTPUT_GROUP_SIZE * 4)) / 4;\n";
169     std::string index;
170     if (weights_desc_.layout == WeightsLayout::kOICustomSpatialI4O4 ||
171         weights_desc_.layout == WeightsLayout::kOICustomSpatialO4I4) {
172       index =
173           "(d_index * args.in_ch_x4_groups + I) * args.kernel_spatial_size + "
174           "spatial_linear";
175     } else if (weights_desc_.layout == WeightsLayout::kOSpatialIOGroupI4O4 ||
176                weights_desc_.layout == WeightsLayout::kOSpatialIOGroupO4I4) {
177       index =
178           "(d_index * args.kernel_spatial_size + spatial_linear) * "
179           "args.in_ch_x4_groups + I";
180     }
181     c += "  int dst_offset = (" + index + ") * OUTPUT_GROUP_SIZE + k_index;\n";
182     c += "  args.dst_tensor.WriteLinear(r0, dst_offset * 4 + 0);\n";
183     c += "  args.dst_tensor.WriteLinear(r1, dst_offset * 4 + 1);\n";
184     c += "  args.dst_tensor.WriteLinear(r2, dst_offset * 4 + 2);\n";
185     c += "  args.dst_tensor.WriteLinear(r3, dst_offset * 4 + 3);\n";
186     c += "}\n";
187   }
188   return c;
189 }
190 
GetWeightsSize() const191 OHWI ConverterToConvWeights::GetWeightsSize() const {
192   int output_channels = 0;
193   int input_channels = 0;
194   int kernel_width = 0;
195   int kernel_height = 0;
196   if (input_layout_ == Layout::HWIO) {
197     output_channels = src_[0]->Channels();
198     input_channels = src_[0]->Width();
199     kernel_width = src_[0]->Height();
200     kernel_height = src_[0]->Batch();
201   } else if (input_layout_ == Layout::OHWI) {
202     output_channels = src_[0]->Batch();
203     input_channels = src_[0]->Channels();
204     kernel_width = src_[0]->Width();
205     kernel_height = src_[0]->Height();
206   }
207   return OHWI(output_channels, kernel_height, kernel_width, input_channels);
208 }
209 
BindArguments(ArgumentsBinder * args)210 absl::Status ConverterToConvWeights::BindArguments(ArgumentsBinder* args) {
211   const auto& weights_shape = GetWeightsSize();
212   const int output_channels_x4_groups = DivideRoundUp(
213       AlignByN(weights_shape.o, 4 * weights_desc_.GetOutputGroupSize()), 4);
214   RETURN_IF_ERROR(args->SetInt("out_ch", weights_shape.o));
215   RETURN_IF_ERROR(args->SetInt("out_ch_x4_groups", output_channels_x4_groups));
216   RETURN_IF_ERROR(args->SetInt("in_ch", weights_shape.i));
217   RETURN_IF_ERROR(
218       args->SetInt("in_ch_x4_groups", DivideRoundUp(weights_shape.i, 4)));
219   RETURN_IF_ERROR(args->SetInt("kernel_width", weights_shape.w));
220   RETURN_IF_ERROR(args->SetInt("kernel_height", weights_shape.h));
221   RETURN_IF_ERROR(
222       args->SetInt("kernel_spatial_size", weights_shape.w * weights_shape.h));
223   float4 mask = GetMaskForLastPlane(src_[0]->Channels());
224   RETURN_IF_ERROR(args->SetFloat("mask_x", mask.x));
225   RETURN_IF_ERROR(args->SetFloat("mask_y", mask.y));
226   RETURN_IF_ERROR(args->SetFloat("mask_z", mask.z));
227   return args->SetFloat("mask_w", mask.w);
228 }
229 
GetGridSize() const230 int3 ConverterToConvWeights::GetGridSize() const {
231   const auto& weights_shape = GetWeightsSize();
232   const int out_group_size = weights_desc_.GetOutputGroupSize();
233   const int grid_x =
234       DivideRoundUp(AlignByN(weights_shape.o, 4 * out_group_size), 4);
235   const int grid_y = DivideRoundUp(weights_shape.i, 4);
236   const int grid_z = weights_shape.w * weights_shape.h;
237   return int3(grid_x, grid_y, grid_z);
238 }
239 
CreateConverterToConvWeights(const OperationDef & definition,const WeightsDescription & weights_desc,Layout input_layout)240 ConverterToConvWeights CreateConverterToConvWeights(
241     const OperationDef& definition, const WeightsDescription& weights_desc,
242     Layout input_layout) {
243   return ConverterToConvWeights(definition, weights_desc, input_layout);
244 }
245 
246 }  // namespace gpu
247 }  // namespace tflite
248