xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/reverse.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stdint.h>
17 
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace reverse {
28 namespace {
29 
30 constexpr int kInputTensor = 0;
31 constexpr int kAxisTensor = 1;
32 constexpr int kOutputTensor = 0;
33 
Prepare(TfLiteContext * context,TfLiteNode * node)34 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
35   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
36   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
37 
38   const TfLiteTensor* input;
39   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
40   const TfLiteTensor* axis;
41   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxisTensor, &axis));
42   TF_LITE_ENSURE_EQ(context, NumDimensions(axis), 1);
43   TF_LITE_ENSURE(context, NumDimensions(input) >= NumElements(axis));
44 
45   if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
46       input->type != kTfLiteUInt8 && input->type != kTfLiteInt8 &&
47       input->type != kTfLiteInt16 && input->type != kTfLiteInt64 &&
48       input->type != kTfLiteBool) {
49     TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by reverse.",
50                        TfLiteTypeGetName(input->type));
51     return kTfLiteError;
52   }
53 
54   if (axis->type != kTfLiteInt32) {
55     TF_LITE_KERNEL_LOG(context, "Axis Type '%s' is not supported by reverse.",
56                        TfLiteTypeGetName(axis->type));
57     return kTfLiteError;
58   }
59 
60   // TODO(b/186320180): support multi-axis case.
61   if (NumElements(axis) > 1) {
62     TF_LITE_KERNEL_LOG(context, "Current does not support more than 1 axis.");
63   }
64 
65   TfLiteTensor* output;
66   TF_LITE_ENSURE_OK(context,
67                     GetOutputSafe(context, node, kOutputTensor, &output));
68   TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
69   TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
70 
71   return context->ResizeTensor(context, output, output_shape);
72 }
73 
Eval(TfLiteContext * context,TfLiteNode * node)74 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
75   const TfLiteTensor* input;
76   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
77   const TfLiteTensor* axis_tensor;
78   TF_LITE_ENSURE_OK(context,
79                     GetInputSafe(context, node, kAxisTensor, &axis_tensor));
80   int axis = GetTensorData<int32_t>(axis_tensor)[0];
81   const int rank = NumDimensions(input);
82   if (axis < 0) {
83     axis += rank;
84   }
85 
86   TF_LITE_ENSURE(context, axis >= 0 && axis < rank);
87   TfLiteTensor* output;
88   TF_LITE_ENSURE_OK(context,
89                     GetOutputSafe(context, node, kOutputTensor, &output));
90 
91   switch (output->type) {
92     case kTfLiteFloat32: {
93       reference_ops::Reverse<float>(
94           axis, GetTensorShape(input), GetTensorData<float>(input),
95           GetTensorShape(output), GetTensorData<float>(output));
96       break;
97     }
98     case kTfLiteUInt8:
99     case kTfLiteInt8: {
100       reference_ops::Reverse<uint8_t>(
101           axis, GetTensorShape(input), GetTensorData<uint8_t>(input),
102           GetTensorShape(output), GetTensorData<uint8_t>(output));
103       break;
104     }
105     case kTfLiteInt16: {
106       reference_ops::Reverse<int16_t>(
107           axis, GetTensorShape(input), GetTensorData<int16_t>(input),
108           GetTensorShape(output), GetTensorData<int16_t>(output));
109       break;
110     }
111     case kTfLiteInt32: {
112       reference_ops::Reverse<int32_t>(
113           axis, GetTensorShape(input), GetTensorData<int32_t>(input),
114           GetTensorShape(output), GetTensorData<int32_t>(output));
115       break;
116     }
117     case kTfLiteInt64: {
118       reference_ops::Reverse<int64_t>(
119           axis, GetTensorShape(input), GetTensorData<int64_t>(input),
120           GetTensorShape(output), GetTensorData<int64_t>(output));
121       break;
122     }
123     case kTfLiteBool: {
124       reference_ops::Reverse<bool>(
125           axis, GetTensorShape(input), GetTensorData<bool>(input),
126           GetTensorShape(output), GetTensorData<bool>(output));
127       break;
128     }
129     default: {
130       TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by reverse.",
131                          TfLiteTypeGetName(output->type));
132       return kTfLiteError;
133     }
134   }
135 
136   return kTfLiteOk;
137 }
138 
139 }  // namespace
140 }  // namespace reverse
141 
Register_REVERSE_V2()142 TfLiteRegistration* Register_REVERSE_V2() {
143   static TfLiteRegistration r = {nullptr, nullptr, reverse::Prepare,
144                                  reverse::Eval};
145   return &r;
146 }
147 
148 }  // namespace builtin
149 }  // namespace ops
150 }  // namespace tflite
151