xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/floor_div.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <math.h>
16 #include <stddef.h>
17 #include <stdint.h>
18 
19 #include <functional>
20 
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow/lite/kernels/internal/reference/binary_function.h"
23 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
24 #include "tensorflow/lite/kernels/internal/tensor.h"
25 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
26 #include "tensorflow/lite/kernels/kernel_util.h"
27 
28 namespace tflite {
29 namespace ops {
30 namespace builtin {
31 namespace floor_div {
32 namespace {
33 
34 // Input/output tensor index.
35 constexpr int kInputTensor1 = 0;
36 constexpr int kInputTensor2 = 1;
37 constexpr int kOutputTensor = 0;
38 
39 // Op data for floor_div op.
40 struct OpData {
41   bool requires_broadcast;
42 };
43 
Init(TfLiteContext * context,const char * buffer,size_t length)44 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
45   auto* data = new OpData;
46   data->requires_broadcast = false;
47   return data;
48 }
49 
Free(TfLiteContext * context,void * buffer)50 void Free(TfLiteContext* context, void* buffer) {
51   delete reinterpret_cast<OpData*>(buffer);
52 }
53 
Prepare(TfLiteContext * context,TfLiteNode * node)54 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
55   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
56   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
57 
58   // Reinterprete the opaque data provided by user.
59   OpData* data = reinterpret_cast<OpData*>(node->user_data);
60 
61   const TfLiteTensor* input1;
62   TF_LITE_ENSURE_OK(context,
63                     GetInputSafe(context, node, kInputTensor1, &input1));
64   const TfLiteTensor* input2;
65   TF_LITE_ENSURE_OK(context,
66                     GetInputSafe(context, node, kInputTensor2, &input2));
67   TfLiteTensor* output;
68   TF_LITE_ENSURE_OK(context,
69                     GetOutputSafe(context, node, kOutputTensor, &output));
70 
71   TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
72 
73   const TfLiteType type = input1->type;
74   switch (type) {
75     case kTfLiteFloat32:
76     case kTfLiteInt32:
77       break;
78     default:
79       TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by floor_div.",
80                          TfLiteTypeGetName(type));
81       return kTfLiteError;
82   }
83   output->type = type;
84 
85   data->requires_broadcast = !HaveSameShapes(input1, input2);
86 
87   TfLiteIntArray* output_size = nullptr;
88   if (data->requires_broadcast) {
89     TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
90                                    context, input1, input2, &output_size));
91   } else {
92     output_size = TfLiteIntArrayCopy(input1->dims);
93   }
94 
95   return context->ResizeTensor(context, output, output_size);
96 }
97 
98 template <typename T>
EvalImpl(TfLiteContext * context,bool requires_broadcast,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output)99 TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
100                       const TfLiteTensor* input1, const TfLiteTensor* input2,
101                       TfLiteTensor* output) {
102   const T* denominator_data = GetTensorData<T>(input2);
103 
104   // Validate the denominator.
105   for (int i = 0; i < NumElements(input2); ++i) {
106     if (std::equal_to<T>()(denominator_data[i], 0)) {
107       TF_LITE_KERNEL_LOG(context, "Division by 0");
108       return kTfLiteError;
109     }
110   }
111   if (requires_broadcast) {
112     reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
113         GetTensorShape(input1), GetTensorData<T>(input1),
114         GetTensorShape(input2), denominator_data, GetTensorShape(output),
115         GetTensorData<T>(output), reference_ops::FloorDiv<T>);
116   } else {
117     reference_ops::BinaryFunction<T, T, T>(
118         GetTensorShape(input1), GetTensorData<T>(input1),
119         GetTensorShape(input2), GetTensorData<T>(input2),
120         GetTensorShape(output), GetTensorData<T>(output),
121         reference_ops::FloorDiv<T>);
122   }
123 
124   return kTfLiteOk;
125 }
126 
Eval(TfLiteContext * context,TfLiteNode * node)127 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
128   OpData* data = reinterpret_cast<OpData*>(node->user_data);
129 
130   const TfLiteTensor* input1;
131   TF_LITE_ENSURE_OK(context,
132                     GetInputSafe(context, node, kInputTensor1, &input1));
133   const TfLiteTensor* input2;
134   TF_LITE_ENSURE_OK(context,
135                     GetInputSafe(context, node, kInputTensor2, &input2));
136   TfLiteTensor* output;
137   TF_LITE_ENSURE_OK(context,
138                     GetOutputSafe(context, node, kOutputTensor, &output));
139 
140   switch (input1->type) {
141     case kTfLiteInt32: {
142       return EvalImpl<int32_t>(context, data->requires_broadcast, input1,
143                                input2, output);
144     }
145     case kTfLiteFloat32: {
146       return EvalImpl<float>(context, data->requires_broadcast, input1, input2,
147                              output);
148     }
149     default: {
150       TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by floor_div.",
151                          TfLiteTypeGetName(input1->type));
152       return kTfLiteError;
153     }
154   }
155 }
156 
157 }  // namespace
158 }  // namespace floor_div
159 
Register_FLOOR_DIV()160 TfLiteRegistration* Register_FLOOR_DIV() {
161   // Init, Free, Prepare, Eval are satisfying the Interface required by
162   // TfLiteRegistration.
163   static TfLiteRegistration r = {floor_div::Init, floor_div::Free,
164                                  floor_div::Prepare, floor_div::Eval};
165   return &r;
166 }
167 
168 }  // namespace builtin
169 }  // namespace ops
170 }  // namespace tflite
171