1*523fa7a6SAndroid Build Coastguard Worker /* 2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates. 3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved. 4*523fa7a6SAndroid Build Coastguard Worker * 5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the 6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree. 7*523fa7a6SAndroid Build Coastguard Worker */ 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Worker #pragma once 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Worker #include <executorch/kernels/test/TestUtil.h> 12*523fa7a6SAndroid Build Coastguard Worker #include <executorch/kernels/test/supported_features.h> 13*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/exec_aten.h> 14*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h> 15*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h> 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Worker #include <gtest/gtest.h> 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Worker namespace torch::executor::testing { 20*523fa7a6SAndroid Build Coastguard Worker // Generic test harness for ops that use unary_ufunc_realhb_to_floath 21*523fa7a6SAndroid Build Coastguard Worker // -- in other words, ops that just apply an elementwise function 22*523fa7a6SAndroid Build Coastguard Worker // mapping to a float or half. 23*523fa7a6SAndroid Build Coastguard Worker class UnaryUfuncRealHBBF16ToFloatHBF16Test : public OperatorTest { 24*523fa7a6SAndroid Build Coastguard Worker protected: 25*523fa7a6SAndroid Build Coastguard Worker // Implement this to call the torch::executor::aten::op_outf function for the 26*523fa7a6SAndroid Build Coastguard Worker // op. 27*523fa7a6SAndroid Build Coastguard Worker virtual exec_aten::Tensor& op_out( 28*523fa7a6SAndroid Build Coastguard Worker const exec_aten::Tensor& self, 29*523fa7a6SAndroid Build Coastguard Worker exec_aten::Tensor& out) = 0; 30*523fa7a6SAndroid Build Coastguard Worker 31*523fa7a6SAndroid Build Coastguard Worker // Scalar reference implementation of the function in question for testing. 32*523fa7a6SAndroid Build Coastguard Worker virtual double op_reference(double x) const = 0; 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Worker // The SupportedFeatures system assumes that it can build each test 35*523fa7a6SAndroid Build Coastguard Worker // target with a separate SupportedFeatures (really just one 36*523fa7a6SAndroid Build Coastguard Worker // portable, one optimzed but between one and the infinite, two is 37*523fa7a6SAndroid Build Coastguard Worker // ridiculous and can't exist). We work around that by calling 38*523fa7a6SAndroid Build Coastguard Worker // SupportedFeatures::get() in the concrete test translation 39*523fa7a6SAndroid Build Coastguard Worker // unit. You need to declare an override, but we implement it for you 40*523fa7a6SAndroid Build Coastguard Worker // in IMPLEMENT_UNARY_UFUNC_REALHB_TO_FLOATH_TEST. 41*523fa7a6SAndroid Build Coastguard Worker virtual SupportedFeatures* get_supported_features() const = 0; 42*523fa7a6SAndroid Build Coastguard Worker 43*523fa7a6SAndroid Build Coastguard Worker template <exec_aten::ScalarType IN_DTYPE, exec_aten::ScalarType OUT_DTYPE> 44*523fa7a6SAndroid Build Coastguard Worker void test_floating_point_op_out( 45*523fa7a6SAndroid Build Coastguard Worker const std::vector<int32_t>& out_shape = {1, 6}, 46*523fa7a6SAndroid Build Coastguard Worker exec_aten::TensorShapeDynamism dynamism = 47*523fa7a6SAndroid Build Coastguard Worker exec_aten::TensorShapeDynamism::STATIC) { 48*523fa7a6SAndroid Build Coastguard Worker TensorFactory<IN_DTYPE> tf_in; 49*523fa7a6SAndroid Build Coastguard Worker TensorFactory<OUT_DTYPE> tf_out; 50*523fa7a6SAndroid Build Coastguard Worker 51*523fa7a6SAndroid Build Coastguard Worker exec_aten::Tensor out = tf_out.zeros(out_shape, dynamism); 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Worker using IN_CTYPE = typename decltype(tf_in)::ctype; 54*523fa7a6SAndroid Build Coastguard Worker using OUT_CTYPE = typename decltype(tf_out)::ctype; 55*523fa7a6SAndroid Build Coastguard Worker std::vector<IN_CTYPE> test_vector = {0, 1, 3, 5, 10, 100}; 56*523fa7a6SAndroid Build Coastguard Worker std::vector<OUT_CTYPE> expected_vector; 57*523fa7a6SAndroid Build Coastguard Worker for (int ii = 0; ii < test_vector.size(); ++ii) { 58*523fa7a6SAndroid Build Coastguard Worker auto ref_result = this->op_reference(test_vector[ii]); 59*523fa7a6SAndroid Build Coastguard Worker // Drop test cases with high magnitude results due to precision 60*523fa7a6SAndroid Build Coastguard Worker // issues. 61*523fa7a6SAndroid Build Coastguard Worker if ((std::abs(ref_result) > 1e30 || std::abs(ref_result) < -1e30)) { 62*523fa7a6SAndroid Build Coastguard Worker test_vector[ii] = 2; 63*523fa7a6SAndroid Build Coastguard Worker ref_result = this->op_reference(2); 64*523fa7a6SAndroid Build Coastguard Worker } 65*523fa7a6SAndroid Build Coastguard Worker expected_vector.push_back(ref_result); 66*523fa7a6SAndroid Build Coastguard Worker } 67*523fa7a6SAndroid Build Coastguard Worker 68*523fa7a6SAndroid Build Coastguard Worker // clang-format off 69*523fa7a6SAndroid Build Coastguard Worker op_out(tf_in.make({1, 6}, test_vector), out); 70*523fa7a6SAndroid Build Coastguard Worker 71*523fa7a6SAndroid Build Coastguard Worker auto expected = tf_out.make({1, 6}, expected_vector); 72*523fa7a6SAndroid Build Coastguard Worker if (IN_DTYPE == ScalarType::BFloat16 || OUT_DTYPE == ScalarType::BFloat16) { 73*523fa7a6SAndroid Build Coastguard Worker double rtol = executorch::runtime::testing::internal::kDefaultRtol; 74*523fa7a6SAndroid Build Coastguard Worker // It appears we need a higher tolerance for at least some ATen 75*523fa7a6SAndroid Build Coastguard Worker // tests, like aten_op_acosh_test. 76*523fa7a6SAndroid Build Coastguard Worker if (get_supported_features()->is_aten) { 77*523fa7a6SAndroid Build Coastguard Worker rtol = 3e-3; 78*523fa7a6SAndroid Build Coastguard Worker } 79*523fa7a6SAndroid Build Coastguard Worker EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, rtol, executorch::runtime::testing::internal::kDefaultBFloat16Atol); 80*523fa7a6SAndroid Build Coastguard Worker } else if (IN_DTYPE == ScalarType::Half || OUT_DTYPE == ScalarType::Half) { 81*523fa7a6SAndroid Build Coastguard Worker double rtol = executorch::runtime::testing::internal::kDefaultRtol; 82*523fa7a6SAndroid Build Coastguard Worker // It appears we need a higher tolerance for at least some ATen 83*523fa7a6SAndroid Build Coastguard Worker // tests, like aten_op_acosh_test. 84*523fa7a6SAndroid Build Coastguard Worker if (get_supported_features()->is_aten) { 85*523fa7a6SAndroid Build Coastguard Worker rtol = 1e-3; 86*523fa7a6SAndroid Build Coastguard Worker } 87*523fa7a6SAndroid Build Coastguard Worker EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, rtol, executorch::runtime::testing::internal::kDefaultHalfAtol); 88*523fa7a6SAndroid Build Coastguard Worker } else { 89*523fa7a6SAndroid Build Coastguard Worker EXPECT_TENSOR_CLOSE(out, expected); 90*523fa7a6SAndroid Build Coastguard Worker } 91*523fa7a6SAndroid Build Coastguard Worker // clang-format on 92*523fa7a6SAndroid Build Coastguard Worker } 93*523fa7a6SAndroid Build Coastguard Worker 94*523fa7a6SAndroid Build Coastguard Worker // Unhandled output dtypes. 95*523fa7a6SAndroid Build Coastguard Worker template < 96*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType INPUT_DTYPE, 97*523fa7a6SAndroid Build Coastguard Worker exec_aten::ScalarType OUTPUT_DTYPE> test_op_invalid_output_dtype_dies()98*523fa7a6SAndroid Build Coastguard Worker void test_op_invalid_output_dtype_dies() { 99*523fa7a6SAndroid Build Coastguard Worker TensorFactory<INPUT_DTYPE> tf; 100*523fa7a6SAndroid Build Coastguard Worker TensorFactory<OUTPUT_DTYPE> tf_out; 101*523fa7a6SAndroid Build Coastguard Worker 102*523fa7a6SAndroid Build Coastguard Worker const std::vector<int32_t> sizes = {2, 5}; 103*523fa7a6SAndroid Build Coastguard Worker 104*523fa7a6SAndroid Build Coastguard Worker exec_aten::Tensor in = tf.ones(sizes); 105*523fa7a6SAndroid Build Coastguard Worker exec_aten::Tensor out = tf_out.zeros(sizes); 106*523fa7a6SAndroid Build Coastguard Worker 107*523fa7a6SAndroid Build Coastguard Worker ET_EXPECT_KERNEL_FAILURE(context_, op_out(in, out)); 108*523fa7a6SAndroid Build Coastguard Worker } 109*523fa7a6SAndroid Build Coastguard Worker 110*523fa7a6SAndroid Build Coastguard Worker void test_bool_input(); 111*523fa7a6SAndroid Build Coastguard Worker 112*523fa7a6SAndroid Build Coastguard Worker void test_mismatched_input_shapes_dies(); 113*523fa7a6SAndroid Build Coastguard Worker 114*523fa7a6SAndroid Build Coastguard Worker void test_all_real_input_half_output_static_dynamism_support(); 115*523fa7a6SAndroid Build Coastguard Worker 116*523fa7a6SAndroid Build Coastguard Worker void test_all_real_input_bfloat16_output_static_dynamism_support(); 117*523fa7a6SAndroid Build Coastguard Worker 118*523fa7a6SAndroid Build Coastguard Worker void test_all_real_input_float_output_static_dynamism_support(); 119*523fa7a6SAndroid Build Coastguard Worker 120*523fa7a6SAndroid Build Coastguard Worker void test_all_real_input_double_output_static_dynamism_support(); 121*523fa7a6SAndroid Build Coastguard Worker 122*523fa7a6SAndroid Build Coastguard Worker void test_all_real_input_half_output_bound_dynamism_support(); 123*523fa7a6SAndroid Build Coastguard Worker 124*523fa7a6SAndroid Build Coastguard Worker void test_all_real_input_bfloat16_output_bound_dynamism_support(); 125*523fa7a6SAndroid Build Coastguard Worker 126*523fa7a6SAndroid Build Coastguard Worker void test_all_real_input_float_output_bound_dynamism_support(); 127*523fa7a6SAndroid Build Coastguard Worker 128*523fa7a6SAndroid Build Coastguard Worker void test_all_real_input_double_output_bound_dynamism_support(); 129*523fa7a6SAndroid Build Coastguard Worker 130*523fa7a6SAndroid Build Coastguard Worker void test_all_real_input_float_output_unbound_dynamism_support(); 131*523fa7a6SAndroid Build Coastguard Worker 132*523fa7a6SAndroid Build Coastguard Worker void test_all_real_input_double_output_unbound_dynamism_support(); 133*523fa7a6SAndroid Build Coastguard Worker 134*523fa7a6SAndroid Build Coastguard Worker void test_non_float_output_dtype_dies(); 135*523fa7a6SAndroid Build Coastguard Worker }; 136*523fa7a6SAndroid Build Coastguard Worker 137*523fa7a6SAndroid Build Coastguard Worker #define IMPLEMENT_UNARY_UFUNC_REALHB_TO_FLOATH_TEST(TestName) \ 138*523fa7a6SAndroid Build Coastguard Worker torch::executor::testing::SupportedFeatures* \ 139*523fa7a6SAndroid Build Coastguard Worker TestName::get_supported_features() const { \ 140*523fa7a6SAndroid Build Coastguard Worker return torch::executor::testing::SupportedFeatures::get(); \ 141*523fa7a6SAndroid Build Coastguard Worker } \ 142*523fa7a6SAndroid Build Coastguard Worker TEST_F(TestName, HandleBoolInput) { \ 143*523fa7a6SAndroid Build Coastguard Worker test_bool_input(); \ 144*523fa7a6SAndroid Build Coastguard Worker } \ 145*523fa7a6SAndroid Build Coastguard Worker TEST_F(TestName, AllRealInputHalfOutputStaticDynamismSupport) { \ 146*523fa7a6SAndroid Build Coastguard Worker test_all_real_input_half_output_static_dynamism_support(); \ 147*523fa7a6SAndroid Build Coastguard Worker } \ 148*523fa7a6SAndroid Build Coastguard Worker \ 149*523fa7a6SAndroid Build Coastguard Worker TEST_F(TestName, AllRealInputBFloat16OutputStaticDynamismSupport) { \ 150*523fa7a6SAndroid Build Coastguard Worker test_all_real_input_bfloat16_output_static_dynamism_support(); \ 151*523fa7a6SAndroid Build Coastguard Worker } \ 152*523fa7a6SAndroid Build Coastguard Worker \ 153*523fa7a6SAndroid Build Coastguard Worker TEST_F(TestName, AllRealInputFloatOutputStaticDynamismSupport) { \ 154*523fa7a6SAndroid Build Coastguard Worker test_all_real_input_float_output_static_dynamism_support(); \ 155*523fa7a6SAndroid Build Coastguard Worker } \ 156*523fa7a6SAndroid Build Coastguard Worker \ 157*523fa7a6SAndroid Build Coastguard Worker TEST_F(TestName, AllRealInputDoubleOutputStaticDynamismSupport) { \ 158*523fa7a6SAndroid Build Coastguard Worker test_all_real_input_double_output_static_dynamism_support(); \ 159*523fa7a6SAndroid Build Coastguard Worker } \ 160*523fa7a6SAndroid Build Coastguard Worker \ 161*523fa7a6SAndroid Build Coastguard Worker TEST_F(TestName, AllRealInputBFloat16OutputBoundDynamismSupport) { \ 162*523fa7a6SAndroid Build Coastguard Worker test_all_real_input_bfloat16_output_bound_dynamism_support(); \ 163*523fa7a6SAndroid Build Coastguard Worker } \ 164*523fa7a6SAndroid Build Coastguard Worker \ 165*523fa7a6SAndroid Build Coastguard Worker TEST_F(TestName, AllRealInputFloatOutputBoundDynamismSupport) { \ 166*523fa7a6SAndroid Build Coastguard Worker test_all_real_input_float_output_bound_dynamism_support(); \ 167*523fa7a6SAndroid Build Coastguard Worker } \ 168*523fa7a6SAndroid Build Coastguard Worker \ 169*523fa7a6SAndroid Build Coastguard Worker TEST_F(TestName, AllRealInputDoubleOutputBoundDynamismSupport) { \ 170*523fa7a6SAndroid Build Coastguard Worker test_all_real_input_double_output_bound_dynamism_support(); \ 171*523fa7a6SAndroid Build Coastguard Worker } \ 172*523fa7a6SAndroid Build Coastguard Worker \ 173*523fa7a6SAndroid Build Coastguard Worker TEST_F(TestName, AllRealInputFloatOutputUnboundDynamismSupport) { \ 174*523fa7a6SAndroid Build Coastguard Worker test_all_real_input_float_output_unbound_dynamism_support(); \ 175*523fa7a6SAndroid Build Coastguard Worker } \ 176*523fa7a6SAndroid Build Coastguard Worker \ 177*523fa7a6SAndroid Build Coastguard Worker TEST_F(TestName, AllRealInputDoubleOutputUnboundDynamismSupport) { \ 178*523fa7a6SAndroid Build Coastguard Worker test_all_real_input_double_output_unbound_dynamism_support(); \ 179*523fa7a6SAndroid Build Coastguard Worker } \ 180*523fa7a6SAndroid Build Coastguard Worker \ 181*523fa7a6SAndroid Build Coastguard Worker TEST_F(TestName, AllNonFloatOutputDTypeDies) { \ 182*523fa7a6SAndroid Build Coastguard Worker test_non_float_output_dtype_dies(); \ 183*523fa7a6SAndroid Build Coastguard Worker } \ 184*523fa7a6SAndroid Build Coastguard Worker \ 185*523fa7a6SAndroid Build Coastguard Worker TEST_F(TestName, MismatchedInputShapesDies) { \ 186*523fa7a6SAndroid Build Coastguard Worker test_mismatched_input_shapes_dies(); \ 187*523fa7a6SAndroid Build Coastguard Worker } 188*523fa7a6SAndroid Build Coastguard Worker 189*523fa7a6SAndroid Build Coastguard Worker } // namespace torch::executor::testing 190