xref: /aosp_15_r20/external/executorch/kernels/test/BinaryLogicalOpTest.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/kernels/test/TestUtil.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 
16 namespace torch::executor::testing {
17 class BinaryLogicalOpTest : public OperatorTest {
18  protected:
19   // Implement this to call the torch::executor::aten::op_outf function for the
20   // op.
21   virtual exec_aten::Tensor& op_out(
22       const exec_aten::Tensor& lhs,
23       const exec_aten::Tensor& rhs,
24       exec_aten::Tensor& out) = 0;
25 
26   // Scalar reference implementation of the function in question for testing.
27   virtual double op_reference(double x, double y) const = 0;
28 
29   template <
30       exec_aten::ScalarType IN_DTYPE,
31       exec_aten::ScalarType IN_DTYPE2,
32       exec_aten::ScalarType OUT_DTYPE>
test_op_out()33   void test_op_out() {
34     TensorFactory<IN_DTYPE> tf_in;
35     TensorFactory<IN_DTYPE2> tf_in2;
36     TensorFactory<OUT_DTYPE> tf_out;
37 
38     exec_aten::Tensor out = tf_out.zeros({1, 4});
39 
40     using CTYPE1 = typename decltype(tf_in)::ctype;
41     std::vector<CTYPE1> test_vector1 = {0, CTYPE1(-1), CTYPE1(0), CTYPE1(31)};
42 
43     using CTYPE2 = typename decltype(tf_in2)::ctype;
44     std::vector<CTYPE2> test_vector2 = {
45         CTYPE2(0),
46         CTYPE2(0),
47         CTYPE2(15),
48         CTYPE2(12),
49     };
50 
51     std::vector<typename decltype(tf_out)::ctype> expected_vector;
52     for (int ii = 0; ii < test_vector1.size(); ++ii) {
53       expected_vector.push_back(
54           op_reference(test_vector1[ii], test_vector2[ii]));
55     }
56 
57     op_out(
58         tf_in.make({1, 4}, test_vector1),
59         tf_in2.make({1, 4}, test_vector2),
60         out);
61 
62     EXPECT_TENSOR_CLOSE(out, tf_out.make({1, 4}, expected_vector));
63   }
64 
65   void test_all_dtypes();
66 };
67 
68 #define IMPLEMENT_BINARY_LOGICAL_OP_TEST(TestName) \
69   TEST_F(TestName, SimpleTestAllTypes) {           \
70     test_all_dtypes();                             \
71   }
72 } // namespace torch::executor::testing
73