xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/tasks/resampler.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/resampler.h"
17 
18 #include <string>
19 
20 namespace tflite {
21 namespace gpu {
22 namespace {
23 
GetResamplerCode(const GpuInfo & gpu_info,const OperationDef & op_def)24 std::string GetResamplerCode(const GpuInfo& gpu_info,
25                              const OperationDef& op_def) {
26   std::string c;
27   c += "MAIN_FUNCTION($0) {\n";
28   if (op_def.dst_tensors[0].HasAxis(Axis::BATCH)) {
29     c += "  int linear_id = GLOBAL_ID_0;\n";
30     c += "  int X = linear_id / args.dst_tensor.Batch();\n";
31     c += "  int B = linear_id % args.dst_tensor.Batch();\n";
32     c += "  args.dst_tensor.SetBatchRef(B);\n";
33   } else {
34     c += "  int X = GLOBAL_ID_0;\n";
35   }
36   c += "  int Y = GLOBAL_ID_1;\n";
37   c += "  int S = GLOBAL_ID_2;\n";
38   c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
39        "S >= args.dst_tensor.Slices()) { \n";
40   c += "    return; \n";
41   c += "  } \n";
42   c += "  float2 f_coords = args.warp_tensor.Read<float>(X, Y, 0).xy;\n";
43   c += "  float2 f_coords_floor = floor(f_coords);\n";
44   c += "  int4 st;\n";
45   c += "  st.xy = INIT_INT2v2(f_coords_floor.x, f_coords_floor.y);\n";
46   c += "  st.zw = st.xy + INIT_INT2v2(1, 1);\n";
47   c += "  float2 t = f_coords - f_coords_floor;\n";
48   bool supports_hw_zero_clamp =
49       op_def.src_tensors[0].SupportsZeroClamp(Axis::WIDTH, gpu_info) &&
50       op_def.src_tensors[0].SupportsZeroClamp(Axis::HEIGHT, gpu_info);
51   if (supports_hw_zero_clamp) {
52     c += R"(
53   float4 src0 = args.src_tensor.Read<float>(st.x, st.y, S);
54   float4 src1 = args.src_tensor.Read<float>(st.z, st.y, S);
55   float4 src2 = args.src_tensor.Read<float>(st.x, st.w, S);
56   float4 src3 = args.src_tensor.Read<float>(st.z, st.w, S);
57 )";
58   } else {
59     c += R"(
60   bool stx_in = st.x >= 0 && st.x < args.src_tensor.Width();
61   bool stz_in = st.z >= 0 && st.z < args.src_tensor.Width();
62   bool sty_in = st.y >= 0 && st.y < args.src_tensor.Height();
63   bool stw_in = st.w >= 0 && st.w < args.src_tensor.Height();
64   float4 src0 = (stx_in && sty_in) ? args.src_tensor.Read<float>(st.x, st.y, S) : INIT_FLOAT4(0.0f);
65   float4 src1 = (stz_in && sty_in) ? args.src_tensor.Read<float>(st.z, st.y, S) : INIT_FLOAT4(0.0f);
66   float4 src2 = (stx_in && stw_in) ? args.src_tensor.Read<float>(st.x, st.w, S) : INIT_FLOAT4(0.0f);
67   float4 src3 = (stz_in && stw_in) ? args.src_tensor.Read<float>(st.z, st.w, S) : INIT_FLOAT4(0.0f);
68     )";
69   }
70   c += "  FLT4 r0 = TO_FLT4(mix(mix(src0, src1, t.x), mix(src2, src3, t.x), "
71        "t.y));\n";
72   c += "  args.dst_tensor.Write(r0, X, Y, S);\n";
73   c += "}\n";
74   return c;
75 }
76 
77 }  // namespace
78 
CreateResampler(const GpuInfo & gpu_info,const OperationDef & definition)79 GPUOperation CreateResampler(const GpuInfo& gpu_info,
80                              const OperationDef& definition) {
81   GPUOperation op(definition);
82   op.AddSrcTensor("src_tensor", definition.src_tensors[0]);
83   op.AddSrcTensor("warp_tensor", definition.src_tensors[1]);
84   op.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
85   op.code_ = GetResamplerCode(gpu_info, definition);
86   op.tensor_to_grid_ = TensorToGrid::kWBToX_HDToY_SToZ;
87   return op;
88 }
89 
90 }  // namespace gpu
91 }  // namespace tflite
92