xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/resize_bilinear.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/compatibility.h"
20 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
21 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
22 // clang-format off: Clang-format thinks this header is paired.
23 #include "tensorflow/lite/kernels/internal/optimized/resize_bilinear.h"
24 // clang-format on
25 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
26 #include "tensorflow/lite/kernels/internal/tensor.h"
27 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
28 #include "tensorflow/lite/kernels/internal/types.h"
29 #include "tensorflow/lite/kernels/kernel_util.h"
30 
31 namespace tflite {
32 namespace ops {
33 namespace builtin {
34 namespace resize_bilinear {
35 
36 // This file has three implementation of RESIZE_BILINEAR.
37 enum KernelType {
38   kReference,
39   kOptimized,
40 };
41 
42 constexpr int kInputTensor = 0;
43 constexpr int kSizeTensor = 1;
44 constexpr int kOutputTensor = 0;
45 
ResizeOutputTensor(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * size,TfLiteTensor * output)46 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
47                                 const TfLiteTensor* input,
48                                 const TfLiteTensor* size,
49                                 TfLiteTensor* output) {
50   const int32* size_data = GetTensorData<int32>(size);
51   // Sanity check, the up/down sampling size should always be positive.
52   TF_LITE_ENSURE(context, size_data[0] > 0);
53   TF_LITE_ENSURE(context, size_data[1] > 0);
54   TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
55   output_size->data[0] = input->dims->data[0];
56   output_size->data[1] = size_data[0];
57   output_size->data[2] = size_data[1];
58   output_size->data[3] = input->dims->data[3];
59   return context->ResizeTensor(context, output, output_size);
60 }
61 
Prepare(TfLiteContext * context,TfLiteNode * node)62 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
63   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
64   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
65 
66   const TfLiteTensor* input;
67   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
68   const TfLiteTensor* size;
69   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
70   TfLiteTensor* output;
71   TF_LITE_ENSURE_OK(context,
72                     GetOutputSafe(context, node, kOutputTensor, &output));
73 
74   // TODO(ahentz): Our current implementations rely on the inputs being 4D.
75   TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
76   TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1);
77 
78   TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32);
79   // ResizeBilinear creates a float tensor even when the input is made of
80   // integers.
81   output->type = input->type;
82 
83   if (!IsConstantTensor(size)) {
84     SetTensorToDynamic(output);
85     return kTfLiteOk;
86   }
87 
88   // Ensure params are valid.
89   auto* params =
90       reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
91   if (params->half_pixel_centers && params->align_corners) {
92     TF_LITE_KERNEL_LOG(
93         context, "If half_pixel_centers is True, align_corners must be False.");
94     return kTfLiteError;
95   }
96 
97   return ResizeOutputTensor(context, input, size, output);
98 }
99 
100 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)101 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
102   auto* params =
103       reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
104 
105   const TfLiteTensor* input;
106   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
107   TfLiteTensor* output;
108   TF_LITE_ENSURE_OK(context,
109                     GetOutputSafe(context, node, kOutputTensor, &output));
110   const TfLiteTensor* size;
111   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSizeTensor, &size));
112 
113   if (IsDynamicTensor(output)) {
114     TF_LITE_ENSURE_OK(context,
115                       ResizeOutputTensor(context, input, size, output));
116   }
117 
118   if (output->type == kTfLiteFloat32) {
119 #define TF_LITE_RESIZE_BILINEAR(type, opname, datatype)              \
120   tflite::ResizeBilinearParams op_params;                            \
121   op_params.align_corners = params->align_corners;                   \
122   op_params.half_pixel_centers = params->half_pixel_centers;         \
123   type::opname(op_params, GetTensorShape(input),                     \
124                GetTensorData<datatype>(input), GetTensorShape(size), \
125                GetTensorData<int32>(size), GetTensorShape(output),   \
126                GetTensorData<datatype>(output))
127 
128     if (kernel_type == kReference) {
129       TF_LITE_RESIZE_BILINEAR(reference_ops, ResizeBilinear, float);
130     } else if (kernel_type == kOptimized) {
131       TF_LITE_RESIZE_BILINEAR(optimized_ops, ResizeBilinear, float);
132     }
133   } else if (output->type == kTfLiteUInt8) {
134     if (kernel_type == kReference) {
135       TF_LITE_RESIZE_BILINEAR(reference_ops, ResizeBilinear, uint8_t);
136     } else if (kernel_type == kOptimized) {
137       TF_LITE_RESIZE_BILINEAR(optimized_ops, ResizeBilinear, uint8_t);
138     }
139   } else if (output->type == kTfLiteInt8) {
140     if (kernel_type == kReference) {
141       TF_LITE_RESIZE_BILINEAR(reference_ops, ResizeBilinearInteger, int8_t);
142     } else if (kernel_type == kOptimized) {
143       TF_LITE_RESIZE_BILINEAR(optimized_ops, ResizeBilinear, int8_t);
144     }
145   } else if (output->type == kTfLiteInt16) {
146     TF_LITE_RESIZE_BILINEAR(reference_ops, ResizeBilinearInteger, int16_t);
147 #undef TF_LITE_RESIZE_BILINEAR
148   } else {
149     TF_LITE_KERNEL_LOG(context, "Output type is %d, requires float.",
150                        output->type);
151     return kTfLiteError;
152   }
153 
154   return kTfLiteOk;
155 }
156 
157 }  // namespace resize_bilinear
158 
Register_RESIZE_BILINEAR_REF()159 TfLiteRegistration* Register_RESIZE_BILINEAR_REF() {
160   static TfLiteRegistration r = {
161       nullptr, nullptr, resize_bilinear::Prepare,
162       resize_bilinear::Eval<resize_bilinear::kReference>};
163   return &r;
164 }
165 
Register_RESIZE_BILINEAR()166 TfLiteRegistration* Register_RESIZE_BILINEAR() {
167   static TfLiteRegistration r = {
168       nullptr, nullptr, resize_bilinear::Prepare,
169       resize_bilinear::Eval<resize_bilinear::kOptimized>};
170   return &r;
171 }
172 
173 }  // namespace builtin
174 }  // namespace ops
175 }  // namespace tflite
176