1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_RESIZE_NEAREST_NEIGHBOR_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_RESIZE_NEAREST_NEIGHBOR_H_
17
18 #include <algorithm>
19 #include <cmath>
20
21 #include "tensorflow/lite/kernels/internal/cppmath.h"
22 #include "tensorflow/lite/kernels/internal/types.h"
23
24 namespace tflite {
25
26 namespace reference_ops {
27
GetNearestNeighbor(const int input_value,const int32_t input_size,const int32_t output_size,const bool align_corners,const bool half_pixel_centers)28 inline int32_t GetNearestNeighbor(const int input_value,
29 const int32_t input_size,
30 const int32_t output_size,
31 const bool align_corners,
32 const bool half_pixel_centers) {
33 const float scale =
34 (align_corners && output_size > 1)
35 ? (input_size - 1) / static_cast<float>(output_size - 1)
36 : input_size / static_cast<float>(output_size);
37 const float offset = half_pixel_centers ? 0.5f : 0.0f;
38 int32_t output_value = std::min(
39 align_corners
40 ? static_cast<int32_t>(TfLiteRound((input_value + offset) * scale))
41 : static_cast<int32_t>(std::floor((input_value + offset) * scale)),
42 input_size - 1);
43 if (half_pixel_centers) {
44 output_value = std::max(static_cast<int32_t>(0), output_value);
45 }
46 return output_value;
47 }
48
49 template <typename T>
ResizeNearestNeighbor(const tflite::ResizeNearestNeighborParams & op_params,const RuntimeShape & unextended_input_shape,const T * input_data,const RuntimeShape & output_size_shape,const int32_t * output_size_data,const RuntimeShape & unextended_output_shape,T * output_data)50 inline void ResizeNearestNeighbor(
51 const tflite::ResizeNearestNeighborParams& op_params,
52 const RuntimeShape& unextended_input_shape, const T* input_data,
53 const RuntimeShape& output_size_shape, const int32_t* output_size_data,
54 const RuntimeShape& unextended_output_shape, T* output_data) {
55 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
56 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
57
58 const RuntimeShape input_shape =
59 RuntimeShape::ExtendedShape(4, unextended_input_shape);
60 const RuntimeShape output_shape =
61 RuntimeShape::ExtendedShape(4, unextended_output_shape);
62
63 int32_t batches = MatchingDim(input_shape, 0, output_shape, 0);
64 int32_t input_height = input_shape.Dims(1);
65 int32_t input_width = input_shape.Dims(2);
66 int32_t depth = MatchingDim(input_shape, 3, output_shape, 3);
67
68 // The Tensorflow version of this op allows resize on the width and height
69 // axis only.
70 TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
71 int32_t output_height = output_size_data[0];
72 int32_t output_width = output_size_data[1];
73
74 const int col_offset = input_shape.Dims(3);
75 const int row_offset = input_shape.Dims(2) * col_offset;
76 const int batch_offset = input_shape.Dims(1) * row_offset;
77
78 const T* input_ptr = input_data;
79 T* output_ptr = output_data;
80 for (int b = 0; b < batches; ++b) {
81 for (int y = 0; y < output_height; ++y) {
82 int32_t in_y = GetNearestNeighbor(y, input_height, output_height,
83 op_params.align_corners,
84 op_params.half_pixel_centers);
85 const T* y_input_ptr = input_ptr + in_y * row_offset;
86 for (int x = 0; x < output_width; ++x) {
87 int32_t in_x = GetNearestNeighbor(x, input_width, output_width,
88 op_params.align_corners,
89 op_params.half_pixel_centers);
90 const T* x_input_ptr = y_input_ptr + in_x * col_offset;
91 memcpy(output_ptr, x_input_ptr, depth * sizeof(T));
92 output_ptr += depth;
93 }
94 }
95 input_ptr += batch_offset;
96 }
97 }
98
99 } // namespace reference_ops
100 } // namespace tflite
101
102 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_RESIZE_NEAREST_NEIGHBOR_H_
103