1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <math.h>
16 #include <stddef.h>
17 #include <stdint.h>
18
19 #include <functional>
20
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/internal/reference/binary_function.h"
23 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
24 #include "tensorflow/lite/kernels/internal/tensor.h"
25 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
26 #include "tensorflow/lite/kernels/kernel_util.h"
27
28 namespace tflite {
29 namespace ops {
30 namespace builtin {
31 namespace floor_div {
32 namespace {
33
34 // Input/output tensor index.
35 constexpr int kInputTensor1 = 0;
36 constexpr int kInputTensor2 = 1;
37 constexpr int kOutputTensor = 0;
38
39 // Op data for floor_div op.
40 struct OpData {
41 bool requires_broadcast;
42 };
43
Init(TfLiteContext * context,const char * buffer,size_t length)44 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
45 auto* data = new OpData;
46 data->requires_broadcast = false;
47 return data;
48 }
49
Free(TfLiteContext * context,void * buffer)50 void Free(TfLiteContext* context, void* buffer) {
51 delete reinterpret_cast<OpData*>(buffer);
52 }
53
Prepare(TfLiteContext * context,TfLiteNode * node)54 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
55 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
56 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
57
58 // Reinterprete the opaque data provided by user.
59 OpData* data = reinterpret_cast<OpData*>(node->user_data);
60
61 const TfLiteTensor* input1;
62 TF_LITE_ENSURE_OK(context,
63 GetInputSafe(context, node, kInputTensor1, &input1));
64 const TfLiteTensor* input2;
65 TF_LITE_ENSURE_OK(context,
66 GetInputSafe(context, node, kInputTensor2, &input2));
67 TfLiteTensor* output;
68 TF_LITE_ENSURE_OK(context,
69 GetOutputSafe(context, node, kOutputTensor, &output));
70
71 TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
72
73 const TfLiteType type = input1->type;
74 switch (type) {
75 case kTfLiteFloat32:
76 case kTfLiteInt32:
77 break;
78 default:
79 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by floor_div.",
80 TfLiteTypeGetName(type));
81 return kTfLiteError;
82 }
83 output->type = type;
84
85 data->requires_broadcast = !HaveSameShapes(input1, input2);
86
87 TfLiteIntArray* output_size = nullptr;
88 if (data->requires_broadcast) {
89 TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
90 context, input1, input2, &output_size));
91 } else {
92 output_size = TfLiteIntArrayCopy(input1->dims);
93 }
94
95 return context->ResizeTensor(context, output, output_size);
96 }
97
98 template <typename T>
EvalImpl(TfLiteContext * context,bool requires_broadcast,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output)99 TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
100 const TfLiteTensor* input1, const TfLiteTensor* input2,
101 TfLiteTensor* output) {
102 const T* denominator_data = GetTensorData<T>(input2);
103
104 // Validate the denominator.
105 for (int i = 0; i < NumElements(input2); ++i) {
106 if (std::equal_to<T>()(denominator_data[i], 0)) {
107 TF_LITE_KERNEL_LOG(context, "Division by 0");
108 return kTfLiteError;
109 }
110 }
111 if (requires_broadcast) {
112 reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
113 GetTensorShape(input1), GetTensorData<T>(input1),
114 GetTensorShape(input2), denominator_data, GetTensorShape(output),
115 GetTensorData<T>(output), reference_ops::FloorDiv<T>);
116 } else {
117 reference_ops::BinaryFunction<T, T, T>(
118 GetTensorShape(input1), GetTensorData<T>(input1),
119 GetTensorShape(input2), GetTensorData<T>(input2),
120 GetTensorShape(output), GetTensorData<T>(output),
121 reference_ops::FloorDiv<T>);
122 }
123
124 return kTfLiteOk;
125 }
126
Eval(TfLiteContext * context,TfLiteNode * node)127 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
128 OpData* data = reinterpret_cast<OpData*>(node->user_data);
129
130 const TfLiteTensor* input1;
131 TF_LITE_ENSURE_OK(context,
132 GetInputSafe(context, node, kInputTensor1, &input1));
133 const TfLiteTensor* input2;
134 TF_LITE_ENSURE_OK(context,
135 GetInputSafe(context, node, kInputTensor2, &input2));
136 TfLiteTensor* output;
137 TF_LITE_ENSURE_OK(context,
138 GetOutputSafe(context, node, kOutputTensor, &output));
139
140 switch (input1->type) {
141 case kTfLiteInt32: {
142 return EvalImpl<int32_t>(context, data->requires_broadcast, input1,
143 input2, output);
144 }
145 case kTfLiteFloat32: {
146 return EvalImpl<float>(context, data->requires_broadcast, input1, input2,
147 output);
148 }
149 default: {
150 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by floor_div.",
151 TfLiteTypeGetName(input1->type));
152 return kTfLiteError;
153 }
154 }
155 }
156
157 } // namespace
158 } // namespace floor_div
159
Register_FLOOR_DIV()160 TfLiteRegistration* Register_FLOOR_DIV() {
161 // Init, Free, Prepare, Eval are satisfying the Interface required by
162 // TfLiteRegistration.
163 static TfLiteRegistration r = {floor_div::Init, floor_div::Free,
164 floor_div::Prepare, floor_div::Eval};
165 return &r;
166 }
167
168 } // namespace builtin
169 } // namespace ops
170 } // namespace tflite
171