1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <executorch/kernels/test/TestUtil.h> 12 #include <executorch/runtime/core/exec_aten/exec_aten.h> 13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h> 14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h> 15 16 namespace torch::executor::testing { 17 class BinaryLogicalOpTest : public OperatorTest { 18 protected: 19 // Implement this to call the torch::executor::aten::op_outf function for the 20 // op. 21 virtual exec_aten::Tensor& op_out( 22 const exec_aten::Tensor& lhs, 23 const exec_aten::Tensor& rhs, 24 exec_aten::Tensor& out) = 0; 25 26 // Scalar reference implementation of the function in question for testing. 27 virtual double op_reference(double x, double y) const = 0; 28 29 template < 30 exec_aten::ScalarType IN_DTYPE, 31 exec_aten::ScalarType IN_DTYPE2, 32 exec_aten::ScalarType OUT_DTYPE> test_op_out()33 void test_op_out() { 34 TensorFactory<IN_DTYPE> tf_in; 35 TensorFactory<IN_DTYPE2> tf_in2; 36 TensorFactory<OUT_DTYPE> tf_out; 37 38 exec_aten::Tensor out = tf_out.zeros({1, 4}); 39 40 using CTYPE1 = typename decltype(tf_in)::ctype; 41 std::vector<CTYPE1> test_vector1 = {0, CTYPE1(-1), CTYPE1(0), CTYPE1(31)}; 42 43 using CTYPE2 = typename decltype(tf_in2)::ctype; 44 std::vector<CTYPE2> test_vector2 = { 45 CTYPE2(0), 46 CTYPE2(0), 47 CTYPE2(15), 48 CTYPE2(12), 49 }; 50 51 std::vector<typename decltype(tf_out)::ctype> expected_vector; 52 for (int ii = 0; ii < test_vector1.size(); ++ii) { 53 expected_vector.push_back( 54 op_reference(test_vector1[ii], test_vector2[ii])); 55 } 56 57 op_out( 58 tf_in.make({1, 4}, test_vector1), 59 tf_in2.make({1, 4}, test_vector2), 60 out); 61 62 EXPECT_TENSOR_CLOSE(out, tf_out.make({1, 4}, expected_vector)); 63 } 64 65 void test_all_dtypes(); 66 }; 67 68 #define IMPLEMENT_BINARY_LOGICAL_OP_TEST(TestName) \ 69 TEST_F(TestName, SimpleTestAllTypes) { \ 70 test_all_dtypes(); \ 71 } 72 } // namespace torch::executor::testing 73