1*523fa7a6SAndroid Build Coastguard Worker /* 2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates. 3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved. 4*523fa7a6SAndroid Build Coastguard Worker * 5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the 6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree. 7*523fa7a6SAndroid Build Coastguard Worker */ 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Worker #pragma once 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Worker #include <executorch/kernels/test/TestUtil.h> 12*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/exec_aten.h> 13*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h> 14*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h> 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Worker namespace torch::executor::testing { 17*523fa7a6SAndroid Build Coastguard Worker class BinaryLogicalOpTest : public OperatorTest { 18*523fa7a6SAndroid Build Coastguard Worker protected: 19*523fa7a6SAndroid Build Coastguard Worker // Implement this to call the torch::executor::aten::op_outf function for the 20*523fa7a6SAndroid Build Coastguard Worker // op. 21*523fa7a6SAndroid Build Coastguard Worker virtual exec_aten::Tensor& op_out( 22*523fa7a6SAndroid Build Coastguard Worker const exec_aten::Tensor& lhs, 23*523fa7a6SAndroid Build Coastguard Worker const exec_aten::Tensor& rhs, 24*523fa7a6SAndroid Build Coastguard Worker exec_aten::Tensor& out) = 0; 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard Worker // Scalar reference implementation of the function in question for testing. 27*523fa7a6SAndroid Build Coastguard Worker virtual double op_reference(double x, double y) const = 0; 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Worker template < 30*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType IN_DTYPE, 31*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType IN_DTYPE2, 32*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType OUT_DTYPE> test_op_out()33*523fa7a6SAndroid Build Coastguard Worker void test_op_out() { 34*523fa7a6SAndroid Build Coastguard Worker TensorFactory<IN_DTYPE> tf_in; 35*523fa7a6SAndroid Build Coastguard Worker TensorFactory<IN_DTYPE2> tf_in2; 36*523fa7a6SAndroid Build Coastguard Worker TensorFactory<OUT_DTYPE> tf_out; 37*523fa7a6SAndroid Build Coastguard Worker 38*523fa7a6SAndroid Build Coastguard Worker exec_aten::Tensor out = tf_out.zeros({1, 4}); 39*523fa7a6SAndroid Build Coastguard Worker 40*523fa7a6SAndroid Build Coastguard Worker using CTYPE1 = typename decltype(tf_in)::ctype; 41*523fa7a6SAndroid Build Coastguard Worker std::vector<CTYPE1> test_vector1 = {0, CTYPE1(-1), CTYPE1(0), CTYPE1(31)}; 42*523fa7a6SAndroid Build Coastguard Worker 43*523fa7a6SAndroid Build Coastguard Worker using CTYPE2 = typename decltype(tf_in2)::ctype; 44*523fa7a6SAndroid Build Coastguard Worker std::vector<CTYPE2> test_vector2 = { 45*523fa7a6SAndroid Build Coastguard Worker CTYPE2(0), 46*523fa7a6SAndroid Build Coastguard Worker CTYPE2(0), 47*523fa7a6SAndroid Build Coastguard Worker CTYPE2(15), 48*523fa7a6SAndroid Build Coastguard Worker CTYPE2(12), 49*523fa7a6SAndroid Build Coastguard Worker }; 50*523fa7a6SAndroid Build Coastguard Worker 51*523fa7a6SAndroid Build Coastguard Worker std::vector<typename decltype(tf_out)::ctype> expected_vector; 52*523fa7a6SAndroid Build Coastguard Worker for (int ii = 0; ii < test_vector1.size(); ++ii) { 53*523fa7a6SAndroid Build Coastguard Worker expected_vector.push_back( 54*523fa7a6SAndroid Build Coastguard Worker op_reference(test_vector1[ii], test_vector2[ii])); 55*523fa7a6SAndroid Build Coastguard Worker } 56*523fa7a6SAndroid Build Coastguard Worker 57*523fa7a6SAndroid Build Coastguard Worker op_out( 58*523fa7a6SAndroid Build Coastguard Worker tf_in.make({1, 4}, test_vector1), 59*523fa7a6SAndroid Build Coastguard Worker tf_in2.make({1, 4}, test_vector2), 60*523fa7a6SAndroid Build Coastguard Worker out); 61*523fa7a6SAndroid Build Coastguard Worker 62*523fa7a6SAndroid Build Coastguard Worker EXPECT_TENSOR_CLOSE(out, tf_out.make({1, 4}, expected_vector)); 63*523fa7a6SAndroid Build Coastguard Worker } 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Worker void test_all_dtypes(); 66*523fa7a6SAndroid Build Coastguard Worker }; 67*523fa7a6SAndroid Build Coastguard Worker 68*523fa7a6SAndroid Build Coastguard Worker #define IMPLEMENT_BINARY_LOGICAL_OP_TEST(TestName) \ 69*523fa7a6SAndroid Build Coastguard Worker TEST_F(TestName, SimpleTestAllTypes) { \ 70*523fa7a6SAndroid Build Coastguard Worker test_all_dtypes(); \ 71*523fa7a6SAndroid Build Coastguard Worker } 72*523fa7a6SAndroid Build Coastguard Worker } // namespace torch::executor::testing 73