xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/maximum_minimum.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h"
16 
17 #include <stdint.h>
18 
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/compatibility.h"
21 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
22 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
23 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
24 #include "tensorflow/lite/kernels/internal/tensor.h"
25 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
26 #include "tensorflow/lite/kernels/internal/types.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28 
29 namespace tflite {
30 namespace ops {
31 namespace builtin {
32 namespace maximum_minimum {
33 
34 // This file has a reference implementation of TFMaximum/TFMinimum.
35 enum KernelType {
36   kReference,
37   kGenericOptimized,
38 };
39 
40 constexpr int kInputTensor1 = 0;
41 constexpr int kInputTensor2 = 1;
42 constexpr int kOutputTensor = 0;
43 
44 struct OpContext {
OpContexttflite::ops::builtin::maximum_minimum::OpContext45   OpContext(TfLiteContext* context, TfLiteNode* node) {
46     input1 = GetInput(context, node, kInputTensor1);
47     input2 = GetInput(context, node, kInputTensor2);
48     output = GetOutput(context, node, kOutputTensor);
49   }
50   const TfLiteTensor* input1;
51   const TfLiteTensor* input2;
52   TfLiteTensor* output;
53 };
54 
Prepare(TfLiteContext * context,TfLiteNode * node)55 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
56   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
57   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
58 
59   OpContext op_context(context, node);
60   TF_LITE_ENSURE_TYPES_EQ(context, op_context.input1->type,
61                           op_context.input2->type);
62   op_context.output->type = op_context.input1->type;
63 
64   bool requires_broadcast =
65       !HaveSameShapes(op_context.input1, op_context.input2);
66 
67   TfLiteIntArray* output_size = nullptr;
68   if (requires_broadcast) {
69     TF_LITE_ENSURE_OK(
70         context, CalculateShapeForBroadcast(context, op_context.input1,
71                                             op_context.input2, &output_size));
72   } else {
73     output_size = TfLiteIntArrayCopy(op_context.input1->dims);
74   }
75 
76   return context->ResizeTensor(context, op_context.output, output_size);
77 }
78 
79 struct MaximumOp {
80   template <typename data_type>
optflite::ops::builtin::maximum_minimum::MaximumOp81   static data_type op(data_type el1, data_type el2) {
82     return el1 > el2 ? el1 : el2;
83   }
84 };
85 
86 struct MinimumOp {
87   template <typename data_type>
optflite::ops::builtin::maximum_minimum::MinimumOp88   static data_type op(data_type el1, data_type el2) {
89     return el1 < el2 ? el1 : el2;
90   }
91 };
92 
93 template <KernelType kernel_type, typename data_type, typename op_type>
TFLiteOperation(TfLiteContext * context,TfLiteNode * node,const OpContext & op_context)94 void TFLiteOperation(TfLiteContext* context, TfLiteNode* node,
95                      const OpContext& op_context) {
96   reference_ops::MaximumMinimumBroadcastSlow(
97       GetTensorShape(op_context.input1),
98       GetTensorData<data_type>(op_context.input1),
99       GetTensorShape(op_context.input2),
100       GetTensorData<data_type>(op_context.input2),
101       GetTensorShape(op_context.output),
102       GetTensorData<data_type>(op_context.output),
103       op_type::template op<data_type>);
104 }
105 
106 // Maximum generic opt int8.
107 template <>
TFLiteOperation(TfLiteContext * context,TfLiteNode * node,const OpContext & op_context)108 void TFLiteOperation<maximum_minimum::kGenericOptimized, int8, MaximumOp>(
109     TfLiteContext* context, TfLiteNode* node, const OpContext& op_context) {
110   tflite::ArithmeticParams op_params;
111   const bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
112       GetTensorShape(op_context.input1), GetTensorShape(op_context.input2),
113       &op_params);
114   if (need_broadcast) {
115     optimized_ops::BroadcastMaximumDispatch(
116         op_params, GetTensorShape(op_context.input1),
117         GetTensorData<int8>(op_context.input1),
118         GetTensorShape(op_context.input2),
119         GetTensorData<int8>(op_context.input2),
120         GetTensorShape(op_context.output),
121         GetTensorData<int8>(op_context.output), MaximumOp::template op<int8>);
122     return;
123   }
124   reference_ops::MaximumMinimumBroadcastSlow(
125       GetTensorShape(op_context.input1), GetTensorData<int8>(op_context.input1),
126       GetTensorShape(op_context.input2), GetTensorData<int8>(op_context.input2),
127       GetTensorShape(op_context.output), GetTensorData<int8>(op_context.output),
128       MaximumOp::template op<int8>);
129 }
130 
131 // Minimum generic opt int8.
132 template <>
TFLiteOperation(TfLiteContext * context,TfLiteNode * node,const OpContext & op_context)133 void TFLiteOperation<maximum_minimum::kGenericOptimized, int8, MinimumOp>(
134     TfLiteContext* context, TfLiteNode* node, const OpContext& op_context) {
135   tflite::ArithmeticParams op_params;
136   const bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
137       GetTensorShape(op_context.input1), GetTensorShape(op_context.input2),
138       &op_params);
139   if (need_broadcast) {
140     optimized_ops::BroadcastMinimumDispatch(
141         op_params, GetTensorShape(op_context.input1),
142         GetTensorData<int8>(op_context.input1),
143         GetTensorShape(op_context.input2),
144         GetTensorData<int8>(op_context.input2),
145         GetTensorShape(op_context.output),
146         GetTensorData<int8>(op_context.output), MinimumOp::template op<int8>);
147     return;
148   }
149   reference_ops::MaximumMinimumBroadcastSlow(
150       GetTensorShape(op_context.input1), GetTensorData<int8>(op_context.input1),
151       GetTensorShape(op_context.input2), GetTensorData<int8>(op_context.input2),
152       GetTensorShape(op_context.output), GetTensorData<int8>(op_context.output),
153       MinimumOp::template op<int8>);
154 }
155 
156 template <KernelType kernel_type, typename OpType>
Eval(TfLiteContext * context,TfLiteNode * node)157 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
158   OpContext op_context(context, node);
159 
160   // If inputs have no element, shortcircuit.
161   if (NumElements(op_context.input1) == 0 ||
162       NumElements(op_context.input2) == 0) {
163     return kTfLiteOk;
164   }
165 
166   switch (op_context.output->type) {
167     case kTfLiteFloat32:
168       TFLiteOperation<kernel_type, float, OpType>(context, node, op_context);
169       break;
170     case kTfLiteUInt8:
171       TFLiteOperation<kernel_type, uint8_t, OpType>(context, node, op_context);
172       break;
173     case kTfLiteInt8:
174       TFLiteOperation<kernel_type, int8_t, OpType>(context, node, op_context);
175       break;
176     case kTfLiteInt32:
177       TFLiteOperation<kernel_type, int32_t, OpType>(context, node, op_context);
178       break;
179     case kTfLiteInt64:
180       TFLiteOperation<kernel_type, int64_t, OpType>(context, node, op_context);
181       break;
182     case kTfLiteInt16:
183       TFLiteOperation<kernel_type, int16_t, OpType>(context, node, op_context);
184       break;
185     default:
186       TF_LITE_KERNEL_LOG(context,
187                          "Type %d is currently not supported by Maximum.",
188                          op_context.output->type);
189       return kTfLiteError;
190   }
191   return kTfLiteOk;
192 }
193 
194 }  // namespace maximum_minimum
195 
Register_MAXIMUM_REF()196 TfLiteRegistration* Register_MAXIMUM_REF() {
197   static TfLiteRegistration r = {
198       nullptr, nullptr, maximum_minimum::Prepare,
199       maximum_minimum::Eval<maximum_minimum::kReference,
200                             maximum_minimum::MaximumOp>};
201   return &r;
202 }
203 
Register_MAXIMUM_GENERIC_OPT()204 TfLiteRegistration* Register_MAXIMUM_GENERIC_OPT() {
205   static TfLiteRegistration r = {
206       nullptr, nullptr, maximum_minimum::Prepare,
207       maximum_minimum::Eval<maximum_minimum::kGenericOptimized,
208                             maximum_minimum::MaximumOp>};
209   return &r;
210 }
211 
Register_MINIMUM_REF()212 TfLiteRegistration* Register_MINIMUM_REF() {
213   static TfLiteRegistration r = {
214       nullptr, nullptr, maximum_minimum::Prepare,
215       maximum_minimum::Eval<maximum_minimum::kReference,
216                             maximum_minimum::MinimumOp>};
217   return &r;
218 }
219 
Register_MINIMUM_GENERIC_OPT()220 TfLiteRegistration* Register_MINIMUM_GENERIC_OPT() {
221   static TfLiteRegistration r = {
222       nullptr, nullptr, maximum_minimum::Prepare,
223       maximum_minimum::Eval<maximum_minimum::kGenericOptimized,
224                             maximum_minimum::MinimumOp>};
225   return &r;
226 }
227 
Register_MAXIMUM()228 TfLiteRegistration* Register_MAXIMUM() {
229   return Register_MAXIMUM_GENERIC_OPT();
230 }
Register_MINIMUM()231 TfLiteRegistration* Register_MINIMUM() {
232   return Register_MINIMUM_GENERIC_OPT();
233 }
234 
235 }  // namespace builtin
236 }  // namespace ops
237 }  // namespace tflite
238