xref: /aosp_15_r20/external/executorch/kernels/test/BinaryLogicalOpTest.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #pragma once
10*523fa7a6SAndroid Build Coastguard Worker 
11*523fa7a6SAndroid Build Coastguard Worker #include <executorch/kernels/test/TestUtil.h>
12*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/exec_aten.h>
13*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15*523fa7a6SAndroid Build Coastguard Worker 
16*523fa7a6SAndroid Build Coastguard Worker namespace torch::executor::testing {
17*523fa7a6SAndroid Build Coastguard Worker class BinaryLogicalOpTest : public OperatorTest {
18*523fa7a6SAndroid Build Coastguard Worker  protected:
19*523fa7a6SAndroid Build Coastguard Worker   // Implement this to call the torch::executor::aten::op_outf function for the
20*523fa7a6SAndroid Build Coastguard Worker   // op.
21*523fa7a6SAndroid Build Coastguard Worker   virtual exec_aten::Tensor& op_out(
22*523fa7a6SAndroid Build Coastguard Worker       const exec_aten::Tensor& lhs,
23*523fa7a6SAndroid Build Coastguard Worker       const exec_aten::Tensor& rhs,
24*523fa7a6SAndroid Build Coastguard Worker       exec_aten::Tensor& out) = 0;
25*523fa7a6SAndroid Build Coastguard Worker 
26*523fa7a6SAndroid Build Coastguard Worker   // Scalar reference implementation of the function in question for testing.
27*523fa7a6SAndroid Build Coastguard Worker   virtual double op_reference(double x, double y) const = 0;
28*523fa7a6SAndroid Build Coastguard Worker 
29*523fa7a6SAndroid Build Coastguard Worker   template <
30*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType IN_DTYPE,
31*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType IN_DTYPE2,
32*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType OUT_DTYPE>
test_op_out()33*523fa7a6SAndroid Build Coastguard Worker   void test_op_out() {
34*523fa7a6SAndroid Build Coastguard Worker     TensorFactory<IN_DTYPE> tf_in;
35*523fa7a6SAndroid Build Coastguard Worker     TensorFactory<IN_DTYPE2> tf_in2;
36*523fa7a6SAndroid Build Coastguard Worker     TensorFactory<OUT_DTYPE> tf_out;
37*523fa7a6SAndroid Build Coastguard Worker 
38*523fa7a6SAndroid Build Coastguard Worker     exec_aten::Tensor out = tf_out.zeros({1, 4});
39*523fa7a6SAndroid Build Coastguard Worker 
40*523fa7a6SAndroid Build Coastguard Worker     using CTYPE1 = typename decltype(tf_in)::ctype;
41*523fa7a6SAndroid Build Coastguard Worker     std::vector<CTYPE1> test_vector1 = {0, CTYPE1(-1), CTYPE1(0), CTYPE1(31)};
42*523fa7a6SAndroid Build Coastguard Worker 
43*523fa7a6SAndroid Build Coastguard Worker     using CTYPE2 = typename decltype(tf_in2)::ctype;
44*523fa7a6SAndroid Build Coastguard Worker     std::vector<CTYPE2> test_vector2 = {
45*523fa7a6SAndroid Build Coastguard Worker         CTYPE2(0),
46*523fa7a6SAndroid Build Coastguard Worker         CTYPE2(0),
47*523fa7a6SAndroid Build Coastguard Worker         CTYPE2(15),
48*523fa7a6SAndroid Build Coastguard Worker         CTYPE2(12),
49*523fa7a6SAndroid Build Coastguard Worker     };
50*523fa7a6SAndroid Build Coastguard Worker 
51*523fa7a6SAndroid Build Coastguard Worker     std::vector<typename decltype(tf_out)::ctype> expected_vector;
52*523fa7a6SAndroid Build Coastguard Worker     for (int ii = 0; ii < test_vector1.size(); ++ii) {
53*523fa7a6SAndroid Build Coastguard Worker       expected_vector.push_back(
54*523fa7a6SAndroid Build Coastguard Worker           op_reference(test_vector1[ii], test_vector2[ii]));
55*523fa7a6SAndroid Build Coastguard Worker     }
56*523fa7a6SAndroid Build Coastguard Worker 
57*523fa7a6SAndroid Build Coastguard Worker     op_out(
58*523fa7a6SAndroid Build Coastguard Worker         tf_in.make({1, 4}, test_vector1),
59*523fa7a6SAndroid Build Coastguard Worker         tf_in2.make({1, 4}, test_vector2),
60*523fa7a6SAndroid Build Coastguard Worker         out);
61*523fa7a6SAndroid Build Coastguard Worker 
62*523fa7a6SAndroid Build Coastguard Worker     EXPECT_TENSOR_CLOSE(out, tf_out.make({1, 4}, expected_vector));
63*523fa7a6SAndroid Build Coastguard Worker   }
64*523fa7a6SAndroid Build Coastguard Worker 
65*523fa7a6SAndroid Build Coastguard Worker   void test_all_dtypes();
66*523fa7a6SAndroid Build Coastguard Worker };
67*523fa7a6SAndroid Build Coastguard Worker 
68*523fa7a6SAndroid Build Coastguard Worker #define IMPLEMENT_BINARY_LOGICAL_OP_TEST(TestName) \
69*523fa7a6SAndroid Build Coastguard Worker   TEST_F(TestName, SimpleTestAllTypes) {           \
70*523fa7a6SAndroid Build Coastguard Worker     test_all_dtypes();                             \
71*523fa7a6SAndroid Build Coastguard Worker   }
72*523fa7a6SAndroid Build Coastguard Worker } // namespace torch::executor::testing
73