xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/tasks/resize.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/resize.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "tensorflow/lite/delegates/gpu/common/operations.h"
22 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
23 
24 namespace tflite {
25 namespace gpu {
26 
Resize(const OperationDef & definition,const Resize2DAttributes & attr)27 Resize::Resize(const OperationDef& definition, const Resize2DAttributes& attr)
28     : GPUOperation(definition), attr_(attr) {
29   code_ = GetResizeCode(definition_, attr_);
30 }
31 
Resize(Resize && operation)32 Resize::Resize(Resize&& operation)
33     : GPUOperation(std::move(operation)), attr_(operation.attr_) {}
34 
operator =(Resize && operation)35 Resize& Resize::operator=(Resize&& operation) {
36   if (this != &operation) {
37     attr_ = operation.attr_;
38     GPUOperation::operator=(std::move(operation));
39   }
40   return *this;
41 }
42 
GetResizeCode(const OperationDef & op_def,const Resize2DAttributes & attr)43 std::string Resize::GetResizeCode(const OperationDef& op_def,
44                                   const Resize2DAttributes& attr) {
45   AddSrcTensor("src_tensor", op_def.src_tensors[0]);
46   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
47   args_.AddFloat("scale_factor_x");
48   args_.AddFloat("scale_factor_y");
49 
50   std::string c;
51   c += "MAIN_FUNCTION($0) {\n";
52   if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
53     c += "  int linear_id = GLOBAL_ID_0;\n";
54     c += "  int X = linear_id / args.dst_tensor.Batch();\n";
55     c += "  int B = linear_id % args.dst_tensor.Batch();\n";
56     c += "  args.src_tensor.SetBatchRef(B);\n";
57     c += "  args.dst_tensor.SetBatchRef(B);\n";
58   } else {
59     c += "  int X = GLOBAL_ID_0;\n";
60   }
61   c += "  int Y = GLOBAL_ID_1;\n";
62   c += "  int Z = GLOBAL_ID_2;\n";
63   c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
64        "|| Z >= args.dst_tensor.Slices()) return;\n";
65   if (attr.half_pixel_centers) {
66     c += "  float f_coords_x = (INIT_FLOAT(X) + 0.5f) * args.scale_factor_x;\n";
67     c += "  float f_coords_y = (INIT_FLOAT(Y) + 0.5f) * args.scale_factor_y;\n";
68   } else {
69     c += "  float f_coords_x = INIT_FLOAT(X) * args.scale_factor_x;\n";
70     c += "  float f_coords_y = INIT_FLOAT(Y) * args.scale_factor_y;\n";
71   }
72   c += "  FLT4 r0;\n";
73   if (attr.type == SamplingType::NEAREST) {
74     if (attr.align_corners) {
75       c += "  f_coords_x += 0.5f;";
76       c += "  f_coords_y += 0.5f;";
77     }
78     c += "  args.src_tensor.ReadNearest(r0, f_coords_x, f_coords_y, Z);\n";
79   } else {
80     if (attr.half_pixel_centers) {
81       c += "  f_coords_x -= 0.5f;";
82       c += "  f_coords_y -= 0.5f;";
83     }
84     c += "  args.src_tensor.ReadBilinear(r0, f_coords_x, f_coords_y, Z);\n";
85   }
86   c += "  args.dst_tensor.Write(r0, X, Y, Z);\n";
87   c += "}\n";
88   return c;
89 }
90 
BindArguments(ArgumentsBinder * args)91 absl::Status Resize::BindArguments(ArgumentsBinder* args) {
92   RETURN_IF_ERROR(args->SetFloat(
93       "scale_factor_x",
94       CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_)));
95   RETURN_IF_ERROR(args->SetFloat(
96       "scale_factor_y",
97       CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_)));
98   return absl::OkStatus();
99 }
100 
GetGridSize() const101 int3 Resize::GetGridSize() const {
102   const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
103   const int grid_y = dst_[0]->Height();
104   const int grid_z = dst_[0]->Slices();
105   return int3(grid_x, grid_y, grid_z);
106 }
107 
CreateResize(const OperationDef & definition,const Resize2DAttributes & attr)108 Resize CreateResize(const OperationDef& definition,
109                     const Resize2DAttributes& attr) {
110   return Resize(definition, attr);
111 }
112 
Resize3D(const OperationDef & definition,const Resize3DAttributes & attr)113 Resize3D::Resize3D(const OperationDef& definition,
114                    const Resize3DAttributes& attr)
115     : GPUOperation(definition), attr_(attr) {
116   code_ = GetResize3DCode(definition_, attr_);
117 }
118 
Resize3D(Resize3D && operation)119 Resize3D::Resize3D(Resize3D&& operation)
120     : GPUOperation(std::move(operation)), attr_(operation.attr_) {}
121 
operator =(Resize3D && operation)122 Resize3D& Resize3D::operator=(Resize3D&& operation) {
123   if (this != &operation) {
124     attr_ = operation.attr_;
125     GPUOperation::operator=(std::move(operation));
126   }
127   return *this;
128 }
129 
GetResize3DCode(const OperationDef & op_def,const Resize3DAttributes & attr)130 std::string Resize3D::GetResize3DCode(const OperationDef& op_def,
131                                       const Resize3DAttributes& attr) {
132   AddSrcTensor("src_tensor", op_def.src_tensors[0]);
133   AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
134   args_.AddFloat("scale_factor_x");
135   args_.AddFloat("scale_factor_y");
136   args_.AddFloat("scale_factor_z");
137 
138   std::string c;
139   c += "MAIN_FUNCTION($0) {\n";
140   if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
141     c += "  int linear_id = GLOBAL_ID_0;\n";
142     c += "  int X = linear_id / args.dst_tensor.Batch();\n";
143     c += "  int B = linear_id % args.dst_tensor.Batch();\n";
144     c += "  args.src_tensor.SetBatchRef(B);\n";
145     c += "  args.dst_tensor.SetBatchRef(B);\n";
146   } else {
147     c += "  int X = GLOBAL_ID_0;\n";
148   }
149   c += "  int Y = GLOBAL_ID_1;\n";
150   c += "  int linear_id_z = GLOBAL_ID_2;\n";
151   c += "  int S = linear_id_z % args.dst_tensor.Slices();\n";
152   c += "  int Z = linear_id_z / args.dst_tensor.Slices();\n";
153   c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
154        "|| Z >= args.dst_tensor.Depth()) return;\n";
155   if (attr.half_pixel_centers) {
156     c += "  float f_coords_x = (INIT_FLOAT(X) + 0.5f) * args.scale_factor_x;\n";
157     c += "  float f_coords_y = (INIT_FLOAT(Y) + 0.5f) * args.scale_factor_y;\n";
158     c += "  float f_coords_z = (INIT_FLOAT(Z) + 0.5f) * args.scale_factor_z;\n";
159   } else {
160     c += "  float f_coords_x = INIT_FLOAT(X) * args.scale_factor_x;\n";
161     c += "  float f_coords_y = INIT_FLOAT(Y) * args.scale_factor_y;\n";
162     c += "  float f_coords_z = INIT_FLOAT(Z) * args.scale_factor_z;\n";
163   }
164   c += "  FLT4 r0;\n";
165   if (attr.type == SamplingType::NEAREST) {
166     if (attr.align_corners) {
167       c += "  f_coords_x += 0.5f;";
168       c += "  f_coords_y += 0.5f;";
169       c += "  f_coords_z += 0.5f;";
170     }
171     c += "  args.src_tensor.ReadNearest(r0, f_coords_x, f_coords_y, "
172          "f_coords_z, S);\n";
173   } else {
174     if (attr.half_pixel_centers) {
175       c += "  f_coords_x -= 0.5f;";
176       c += "  f_coords_y -= 0.5f;";
177       c += "  f_coords_z -= 0.5f;";
178     }
179     c += "  args.src_tensor.ReadBilinear(r0, f_coords_x, f_coords_y, "
180          "f_coords_z, S);\n";
181   }
182   c += "  args.dst_tensor.Write(r0, X, Y, Z, S);\n";
183   c += "}\n";
184   return c;
185 }
186 
BindArguments(ArgumentsBinder * args)187 absl::Status Resize3D::BindArguments(ArgumentsBinder* args) {
188   RETURN_IF_ERROR(args->SetFloat(
189       "scale_factor_x",
190       CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_)));
191   RETURN_IF_ERROR(args->SetFloat(
192       "scale_factor_y",
193       CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_)));
194   RETURN_IF_ERROR(args->SetFloat(
195       "scale_factor_z",
196       CalculateResizeScale(src_[0]->Depth(), dst_[0]->Depth(), attr_)));
197   return absl::OkStatus();
198 }
199 
GetGridSize() const200 int3 Resize3D::GetGridSize() const {
201   const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
202   const int grid_y = dst_[0]->Height();
203   const int grid_z = dst_[0]->Slices() * dst_[0]->Depth();
204   return int3(grid_x, grid_y, grid_z);
205 }
206 
CreateResize3D(const OperationDef & definition,const Resize3DAttributes & attr)207 Resize3D CreateResize3D(const OperationDef& definition,
208                         const Resize3DAttributes& attr) {
209   return Resize3D(definition, attr);
210 }
211 
212 }  // namespace gpu
213 }  // namespace tflite
214