1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <executorch/kernels/test/TestUtil.h> 12 #include <executorch/kernels/test/supported_features.h> 13 #include <executorch/runtime/core/exec_aten/exec_aten.h> 14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h> 15 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h> 16 17 #include <gtest/gtest.h> 18 19 namespace torch::executor::testing { 20 // Generic test harness for ops that use unary_ufunc_realhb_to_floath 21 // -- in other words, ops that just apply an elementwise function 22 // mapping to a float or half. 23 class UnaryUfuncRealHBBF16ToFloatHBF16Test : public OperatorTest { 24 protected: 25 // Implement this to call the torch::executor::aten::op_outf function for the 26 // op. 27 virtual exec_aten::Tensor& op_out( 28 const exec_aten::Tensor& self, 29 exec_aten::Tensor& out) = 0; 30 31 // Scalar reference implementation of the function in question for testing. 32 virtual double op_reference(double x) const = 0; 33 34 // The SupportedFeatures system assumes that it can build each test 35 // target with a separate SupportedFeatures (really just one 36 // portable, one optimzed but between one and the infinite, two is 37 // ridiculous and can't exist). We work around that by calling 38 // SupportedFeatures::get() in the concrete test translation 39 // unit. You need to declare an override, but we implement it for you 40 // in IMPLEMENT_UNARY_UFUNC_REALHB_TO_FLOATH_TEST. 41 virtual SupportedFeatures* get_supported_features() const = 0; 42 43 template <exec_aten::ScalarType IN_DTYPE, exec_aten::ScalarType OUT_DTYPE> 44 void test_floating_point_op_out( 45 const std::vector<int32_t>& out_shape = {1, 6}, 46 exec_aten::TensorShapeDynamism dynamism = 47 exec_aten::TensorShapeDynamism::STATIC) { 48 TensorFactory<IN_DTYPE> tf_in; 49 TensorFactory<OUT_DTYPE> tf_out; 50 51 exec_aten::Tensor out = tf_out.zeros(out_shape, dynamism); 52 53 using IN_CTYPE = typename decltype(tf_in)::ctype; 54 using OUT_CTYPE = typename decltype(tf_out)::ctype; 55 std::vector<IN_CTYPE> test_vector = {0, 1, 3, 5, 10, 100}; 56 std::vector<OUT_CTYPE> expected_vector; 57 for (int ii = 0; ii < test_vector.size(); ++ii) { 58 auto ref_result = this->op_reference(test_vector[ii]); 59 // Drop test cases with high magnitude results due to precision 60 // issues. 61 if ((std::abs(ref_result) > 1e30 || std::abs(ref_result) < -1e30)) { 62 test_vector[ii] = 2; 63 ref_result = this->op_reference(2); 64 } 65 expected_vector.push_back(ref_result); 66 } 67 68 // clang-format off 69 op_out(tf_in.make({1, 6}, test_vector), out); 70 71 auto expected = tf_out.make({1, 6}, expected_vector); 72 if (IN_DTYPE == ScalarType::BFloat16 || OUT_DTYPE == ScalarType::BFloat16) { 73 double rtol = executorch::runtime::testing::internal::kDefaultRtol; 74 // It appears we need a higher tolerance for at least some ATen 75 // tests, like aten_op_acosh_test. 76 if (get_supported_features()->is_aten) { 77 rtol = 3e-3; 78 } 79 EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, rtol, executorch::runtime::testing::internal::kDefaultBFloat16Atol); 80 } else if (IN_DTYPE == ScalarType::Half || OUT_DTYPE == ScalarType::Half) { 81 double rtol = executorch::runtime::testing::internal::kDefaultRtol; 82 // It appears we need a higher tolerance for at least some ATen 83 // tests, like aten_op_acosh_test. 84 if (get_supported_features()->is_aten) { 85 rtol = 1e-3; 86 } 87 EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, rtol, executorch::runtime::testing::internal::kDefaultHalfAtol); 88 } else { 89 EXPECT_TENSOR_CLOSE(out, expected); 90 } 91 // clang-format on 92 } 93 94 // Unhandled output dtypes. 95 template < 96 exec_aten::ScalarType INPUT_DTYPE, 97 exec_aten::ScalarType OUTPUT_DTYPE> test_op_invalid_output_dtype_dies()98 void test_op_invalid_output_dtype_dies() { 99 TensorFactory<INPUT_DTYPE> tf; 100 TensorFactory<OUTPUT_DTYPE> tf_out; 101 102 const std::vector<int32_t> sizes = {2, 5}; 103 104 exec_aten::Tensor in = tf.ones(sizes); 105 exec_aten::Tensor out = tf_out.zeros(sizes); 106 107 ET_EXPECT_KERNEL_FAILURE(context_, op_out(in, out)); 108 } 109 110 void test_bool_input(); 111 112 void test_mismatched_input_shapes_dies(); 113 114 void test_all_real_input_half_output_static_dynamism_support(); 115 116 void test_all_real_input_bfloat16_output_static_dynamism_support(); 117 118 void test_all_real_input_float_output_static_dynamism_support(); 119 120 void test_all_real_input_double_output_static_dynamism_support(); 121 122 void test_all_real_input_half_output_bound_dynamism_support(); 123 124 void test_all_real_input_bfloat16_output_bound_dynamism_support(); 125 126 void test_all_real_input_float_output_bound_dynamism_support(); 127 128 void test_all_real_input_double_output_bound_dynamism_support(); 129 130 void test_all_real_input_float_output_unbound_dynamism_support(); 131 132 void test_all_real_input_double_output_unbound_dynamism_support(); 133 134 void test_non_float_output_dtype_dies(); 135 }; 136 137 #define IMPLEMENT_UNARY_UFUNC_REALHB_TO_FLOATH_TEST(TestName) \ 138 torch::executor::testing::SupportedFeatures* \ 139 TestName::get_supported_features() const { \ 140 return torch::executor::testing::SupportedFeatures::get(); \ 141 } \ 142 TEST_F(TestName, HandleBoolInput) { \ 143 test_bool_input(); \ 144 } \ 145 TEST_F(TestName, AllRealInputHalfOutputStaticDynamismSupport) { \ 146 test_all_real_input_half_output_static_dynamism_support(); \ 147 } \ 148 \ 149 TEST_F(TestName, AllRealInputBFloat16OutputStaticDynamismSupport) { \ 150 test_all_real_input_bfloat16_output_static_dynamism_support(); \ 151 } \ 152 \ 153 TEST_F(TestName, AllRealInputFloatOutputStaticDynamismSupport) { \ 154 test_all_real_input_float_output_static_dynamism_support(); \ 155 } \ 156 \ 157 TEST_F(TestName, AllRealInputDoubleOutputStaticDynamismSupport) { \ 158 test_all_real_input_double_output_static_dynamism_support(); \ 159 } \ 160 \ 161 TEST_F(TestName, AllRealInputBFloat16OutputBoundDynamismSupport) { \ 162 test_all_real_input_bfloat16_output_bound_dynamism_support(); \ 163 } \ 164 \ 165 TEST_F(TestName, AllRealInputFloatOutputBoundDynamismSupport) { \ 166 test_all_real_input_float_output_bound_dynamism_support(); \ 167 } \ 168 \ 169 TEST_F(TestName, AllRealInputDoubleOutputBoundDynamismSupport) { \ 170 test_all_real_input_double_output_bound_dynamism_support(); \ 171 } \ 172 \ 173 TEST_F(TestName, AllRealInputFloatOutputUnboundDynamismSupport) { \ 174 test_all_real_input_float_output_unbound_dynamism_support(); \ 175 } \ 176 \ 177 TEST_F(TestName, AllRealInputDoubleOutputUnboundDynamismSupport) { \ 178 test_all_real_input_double_output_unbound_dynamism_support(); \ 179 } \ 180 \ 181 TEST_F(TestName, AllNonFloatOutputDTypeDies) { \ 182 test_non_float_output_dtype_dies(); \ 183 } \ 184 \ 185 TEST_F(TestName, MismatchedInputShapesDies) { \ 186 test_mismatched_input_shapes_dies(); \ 187 } 188 189 } // namespace torch::executor::testing 190