xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/logical.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stddef.h>
16 
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/reference/binary_function.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace logical {
28 namespace {
29 
30 // Input/output tensor index.
31 constexpr int kInputTensor1 = 0;
32 constexpr int kInputTensor2 = 1;
33 constexpr int kOutputTensor = 0;
34 
35 // Op data for logical op.
36 struct OpData {
37   bool requires_broadcast;
38 };
39 
Init(TfLiteContext * context,const char * buffer,size_t length)40 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
41   auto* data = new OpData;
42   data->requires_broadcast = false;
43   return data;
44 }
45 
Free(TfLiteContext * context,void * buffer)46 void Free(TfLiteContext* context, void* buffer) {
47   delete reinterpret_cast<OpData*>(buffer);
48 }
49 
Prepare(TfLiteContext * context,TfLiteNode * node)50 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
51   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
52   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
53 
54   // Reinterprete the opaque data provided by user.
55   OpData* data = reinterpret_cast<OpData*>(node->user_data);
56 
57   const TfLiteTensor* input1;
58   TF_LITE_ENSURE_OK(context,
59                     GetInputSafe(context, node, kInputTensor1, &input1));
60   const TfLiteTensor* input2;
61   TF_LITE_ENSURE_OK(context,
62                     GetInputSafe(context, node, kInputTensor2, &input2));
63   TfLiteTensor* output;
64   TF_LITE_ENSURE_OK(context,
65                     GetOutputSafe(context, node, kOutputTensor, &output));
66 
67   TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
68 
69   const TfLiteType type = input1->type;
70   if (type != kTfLiteBool) {
71     TF_LITE_KERNEL_LOG(context, "Logical ops only support bool type.");
72     return kTfLiteError;
73   }
74   output->type = type;
75 
76   data->requires_broadcast = !HaveSameShapes(input1, input2);
77 
78   TfLiteIntArray* output_size = nullptr;
79   if (data->requires_broadcast) {
80     TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
81                                    context, input1, input2, &output_size));
82   } else {
83     output_size = TfLiteIntArrayCopy(input1->dims);
84   }
85 
86   return context->ResizeTensor(context, output, output_size);
87 }
88 
LogicalImpl(TfLiteContext * context,TfLiteNode * node,bool (* func)(bool,bool))89 TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
90                          bool (*func)(bool, bool)) {
91   OpData* data = reinterpret_cast<OpData*>(node->user_data);
92 
93   const TfLiteTensor* input1;
94   TF_LITE_ENSURE_OK(context,
95                     GetInputSafe(context, node, kInputTensor1, &input1));
96   const TfLiteTensor* input2;
97   TF_LITE_ENSURE_OK(context,
98                     GetInputSafe(context, node, kInputTensor2, &input2));
99   TfLiteTensor* output;
100   TF_LITE_ENSURE_OK(context,
101                     GetOutputSafe(context, node, kOutputTensor, &output));
102 
103   if (data->requires_broadcast) {
104     reference_ops::BroadcastBinaryFunction4DSlow<bool, bool, bool>(
105         GetTensorShape(input1), GetTensorData<bool>(input1),
106         GetTensorShape(input2), GetTensorData<bool>(input2),
107         GetTensorShape(output), GetTensorData<bool>(output), func);
108   } else {
109     reference_ops::BinaryFunction<bool, bool, bool>(
110         GetTensorShape(input1), GetTensorData<bool>(input1),
111         GetTensorShape(input2), GetTensorData<bool>(input2),
112         GetTensorShape(output), GetTensorData<bool>(output), func);
113   }
114 
115   return kTfLiteOk;
116 }
117 
LogicalOr(bool x,bool y)118 bool LogicalOr(bool x, bool y) { return x || y; }
119 
LogicalOrEval(TfLiteContext * context,TfLiteNode * node)120 TfLiteStatus LogicalOrEval(TfLiteContext* context, TfLiteNode* node) {
121   return LogicalImpl(context, node, LogicalOr);
122 }
123 
LogicalAnd(bool x,bool y)124 bool LogicalAnd(bool x, bool y) { return x && y; }
125 
LogicalAndEval(TfLiteContext * context,TfLiteNode * node)126 TfLiteStatus LogicalAndEval(TfLiteContext* context, TfLiteNode* node) {
127   return LogicalImpl(context, node, LogicalAnd);
128 }
129 
130 }  // namespace
131 }  // namespace logical
132 
Register_LOGICAL_OR()133 TfLiteRegistration* Register_LOGICAL_OR() {
134   // Init, Free, Prepare, Eval are satisfying the Interface required by
135   // TfLiteRegistration.
136   static TfLiteRegistration r = {logical::Init, logical::Free, logical::Prepare,
137                                  logical::LogicalOrEval};
138   return &r;
139 }
140 
Register_LOGICAL_AND()141 TfLiteRegistration* Register_LOGICAL_AND() {
142   // Init, Free, Prepare, Eval are satisfying the Interface required by
143   // TfLiteRegistration.
144   static TfLiteRegistration r = {logical::Init, logical::Free, logical::Prepare,
145                                  logical::LogicalAndEval};
146   return &r;
147 }
148 
149 }  // namespace builtin
150 }  // namespace ops
151 }  // namespace tflite
152