1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020, 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "ElementwiseUnaryTestHelper.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "LogicalTestHelper.hpp"
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flatbuffers.h>
12*89c4ff92SAndroid Build Coastguard Worker #include <schema_generated.h>
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
15*89c4ff92SAndroid Build Coastguard Worker
16*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker
LogicalBinaryAndBoolTest(std::vector<armnn::BackendId> & backends)19*89c4ff92SAndroid Build Coastguard Worker void LogicalBinaryAndBoolTest(std::vector<armnn::BackendId>& backends)
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> input0Shape { 1, 2, 2 };
22*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> input1Shape { 1, 2, 2 };
23*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> expectedOutputShape { 1, 2, 2 };
24*89c4ff92SAndroid Build Coastguard Worker
25*89c4ff92SAndroid Build Coastguard Worker // Set input and output values
26*89c4ff92SAndroid Build Coastguard Worker std::vector<bool> input0Values { 0, 0, 1, 1 };
27*89c4ff92SAndroid Build Coastguard Worker std::vector<bool> input1Values { 0, 1, 0, 1 };
28*89c4ff92SAndroid Build Coastguard Worker std::vector<bool> expectedOutputValues { 0, 0, 0, 1 };
29*89c4ff92SAndroid Build Coastguard Worker
30*89c4ff92SAndroid Build Coastguard Worker LogicalBinaryTest(tflite::BuiltinOperator_LOGICAL_AND,
31*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_BOOL,
32*89c4ff92SAndroid Build Coastguard Worker backends,
33*89c4ff92SAndroid Build Coastguard Worker input0Shape,
34*89c4ff92SAndroid Build Coastguard Worker input1Shape,
35*89c4ff92SAndroid Build Coastguard Worker expectedOutputShape,
36*89c4ff92SAndroid Build Coastguard Worker input0Values,
37*89c4ff92SAndroid Build Coastguard Worker input1Values,
38*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues);
39*89c4ff92SAndroid Build Coastguard Worker }
40*89c4ff92SAndroid Build Coastguard Worker
LogicalBinaryAndBroadcastTest(std::vector<armnn::BackendId> & backends)41*89c4ff92SAndroid Build Coastguard Worker void LogicalBinaryAndBroadcastTest(std::vector<armnn::BackendId>& backends)
42*89c4ff92SAndroid Build Coastguard Worker {
43*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> input0Shape { 1, 2, 2 };
44*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> input1Shape { 1, 1, 1 };
45*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> expectedOutputShape { 1, 2, 2 };
46*89c4ff92SAndroid Build Coastguard Worker
47*89c4ff92SAndroid Build Coastguard Worker std::vector<bool> input0Values { 0, 1, 0, 1 };
48*89c4ff92SAndroid Build Coastguard Worker std::vector<bool> input1Values { 1 };
49*89c4ff92SAndroid Build Coastguard Worker std::vector<bool> expectedOutputValues { 0, 1, 0, 1 };
50*89c4ff92SAndroid Build Coastguard Worker
51*89c4ff92SAndroid Build Coastguard Worker LogicalBinaryTest(tflite::BuiltinOperator_LOGICAL_AND,
52*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_BOOL,
53*89c4ff92SAndroid Build Coastguard Worker backends,
54*89c4ff92SAndroid Build Coastguard Worker input0Shape,
55*89c4ff92SAndroid Build Coastguard Worker input1Shape,
56*89c4ff92SAndroid Build Coastguard Worker expectedOutputShape,
57*89c4ff92SAndroid Build Coastguard Worker input0Values,
58*89c4ff92SAndroid Build Coastguard Worker input1Values,
59*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues);
60*89c4ff92SAndroid Build Coastguard Worker }
61*89c4ff92SAndroid Build Coastguard Worker
LogicalBinaryOrBoolTest(std::vector<armnn::BackendId> & backends)62*89c4ff92SAndroid Build Coastguard Worker void LogicalBinaryOrBoolTest(std::vector<armnn::BackendId>& backends)
63*89c4ff92SAndroid Build Coastguard Worker {
64*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> input0Shape { 1, 2, 2 };
65*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> input1Shape { 1, 2, 2 };
66*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> expectedOutputShape { 1, 2, 2 };
67*89c4ff92SAndroid Build Coastguard Worker
68*89c4ff92SAndroid Build Coastguard Worker std::vector<bool> input0Values { 0, 0, 1, 1 };
69*89c4ff92SAndroid Build Coastguard Worker std::vector<bool> input1Values { 0, 1, 0, 1 };
70*89c4ff92SAndroid Build Coastguard Worker std::vector<bool> expectedOutputValues { 0, 1, 1, 1 };
71*89c4ff92SAndroid Build Coastguard Worker
72*89c4ff92SAndroid Build Coastguard Worker LogicalBinaryTest(tflite::BuiltinOperator_LOGICAL_OR,
73*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_BOOL,
74*89c4ff92SAndroid Build Coastguard Worker backends,
75*89c4ff92SAndroid Build Coastguard Worker input0Shape,
76*89c4ff92SAndroid Build Coastguard Worker input1Shape,
77*89c4ff92SAndroid Build Coastguard Worker expectedOutputShape,
78*89c4ff92SAndroid Build Coastguard Worker input0Values,
79*89c4ff92SAndroid Build Coastguard Worker input1Values,
80*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues);
81*89c4ff92SAndroid Build Coastguard Worker }
82*89c4ff92SAndroid Build Coastguard Worker
LogicalBinaryOrBroadcastTest(std::vector<armnn::BackendId> & backends)83*89c4ff92SAndroid Build Coastguard Worker void LogicalBinaryOrBroadcastTest(std::vector<armnn::BackendId>& backends)
84*89c4ff92SAndroid Build Coastguard Worker {
85*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> input0Shape { 1, 2, 2 };
86*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> input1Shape { 1, 1, 1 };
87*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> expectedOutputShape { 1, 2, 2 };
88*89c4ff92SAndroid Build Coastguard Worker
89*89c4ff92SAndroid Build Coastguard Worker std::vector<bool> input0Values { 0, 1, 0, 1 };
90*89c4ff92SAndroid Build Coastguard Worker std::vector<bool> input1Values { 1 };
91*89c4ff92SAndroid Build Coastguard Worker std::vector<bool> expectedOutputValues { 1, 1, 1, 1 };
92*89c4ff92SAndroid Build Coastguard Worker
93*89c4ff92SAndroid Build Coastguard Worker LogicalBinaryTest(tflite::BuiltinOperator_LOGICAL_OR,
94*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_BOOL,
95*89c4ff92SAndroid Build Coastguard Worker backends,
96*89c4ff92SAndroid Build Coastguard Worker input0Shape,
97*89c4ff92SAndroid Build Coastguard Worker input1Shape,
98*89c4ff92SAndroid Build Coastguard Worker expectedOutputShape,
99*89c4ff92SAndroid Build Coastguard Worker input0Values,
100*89c4ff92SAndroid Build Coastguard Worker input1Values,
101*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues);
102*89c4ff92SAndroid Build Coastguard Worker }
103*89c4ff92SAndroid Build Coastguard Worker
104*89c4ff92SAndroid Build Coastguard Worker // LogicalNot operator uses ElementwiseUnary unary layer and descriptor but is still classed as logical operator.
LogicalNotBoolTest(std::vector<armnn::BackendId> & backends)105*89c4ff92SAndroid Build Coastguard Worker void LogicalNotBoolTest(std::vector<armnn::BackendId>& backends)
106*89c4ff92SAndroid Build Coastguard Worker {
107*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> inputShape { 1, 2, 2 };
108*89c4ff92SAndroid Build Coastguard Worker
109*89c4ff92SAndroid Build Coastguard Worker std::vector<bool> inputValues { 0, 1, 0, 1 };
110*89c4ff92SAndroid Build Coastguard Worker std::vector<bool> expectedOutputValues { 1, 0, 1, 0 };
111*89c4ff92SAndroid Build Coastguard Worker
112*89c4ff92SAndroid Build Coastguard Worker ElementwiseUnaryBoolTest(tflite::BuiltinOperator_LOGICAL_NOT,
113*89c4ff92SAndroid Build Coastguard Worker backends,
114*89c4ff92SAndroid Build Coastguard Worker inputShape,
115*89c4ff92SAndroid Build Coastguard Worker inputValues,
116*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues);
117*89c4ff92SAndroid Build Coastguard Worker }
118*89c4ff92SAndroid Build Coastguard Worker
119*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("LogicalBinaryTests_GpuAccTests")
120*89c4ff92SAndroid Build Coastguard Worker {
121*89c4ff92SAndroid Build Coastguard Worker
122*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("LogicalBinary_AND_Bool_GpuAcc_Test")
123*89c4ff92SAndroid Build Coastguard Worker {
124*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
125*89c4ff92SAndroid Build Coastguard Worker LogicalBinaryAndBoolTest(backends);
126*89c4ff92SAndroid Build Coastguard Worker }
127*89c4ff92SAndroid Build Coastguard Worker
128*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("LogicalBinary_AND_Broadcast_GpuAcc_Test")
129*89c4ff92SAndroid Build Coastguard Worker {
130*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
131*89c4ff92SAndroid Build Coastguard Worker LogicalBinaryAndBroadcastTest(backends);
132*89c4ff92SAndroid Build Coastguard Worker }
133*89c4ff92SAndroid Build Coastguard Worker
134*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Logical_NOT_Bool_GpuAcc_Test")
135*89c4ff92SAndroid Build Coastguard Worker {
136*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
137*89c4ff92SAndroid Build Coastguard Worker LogicalNotBoolTest(backends);
138*89c4ff92SAndroid Build Coastguard Worker }
139*89c4ff92SAndroid Build Coastguard Worker
140*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("LogicalBinary_OR_Bool_GpuAcc_Test")
141*89c4ff92SAndroid Build Coastguard Worker {
142*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
143*89c4ff92SAndroid Build Coastguard Worker LogicalBinaryOrBoolTest(backends);
144*89c4ff92SAndroid Build Coastguard Worker }
145*89c4ff92SAndroid Build Coastguard Worker
146*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("LogicalBinary_OR_Broadcast_GpuAcc_Test")
147*89c4ff92SAndroid Build Coastguard Worker {
148*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
149*89c4ff92SAndroid Build Coastguard Worker LogicalBinaryOrBroadcastTest(backends);
150*89c4ff92SAndroid Build Coastguard Worker }
151*89c4ff92SAndroid Build Coastguard Worker
152*89c4ff92SAndroid Build Coastguard Worker }
153*89c4ff92SAndroid Build Coastguard Worker
154*89c4ff92SAndroid Build Coastguard Worker
155*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("LogicalBinaryTests_CpuAccTests")
156*89c4ff92SAndroid Build Coastguard Worker {
157*89c4ff92SAndroid Build Coastguard Worker
158*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("LogicalBinary_AND_Bool_CpuAcc_Test")
159*89c4ff92SAndroid Build Coastguard Worker {
160*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
161*89c4ff92SAndroid Build Coastguard Worker LogicalBinaryAndBoolTest(backends);
162*89c4ff92SAndroid Build Coastguard Worker }
163*89c4ff92SAndroid Build Coastguard Worker
164*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("LogicalBinary_AND_Broadcast_CpuAcc_Test")
165*89c4ff92SAndroid Build Coastguard Worker {
166*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
167*89c4ff92SAndroid Build Coastguard Worker LogicalBinaryAndBroadcastTest(backends);
168*89c4ff92SAndroid Build Coastguard Worker }
169*89c4ff92SAndroid Build Coastguard Worker
170*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Logical_NOT_Bool_CpuAcc_Test")
171*89c4ff92SAndroid Build Coastguard Worker {
172*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
173*89c4ff92SAndroid Build Coastguard Worker LogicalNotBoolTest(backends);
174*89c4ff92SAndroid Build Coastguard Worker }
175*89c4ff92SAndroid Build Coastguard Worker
176*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("LogicalBinary_OR_Bool_CpuAcc_Test")
177*89c4ff92SAndroid Build Coastguard Worker {
178*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
179*89c4ff92SAndroid Build Coastguard Worker LogicalBinaryOrBoolTest(backends);
180*89c4ff92SAndroid Build Coastguard Worker }
181*89c4ff92SAndroid Build Coastguard Worker
182*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("LogicalBinary_OR_Broadcast_CpuAcc_Test")
183*89c4ff92SAndroid Build Coastguard Worker {
184*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
185*89c4ff92SAndroid Build Coastguard Worker LogicalBinaryOrBroadcastTest(backends);
186*89c4ff92SAndroid Build Coastguard Worker }
187*89c4ff92SAndroid Build Coastguard Worker
188*89c4ff92SAndroid Build Coastguard Worker }
189*89c4ff92SAndroid Build Coastguard Worker
190*89c4ff92SAndroid Build Coastguard Worker
191*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("LogicalBinaryTests_CpuRefTests")
192*89c4ff92SAndroid Build Coastguard Worker {
193*89c4ff92SAndroid Build Coastguard Worker
194*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("LogicalBinary_AND_Bool_CpuRef_Test")
195*89c4ff92SAndroid Build Coastguard Worker {
196*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
197*89c4ff92SAndroid Build Coastguard Worker LogicalBinaryAndBoolTest(backends);
198*89c4ff92SAndroid Build Coastguard Worker }
199*89c4ff92SAndroid Build Coastguard Worker
200*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("LogicalBinary_AND_Broadcast_CpuRef_Test")
201*89c4ff92SAndroid Build Coastguard Worker {
202*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
203*89c4ff92SAndroid Build Coastguard Worker LogicalBinaryAndBroadcastTest(backends);
204*89c4ff92SAndroid Build Coastguard Worker }
205*89c4ff92SAndroid Build Coastguard Worker
206*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Logical_NOT_Bool_CpuRef_Test")
207*89c4ff92SAndroid Build Coastguard Worker {
208*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
209*89c4ff92SAndroid Build Coastguard Worker LogicalNotBoolTest(backends);
210*89c4ff92SAndroid Build Coastguard Worker }
211*89c4ff92SAndroid Build Coastguard Worker
212*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("LogicalBinary_OR_Bool_CpuRef_Test")
213*89c4ff92SAndroid Build Coastguard Worker {
214*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
215*89c4ff92SAndroid Build Coastguard Worker LogicalBinaryOrBoolTest(backends);
216*89c4ff92SAndroid Build Coastguard Worker }
217*89c4ff92SAndroid Build Coastguard Worker
218*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("LogicalBinary_OR_Broadcast_CpuRef_Test")
219*89c4ff92SAndroid Build Coastguard Worker {
220*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
221*89c4ff92SAndroid Build Coastguard Worker LogicalBinaryOrBroadcastTest(backends);
222*89c4ff92SAndroid Build Coastguard Worker }
223*89c4ff92SAndroid Build Coastguard Worker
224*89c4ff92SAndroid Build Coastguard Worker }
225*89c4ff92SAndroid Build Coastguard Worker
226*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate