1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h"
16
17 #include <stdint.h>
18
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/compatibility.h"
21 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
22 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
23 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
24 #include "tensorflow/lite/kernels/internal/tensor.h"
25 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
26 #include "tensorflow/lite/kernels/internal/types.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28
29 namespace tflite {
30 namespace ops {
31 namespace builtin {
32 namespace maximum_minimum {
33
34 // This file has a reference implementation of TFMaximum/TFMinimum.
35 enum KernelType {
36 kReference,
37 kGenericOptimized,
38 };
39
40 constexpr int kInputTensor1 = 0;
41 constexpr int kInputTensor2 = 1;
42 constexpr int kOutputTensor = 0;
43
44 struct OpContext {
OpContexttflite::ops::builtin::maximum_minimum::OpContext45 OpContext(TfLiteContext* context, TfLiteNode* node) {
46 input1 = GetInput(context, node, kInputTensor1);
47 input2 = GetInput(context, node, kInputTensor2);
48 output = GetOutput(context, node, kOutputTensor);
49 }
50 const TfLiteTensor* input1;
51 const TfLiteTensor* input2;
52 TfLiteTensor* output;
53 };
54
Prepare(TfLiteContext * context,TfLiteNode * node)55 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
56 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
57 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
58
59 OpContext op_context(context, node);
60 TF_LITE_ENSURE_TYPES_EQ(context, op_context.input1->type,
61 op_context.input2->type);
62 op_context.output->type = op_context.input1->type;
63
64 bool requires_broadcast =
65 !HaveSameShapes(op_context.input1, op_context.input2);
66
67 TfLiteIntArray* output_size = nullptr;
68 if (requires_broadcast) {
69 TF_LITE_ENSURE_OK(
70 context, CalculateShapeForBroadcast(context, op_context.input1,
71 op_context.input2, &output_size));
72 } else {
73 output_size = TfLiteIntArrayCopy(op_context.input1->dims);
74 }
75
76 return context->ResizeTensor(context, op_context.output, output_size);
77 }
78
79 struct MaximumOp {
80 template <typename data_type>
optflite::ops::builtin::maximum_minimum::MaximumOp81 static data_type op(data_type el1, data_type el2) {
82 return el1 > el2 ? el1 : el2;
83 }
84 };
85
86 struct MinimumOp {
87 template <typename data_type>
optflite::ops::builtin::maximum_minimum::MinimumOp88 static data_type op(data_type el1, data_type el2) {
89 return el1 < el2 ? el1 : el2;
90 }
91 };
92
93 template <KernelType kernel_type, typename data_type, typename op_type>
TFLiteOperation(TfLiteContext * context,TfLiteNode * node,const OpContext & op_context)94 void TFLiteOperation(TfLiteContext* context, TfLiteNode* node,
95 const OpContext& op_context) {
96 reference_ops::MaximumMinimumBroadcastSlow(
97 GetTensorShape(op_context.input1),
98 GetTensorData<data_type>(op_context.input1),
99 GetTensorShape(op_context.input2),
100 GetTensorData<data_type>(op_context.input2),
101 GetTensorShape(op_context.output),
102 GetTensorData<data_type>(op_context.output),
103 op_type::template op<data_type>);
104 }
105
106 // Maximum generic opt int8.
107 template <>
TFLiteOperation(TfLiteContext * context,TfLiteNode * node,const OpContext & op_context)108 void TFLiteOperation<maximum_minimum::kGenericOptimized, int8, MaximumOp>(
109 TfLiteContext* context, TfLiteNode* node, const OpContext& op_context) {
110 tflite::ArithmeticParams op_params;
111 const bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
112 GetTensorShape(op_context.input1), GetTensorShape(op_context.input2),
113 &op_params);
114 if (need_broadcast) {
115 optimized_ops::BroadcastMaximumDispatch(
116 op_params, GetTensorShape(op_context.input1),
117 GetTensorData<int8>(op_context.input1),
118 GetTensorShape(op_context.input2),
119 GetTensorData<int8>(op_context.input2),
120 GetTensorShape(op_context.output),
121 GetTensorData<int8>(op_context.output), MaximumOp::template op<int8>);
122 return;
123 }
124 reference_ops::MaximumMinimumBroadcastSlow(
125 GetTensorShape(op_context.input1), GetTensorData<int8>(op_context.input1),
126 GetTensorShape(op_context.input2), GetTensorData<int8>(op_context.input2),
127 GetTensorShape(op_context.output), GetTensorData<int8>(op_context.output),
128 MaximumOp::template op<int8>);
129 }
130
131 // Minimum generic opt int8.
132 template <>
TFLiteOperation(TfLiteContext * context,TfLiteNode * node,const OpContext & op_context)133 void TFLiteOperation<maximum_minimum::kGenericOptimized, int8, MinimumOp>(
134 TfLiteContext* context, TfLiteNode* node, const OpContext& op_context) {
135 tflite::ArithmeticParams op_params;
136 const bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
137 GetTensorShape(op_context.input1), GetTensorShape(op_context.input2),
138 &op_params);
139 if (need_broadcast) {
140 optimized_ops::BroadcastMinimumDispatch(
141 op_params, GetTensorShape(op_context.input1),
142 GetTensorData<int8>(op_context.input1),
143 GetTensorShape(op_context.input2),
144 GetTensorData<int8>(op_context.input2),
145 GetTensorShape(op_context.output),
146 GetTensorData<int8>(op_context.output), MinimumOp::template op<int8>);
147 return;
148 }
149 reference_ops::MaximumMinimumBroadcastSlow(
150 GetTensorShape(op_context.input1), GetTensorData<int8>(op_context.input1),
151 GetTensorShape(op_context.input2), GetTensorData<int8>(op_context.input2),
152 GetTensorShape(op_context.output), GetTensorData<int8>(op_context.output),
153 MinimumOp::template op<int8>);
154 }
155
156 template <KernelType kernel_type, typename OpType>
Eval(TfLiteContext * context,TfLiteNode * node)157 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
158 OpContext op_context(context, node);
159
160 // If inputs have no element, shortcircuit.
161 if (NumElements(op_context.input1) == 0 ||
162 NumElements(op_context.input2) == 0) {
163 return kTfLiteOk;
164 }
165
166 switch (op_context.output->type) {
167 case kTfLiteFloat32:
168 TFLiteOperation<kernel_type, float, OpType>(context, node, op_context);
169 break;
170 case kTfLiteUInt8:
171 TFLiteOperation<kernel_type, uint8_t, OpType>(context, node, op_context);
172 break;
173 case kTfLiteInt8:
174 TFLiteOperation<kernel_type, int8_t, OpType>(context, node, op_context);
175 break;
176 case kTfLiteInt32:
177 TFLiteOperation<kernel_type, int32_t, OpType>(context, node, op_context);
178 break;
179 case kTfLiteInt64:
180 TFLiteOperation<kernel_type, int64_t, OpType>(context, node, op_context);
181 break;
182 case kTfLiteInt16:
183 TFLiteOperation<kernel_type, int16_t, OpType>(context, node, op_context);
184 break;
185 default:
186 TF_LITE_KERNEL_LOG(context,
187 "Type %d is currently not supported by Maximum.",
188 op_context.output->type);
189 return kTfLiteError;
190 }
191 return kTfLiteOk;
192 }
193
194 } // namespace maximum_minimum
195
Register_MAXIMUM_REF()196 TfLiteRegistration* Register_MAXIMUM_REF() {
197 static TfLiteRegistration r = {
198 nullptr, nullptr, maximum_minimum::Prepare,
199 maximum_minimum::Eval<maximum_minimum::kReference,
200 maximum_minimum::MaximumOp>};
201 return &r;
202 }
203
Register_MAXIMUM_GENERIC_OPT()204 TfLiteRegistration* Register_MAXIMUM_GENERIC_OPT() {
205 static TfLiteRegistration r = {
206 nullptr, nullptr, maximum_minimum::Prepare,
207 maximum_minimum::Eval<maximum_minimum::kGenericOptimized,
208 maximum_minimum::MaximumOp>};
209 return &r;
210 }
211
Register_MINIMUM_REF()212 TfLiteRegistration* Register_MINIMUM_REF() {
213 static TfLiteRegistration r = {
214 nullptr, nullptr, maximum_minimum::Prepare,
215 maximum_minimum::Eval<maximum_minimum::kReference,
216 maximum_minimum::MinimumOp>};
217 return &r;
218 }
219
Register_MINIMUM_GENERIC_OPT()220 TfLiteRegistration* Register_MINIMUM_GENERIC_OPT() {
221 static TfLiteRegistration r = {
222 nullptr, nullptr, maximum_minimum::Prepare,
223 maximum_minimum::Eval<maximum_minimum::kGenericOptimized,
224 maximum_minimum::MinimumOp>};
225 return &r;
226 }
227
Register_MAXIMUM()228 TfLiteRegistration* Register_MAXIMUM() {
229 return Register_MAXIMUM_GENERIC_OPT();
230 }
Register_MINIMUM()231 TfLiteRegistration* Register_MINIMUM() {
232 return Register_MINIMUM_GENERIC_OPT();
233 }
234
235 } // namespace builtin
236 } // namespace ops
237 } // namespace tflite
238