1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/resize.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "tensorflow/lite/delegates/gpu/common/operations.h"
22 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
23
24 namespace tflite {
25 namespace gpu {
26
Resize(const OperationDef & definition,const Resize2DAttributes & attr)27 Resize::Resize(const OperationDef& definition, const Resize2DAttributes& attr)
28 : GPUOperation(definition), attr_(attr) {
29 code_ = GetResizeCode(definition_, attr_);
30 }
31
Resize(Resize && operation)32 Resize::Resize(Resize&& operation)
33 : GPUOperation(std::move(operation)), attr_(operation.attr_) {}
34
operator =(Resize && operation)35 Resize& Resize::operator=(Resize&& operation) {
36 if (this != &operation) {
37 attr_ = operation.attr_;
38 GPUOperation::operator=(std::move(operation));
39 }
40 return *this;
41 }
42
GetResizeCode(const OperationDef & op_def,const Resize2DAttributes & attr)43 std::string Resize::GetResizeCode(const OperationDef& op_def,
44 const Resize2DAttributes& attr) {
45 AddSrcTensor("src_tensor", op_def.src_tensors[0]);
46 AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
47 args_.AddFloat("scale_factor_x");
48 args_.AddFloat("scale_factor_y");
49
50 std::string c;
51 c += "MAIN_FUNCTION($0) {\n";
52 if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
53 c += " int linear_id = GLOBAL_ID_0;\n";
54 c += " int X = linear_id / args.dst_tensor.Batch();\n";
55 c += " int B = linear_id % args.dst_tensor.Batch();\n";
56 c += " args.src_tensor.SetBatchRef(B);\n";
57 c += " args.dst_tensor.SetBatchRef(B);\n";
58 } else {
59 c += " int X = GLOBAL_ID_0;\n";
60 }
61 c += " int Y = GLOBAL_ID_1;\n";
62 c += " int Z = GLOBAL_ID_2;\n";
63 c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
64 "|| Z >= args.dst_tensor.Slices()) return;\n";
65 if (attr.half_pixel_centers) {
66 c += " float f_coords_x = (INIT_FLOAT(X) + 0.5f) * args.scale_factor_x;\n";
67 c += " float f_coords_y = (INIT_FLOAT(Y) + 0.5f) * args.scale_factor_y;\n";
68 } else {
69 c += " float f_coords_x = INIT_FLOAT(X) * args.scale_factor_x;\n";
70 c += " float f_coords_y = INIT_FLOAT(Y) * args.scale_factor_y;\n";
71 }
72 c += " FLT4 r0;\n";
73 if (attr.type == SamplingType::NEAREST) {
74 if (attr.align_corners) {
75 c += " f_coords_x += 0.5f;";
76 c += " f_coords_y += 0.5f;";
77 }
78 c += " args.src_tensor.ReadNearest(r0, f_coords_x, f_coords_y, Z);\n";
79 } else {
80 if (attr.half_pixel_centers) {
81 c += " f_coords_x -= 0.5f;";
82 c += " f_coords_y -= 0.5f;";
83 }
84 c += " args.src_tensor.ReadBilinear(r0, f_coords_x, f_coords_y, Z);\n";
85 }
86 c += " args.dst_tensor.Write(r0, X, Y, Z);\n";
87 c += "}\n";
88 return c;
89 }
90
BindArguments(ArgumentsBinder * args)91 absl::Status Resize::BindArguments(ArgumentsBinder* args) {
92 RETURN_IF_ERROR(args->SetFloat(
93 "scale_factor_x",
94 CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_)));
95 RETURN_IF_ERROR(args->SetFloat(
96 "scale_factor_y",
97 CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_)));
98 return absl::OkStatus();
99 }
100
GetGridSize() const101 int3 Resize::GetGridSize() const {
102 const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
103 const int grid_y = dst_[0]->Height();
104 const int grid_z = dst_[0]->Slices();
105 return int3(grid_x, grid_y, grid_z);
106 }
107
CreateResize(const OperationDef & definition,const Resize2DAttributes & attr)108 Resize CreateResize(const OperationDef& definition,
109 const Resize2DAttributes& attr) {
110 return Resize(definition, attr);
111 }
112
Resize3D(const OperationDef & definition,const Resize3DAttributes & attr)113 Resize3D::Resize3D(const OperationDef& definition,
114 const Resize3DAttributes& attr)
115 : GPUOperation(definition), attr_(attr) {
116 code_ = GetResize3DCode(definition_, attr_);
117 }
118
Resize3D(Resize3D && operation)119 Resize3D::Resize3D(Resize3D&& operation)
120 : GPUOperation(std::move(operation)), attr_(operation.attr_) {}
121
operator =(Resize3D && operation)122 Resize3D& Resize3D::operator=(Resize3D&& operation) {
123 if (this != &operation) {
124 attr_ = operation.attr_;
125 GPUOperation::operator=(std::move(operation));
126 }
127 return *this;
128 }
129
GetResize3DCode(const OperationDef & op_def,const Resize3DAttributes & attr)130 std::string Resize3D::GetResize3DCode(const OperationDef& op_def,
131 const Resize3DAttributes& attr) {
132 AddSrcTensor("src_tensor", op_def.src_tensors[0]);
133 AddDstTensor("dst_tensor", op_def.dst_tensors[0]);
134 args_.AddFloat("scale_factor_x");
135 args_.AddFloat("scale_factor_y");
136 args_.AddFloat("scale_factor_z");
137
138 std::string c;
139 c += "MAIN_FUNCTION($0) {\n";
140 if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
141 c += " int linear_id = GLOBAL_ID_0;\n";
142 c += " int X = linear_id / args.dst_tensor.Batch();\n";
143 c += " int B = linear_id % args.dst_tensor.Batch();\n";
144 c += " args.src_tensor.SetBatchRef(B);\n";
145 c += " args.dst_tensor.SetBatchRef(B);\n";
146 } else {
147 c += " int X = GLOBAL_ID_0;\n";
148 }
149 c += " int Y = GLOBAL_ID_1;\n";
150 c += " int linear_id_z = GLOBAL_ID_2;\n";
151 c += " int S = linear_id_z % args.dst_tensor.Slices();\n";
152 c += " int Z = linear_id_z / args.dst_tensor.Slices();\n";
153 c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
154 "|| Z >= args.dst_tensor.Depth()) return;\n";
155 if (attr.half_pixel_centers) {
156 c += " float f_coords_x = (INIT_FLOAT(X) + 0.5f) * args.scale_factor_x;\n";
157 c += " float f_coords_y = (INIT_FLOAT(Y) + 0.5f) * args.scale_factor_y;\n";
158 c += " float f_coords_z = (INIT_FLOAT(Z) + 0.5f) * args.scale_factor_z;\n";
159 } else {
160 c += " float f_coords_x = INIT_FLOAT(X) * args.scale_factor_x;\n";
161 c += " float f_coords_y = INIT_FLOAT(Y) * args.scale_factor_y;\n";
162 c += " float f_coords_z = INIT_FLOAT(Z) * args.scale_factor_z;\n";
163 }
164 c += " FLT4 r0;\n";
165 if (attr.type == SamplingType::NEAREST) {
166 if (attr.align_corners) {
167 c += " f_coords_x += 0.5f;";
168 c += " f_coords_y += 0.5f;";
169 c += " f_coords_z += 0.5f;";
170 }
171 c += " args.src_tensor.ReadNearest(r0, f_coords_x, f_coords_y, "
172 "f_coords_z, S);\n";
173 } else {
174 if (attr.half_pixel_centers) {
175 c += " f_coords_x -= 0.5f;";
176 c += " f_coords_y -= 0.5f;";
177 c += " f_coords_z -= 0.5f;";
178 }
179 c += " args.src_tensor.ReadBilinear(r0, f_coords_x, f_coords_y, "
180 "f_coords_z, S);\n";
181 }
182 c += " args.dst_tensor.Write(r0, X, Y, Z, S);\n";
183 c += "}\n";
184 return c;
185 }
186
BindArguments(ArgumentsBinder * args)187 absl::Status Resize3D::BindArguments(ArgumentsBinder* args) {
188 RETURN_IF_ERROR(args->SetFloat(
189 "scale_factor_x",
190 CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_)));
191 RETURN_IF_ERROR(args->SetFloat(
192 "scale_factor_y",
193 CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_)));
194 RETURN_IF_ERROR(args->SetFloat(
195 "scale_factor_z",
196 CalculateResizeScale(src_[0]->Depth(), dst_[0]->Depth(), attr_)));
197 return absl::OkStatus();
198 }
199
GetGridSize() const200 int3 Resize3D::GetGridSize() const {
201 const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
202 const int grid_y = dst_[0]->Height();
203 const int grid_z = dst_[0]->Slices() * dst_[0]->Depth();
204 return int3(grid_x, grid_y, grid_z);
205 }
206
CreateResize3D(const OperationDef & definition,const Resize3DAttributes & attr)207 Resize3D CreateResize3D(const OperationDef& definition,
208 const Resize3DAttributes& attr) {
209 return Resize3D(definition, attr);
210 }
211
212 } // namespace gpu
213 } // namespace tflite
214