xref: /aosp_15_r20/external/executorch/kernels/test/UnaryUfuncRealHBBF16ToFloatHBF16Test.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/kernels/test/TestUtil.h>
12 #include <executorch/kernels/test/supported_features.h>
13 #include <executorch/runtime/core/exec_aten/exec_aten.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
15 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
16 
17 #include <gtest/gtest.h>
18 
19 namespace torch::executor::testing {
20 // Generic test harness for ops that use unary_ufunc_realhb_to_floath
21 // -- in other words, ops that just apply an elementwise function
22 // mapping to a float or half.
23 class UnaryUfuncRealHBBF16ToFloatHBF16Test : public OperatorTest {
24  protected:
25   // Implement this to call the torch::executor::aten::op_outf function for the
26   // op.
27   virtual exec_aten::Tensor& op_out(
28       const exec_aten::Tensor& self,
29       exec_aten::Tensor& out) = 0;
30 
31   // Scalar reference implementation of the function in question for testing.
32   virtual double op_reference(double x) const = 0;
33 
34   // The SupportedFeatures system assumes that it can build each test
35   // target with a separate SupportedFeatures (really just one
36   // portable, one optimzed but between one and the infinite, two is
37   // ridiculous and can't exist). We work around that by calling
38   // SupportedFeatures::get() in the concrete test translation
39   // unit. You need to declare an override, but we implement it for you
40   // in IMPLEMENT_UNARY_UFUNC_REALHB_TO_FLOATH_TEST.
41   virtual SupportedFeatures* get_supported_features() const = 0;
42 
43   template <exec_aten::ScalarType IN_DTYPE, exec_aten::ScalarType OUT_DTYPE>
44   void test_floating_point_op_out(
45       const std::vector<int32_t>& out_shape = {1, 6},
46       exec_aten::TensorShapeDynamism dynamism =
47           exec_aten::TensorShapeDynamism::STATIC) {
48     TensorFactory<IN_DTYPE> tf_in;
49     TensorFactory<OUT_DTYPE> tf_out;
50 
51     exec_aten::Tensor out = tf_out.zeros(out_shape, dynamism);
52 
53     using IN_CTYPE = typename decltype(tf_in)::ctype;
54     using OUT_CTYPE = typename decltype(tf_out)::ctype;
55     std::vector<IN_CTYPE> test_vector = {0, 1, 3, 5, 10, 100};
56     std::vector<OUT_CTYPE> expected_vector;
57     for (int ii = 0; ii < test_vector.size(); ++ii) {
58       auto ref_result = this->op_reference(test_vector[ii]);
59       // Drop test cases with high magnitude results due to precision
60       // issues.
61       if ((std::abs(ref_result) > 1e30 || std::abs(ref_result) < -1e30)) {
62         test_vector[ii] = 2;
63         ref_result = this->op_reference(2);
64       }
65       expected_vector.push_back(ref_result);
66     }
67 
68     // clang-format off
69     op_out(tf_in.make({1, 6}, test_vector), out);
70 
71     auto expected = tf_out.make({1, 6}, expected_vector);
72     if (IN_DTYPE == ScalarType::BFloat16 || OUT_DTYPE == ScalarType::BFloat16) {
73       double rtol = executorch::runtime::testing::internal::kDefaultRtol;
74       // It appears we need a higher tolerance for at least some ATen
75       // tests, like aten_op_acosh_test.
76       if (get_supported_features()->is_aten) {
77         rtol = 3e-3;
78       }
79       EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, rtol, executorch::runtime::testing::internal::kDefaultBFloat16Atol);
80     } else if (IN_DTYPE == ScalarType::Half || OUT_DTYPE == ScalarType::Half) {
81       double rtol = executorch::runtime::testing::internal::kDefaultRtol;
82       // It appears we need a higher tolerance for at least some ATen
83       // tests, like aten_op_acosh_test.
84       if (get_supported_features()->is_aten) {
85         rtol = 1e-3;
86       }
87       EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, rtol, executorch::runtime::testing::internal::kDefaultHalfAtol);
88     } else {
89       EXPECT_TENSOR_CLOSE(out, expected);
90     }
91     // clang-format on
92   }
93 
94   // Unhandled output dtypes.
95   template <
96       exec_aten::ScalarType INPUT_DTYPE,
97       exec_aten::ScalarType OUTPUT_DTYPE>
test_op_invalid_output_dtype_dies()98   void test_op_invalid_output_dtype_dies() {
99     TensorFactory<INPUT_DTYPE> tf;
100     TensorFactory<OUTPUT_DTYPE> tf_out;
101 
102     const std::vector<int32_t> sizes = {2, 5};
103 
104     exec_aten::Tensor in = tf.ones(sizes);
105     exec_aten::Tensor out = tf_out.zeros(sizes);
106 
107     ET_EXPECT_KERNEL_FAILURE(context_, op_out(in, out));
108   }
109 
110   void test_bool_input();
111 
112   void test_mismatched_input_shapes_dies();
113 
114   void test_all_real_input_half_output_static_dynamism_support();
115 
116   void test_all_real_input_bfloat16_output_static_dynamism_support();
117 
118   void test_all_real_input_float_output_static_dynamism_support();
119 
120   void test_all_real_input_double_output_static_dynamism_support();
121 
122   void test_all_real_input_half_output_bound_dynamism_support();
123 
124   void test_all_real_input_bfloat16_output_bound_dynamism_support();
125 
126   void test_all_real_input_float_output_bound_dynamism_support();
127 
128   void test_all_real_input_double_output_bound_dynamism_support();
129 
130   void test_all_real_input_float_output_unbound_dynamism_support();
131 
132   void test_all_real_input_double_output_unbound_dynamism_support();
133 
134   void test_non_float_output_dtype_dies();
135 };
136 
137 #define IMPLEMENT_UNARY_UFUNC_REALHB_TO_FLOATH_TEST(TestName)         \
138   torch::executor::testing::SupportedFeatures*                        \
139   TestName::get_supported_features() const {                          \
140     return torch::executor::testing::SupportedFeatures::get();        \
141   }                                                                   \
142   TEST_F(TestName, HandleBoolInput) {                                 \
143     test_bool_input();                                                \
144   }                                                                   \
145   TEST_F(TestName, AllRealInputHalfOutputStaticDynamismSupport) {     \
146     test_all_real_input_half_output_static_dynamism_support();        \
147   }                                                                   \
148                                                                       \
149   TEST_F(TestName, AllRealInputBFloat16OutputStaticDynamismSupport) { \
150     test_all_real_input_bfloat16_output_static_dynamism_support();    \
151   }                                                                   \
152                                                                       \
153   TEST_F(TestName, AllRealInputFloatOutputStaticDynamismSupport) {    \
154     test_all_real_input_float_output_static_dynamism_support();       \
155   }                                                                   \
156                                                                       \
157   TEST_F(TestName, AllRealInputDoubleOutputStaticDynamismSupport) {   \
158     test_all_real_input_double_output_static_dynamism_support();      \
159   }                                                                   \
160                                                                       \
161   TEST_F(TestName, AllRealInputBFloat16OutputBoundDynamismSupport) {  \
162     test_all_real_input_bfloat16_output_bound_dynamism_support();     \
163   }                                                                   \
164                                                                       \
165   TEST_F(TestName, AllRealInputFloatOutputBoundDynamismSupport) {     \
166     test_all_real_input_float_output_bound_dynamism_support();        \
167   }                                                                   \
168                                                                       \
169   TEST_F(TestName, AllRealInputDoubleOutputBoundDynamismSupport) {    \
170     test_all_real_input_double_output_bound_dynamism_support();       \
171   }                                                                   \
172                                                                       \
173   TEST_F(TestName, AllRealInputFloatOutputUnboundDynamismSupport) {   \
174     test_all_real_input_float_output_unbound_dynamism_support();      \
175   }                                                                   \
176                                                                       \
177   TEST_F(TestName, AllRealInputDoubleOutputUnboundDynamismSupport) {  \
178     test_all_real_input_double_output_unbound_dynamism_support();     \
179   }                                                                   \
180                                                                       \
181   TEST_F(TestName, AllNonFloatOutputDTypeDies) {                      \
182     test_non_float_output_dtype_dies();                               \
183   }                                                                   \
184                                                                       \
185   TEST_F(TestName, MismatchedInputShapesDies) {                       \
186     test_mismatched_input_shapes_dies();                              \
187   }
188 
189 } // namespace torch::executor::testing
190