xref: /aosp_15_r20/external/executorch/kernels/test/UnaryUfuncRealHBBF16ToFloatHBF16Test.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #pragma once
10*523fa7a6SAndroid Build Coastguard Worker 
11*523fa7a6SAndroid Build Coastguard Worker #include <executorch/kernels/test/TestUtil.h>
12*523fa7a6SAndroid Build Coastguard Worker #include <executorch/kernels/test/supported_features.h>
13*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/exec_aten.h>
14*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
15*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
16*523fa7a6SAndroid Build Coastguard Worker 
17*523fa7a6SAndroid Build Coastguard Worker #include <gtest/gtest.h>
18*523fa7a6SAndroid Build Coastguard Worker 
19*523fa7a6SAndroid Build Coastguard Worker namespace torch::executor::testing {
20*523fa7a6SAndroid Build Coastguard Worker // Generic test harness for ops that use unary_ufunc_realhb_to_floath
21*523fa7a6SAndroid Build Coastguard Worker // -- in other words, ops that just apply an elementwise function
22*523fa7a6SAndroid Build Coastguard Worker // mapping to a float or half.
23*523fa7a6SAndroid Build Coastguard Worker class UnaryUfuncRealHBBF16ToFloatHBF16Test : public OperatorTest {
24*523fa7a6SAndroid Build Coastguard Worker  protected:
25*523fa7a6SAndroid Build Coastguard Worker   // Implement this to call the torch::executor::aten::op_outf function for the
26*523fa7a6SAndroid Build Coastguard Worker   // op.
27*523fa7a6SAndroid Build Coastguard Worker   virtual exec_aten::Tensor& op_out(
28*523fa7a6SAndroid Build Coastguard Worker       const exec_aten::Tensor& self,
29*523fa7a6SAndroid Build Coastguard Worker       exec_aten::Tensor& out) = 0;
30*523fa7a6SAndroid Build Coastguard Worker 
31*523fa7a6SAndroid Build Coastguard Worker   // Scalar reference implementation of the function in question for testing.
32*523fa7a6SAndroid Build Coastguard Worker   virtual double op_reference(double x) const = 0;
33*523fa7a6SAndroid Build Coastguard Worker 
34*523fa7a6SAndroid Build Coastguard Worker   // The SupportedFeatures system assumes that it can build each test
35*523fa7a6SAndroid Build Coastguard Worker   // target with a separate SupportedFeatures (really just one
36*523fa7a6SAndroid Build Coastguard Worker   // portable, one optimzed but between one and the infinite, two is
37*523fa7a6SAndroid Build Coastguard Worker   // ridiculous and can't exist). We work around that by calling
38*523fa7a6SAndroid Build Coastguard Worker   // SupportedFeatures::get() in the concrete test translation
39*523fa7a6SAndroid Build Coastguard Worker   // unit. You need to declare an override, but we implement it for you
40*523fa7a6SAndroid Build Coastguard Worker   // in IMPLEMENT_UNARY_UFUNC_REALHB_TO_FLOATH_TEST.
41*523fa7a6SAndroid Build Coastguard Worker   virtual SupportedFeatures* get_supported_features() const = 0;
42*523fa7a6SAndroid Build Coastguard Worker 
43*523fa7a6SAndroid Build Coastguard Worker   template <exec_aten::ScalarType IN_DTYPE, exec_aten::ScalarType OUT_DTYPE>
44*523fa7a6SAndroid Build Coastguard Worker   void test_floating_point_op_out(
45*523fa7a6SAndroid Build Coastguard Worker       const std::vector<int32_t>& out_shape = {1, 6},
46*523fa7a6SAndroid Build Coastguard Worker       exec_aten::TensorShapeDynamism dynamism =
47*523fa7a6SAndroid Build Coastguard Worker           exec_aten::TensorShapeDynamism::STATIC) {
48*523fa7a6SAndroid Build Coastguard Worker     TensorFactory<IN_DTYPE> tf_in;
49*523fa7a6SAndroid Build Coastguard Worker     TensorFactory<OUT_DTYPE> tf_out;
50*523fa7a6SAndroid Build Coastguard Worker 
51*523fa7a6SAndroid Build Coastguard Worker     exec_aten::Tensor out = tf_out.zeros(out_shape, dynamism);
52*523fa7a6SAndroid Build Coastguard Worker 
53*523fa7a6SAndroid Build Coastguard Worker     using IN_CTYPE = typename decltype(tf_in)::ctype;
54*523fa7a6SAndroid Build Coastguard Worker     using OUT_CTYPE = typename decltype(tf_out)::ctype;
55*523fa7a6SAndroid Build Coastguard Worker     std::vector<IN_CTYPE> test_vector = {0, 1, 3, 5, 10, 100};
56*523fa7a6SAndroid Build Coastguard Worker     std::vector<OUT_CTYPE> expected_vector;
57*523fa7a6SAndroid Build Coastguard Worker     for (int ii = 0; ii < test_vector.size(); ++ii) {
58*523fa7a6SAndroid Build Coastguard Worker       auto ref_result = this->op_reference(test_vector[ii]);
59*523fa7a6SAndroid Build Coastguard Worker       // Drop test cases with high magnitude results due to precision
60*523fa7a6SAndroid Build Coastguard Worker       // issues.
61*523fa7a6SAndroid Build Coastguard Worker       if ((std::abs(ref_result) > 1e30 || std::abs(ref_result) < -1e30)) {
62*523fa7a6SAndroid Build Coastguard Worker         test_vector[ii] = 2;
63*523fa7a6SAndroid Build Coastguard Worker         ref_result = this->op_reference(2);
64*523fa7a6SAndroid Build Coastguard Worker       }
65*523fa7a6SAndroid Build Coastguard Worker       expected_vector.push_back(ref_result);
66*523fa7a6SAndroid Build Coastguard Worker     }
67*523fa7a6SAndroid Build Coastguard Worker 
68*523fa7a6SAndroid Build Coastguard Worker     // clang-format off
69*523fa7a6SAndroid Build Coastguard Worker     op_out(tf_in.make({1, 6}, test_vector), out);
70*523fa7a6SAndroid Build Coastguard Worker 
71*523fa7a6SAndroid Build Coastguard Worker     auto expected = tf_out.make({1, 6}, expected_vector);
72*523fa7a6SAndroid Build Coastguard Worker     if (IN_DTYPE == ScalarType::BFloat16 || OUT_DTYPE == ScalarType::BFloat16) {
73*523fa7a6SAndroid Build Coastguard Worker       double rtol = executorch::runtime::testing::internal::kDefaultRtol;
74*523fa7a6SAndroid Build Coastguard Worker       // It appears we need a higher tolerance for at least some ATen
75*523fa7a6SAndroid Build Coastguard Worker       // tests, like aten_op_acosh_test.
76*523fa7a6SAndroid Build Coastguard Worker       if (get_supported_features()->is_aten) {
77*523fa7a6SAndroid Build Coastguard Worker         rtol = 3e-3;
78*523fa7a6SAndroid Build Coastguard Worker       }
79*523fa7a6SAndroid Build Coastguard Worker       EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, rtol, executorch::runtime::testing::internal::kDefaultBFloat16Atol);
80*523fa7a6SAndroid Build Coastguard Worker     } else if (IN_DTYPE == ScalarType::Half || OUT_DTYPE == ScalarType::Half) {
81*523fa7a6SAndroid Build Coastguard Worker       double rtol = executorch::runtime::testing::internal::kDefaultRtol;
82*523fa7a6SAndroid Build Coastguard Worker       // It appears we need a higher tolerance for at least some ATen
83*523fa7a6SAndroid Build Coastguard Worker       // tests, like aten_op_acosh_test.
84*523fa7a6SAndroid Build Coastguard Worker       if (get_supported_features()->is_aten) {
85*523fa7a6SAndroid Build Coastguard Worker         rtol = 1e-3;
86*523fa7a6SAndroid Build Coastguard Worker       }
87*523fa7a6SAndroid Build Coastguard Worker       EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, rtol, executorch::runtime::testing::internal::kDefaultHalfAtol);
88*523fa7a6SAndroid Build Coastguard Worker     } else {
89*523fa7a6SAndroid Build Coastguard Worker       EXPECT_TENSOR_CLOSE(out, expected);
90*523fa7a6SAndroid Build Coastguard Worker     }
91*523fa7a6SAndroid Build Coastguard Worker     // clang-format on
92*523fa7a6SAndroid Build Coastguard Worker   }
93*523fa7a6SAndroid Build Coastguard Worker 
94*523fa7a6SAndroid Build Coastguard Worker   // Unhandled output dtypes.
95*523fa7a6SAndroid Build Coastguard Worker   template <
96*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType INPUT_DTYPE,
97*523fa7a6SAndroid Build Coastguard Worker       exec_aten::ScalarType OUTPUT_DTYPE>
test_op_invalid_output_dtype_dies()98*523fa7a6SAndroid Build Coastguard Worker   void test_op_invalid_output_dtype_dies() {
99*523fa7a6SAndroid Build Coastguard Worker     TensorFactory<INPUT_DTYPE> tf;
100*523fa7a6SAndroid Build Coastguard Worker     TensorFactory<OUTPUT_DTYPE> tf_out;
101*523fa7a6SAndroid Build Coastguard Worker 
102*523fa7a6SAndroid Build Coastguard Worker     const std::vector<int32_t> sizes = {2, 5};
103*523fa7a6SAndroid Build Coastguard Worker 
104*523fa7a6SAndroid Build Coastguard Worker     exec_aten::Tensor in = tf.ones(sizes);
105*523fa7a6SAndroid Build Coastguard Worker     exec_aten::Tensor out = tf_out.zeros(sizes);
106*523fa7a6SAndroid Build Coastguard Worker 
107*523fa7a6SAndroid Build Coastguard Worker     ET_EXPECT_KERNEL_FAILURE(context_, op_out(in, out));
108*523fa7a6SAndroid Build Coastguard Worker   }
109*523fa7a6SAndroid Build Coastguard Worker 
110*523fa7a6SAndroid Build Coastguard Worker   void test_bool_input();
111*523fa7a6SAndroid Build Coastguard Worker 
112*523fa7a6SAndroid Build Coastguard Worker   void test_mismatched_input_shapes_dies();
113*523fa7a6SAndroid Build Coastguard Worker 
114*523fa7a6SAndroid Build Coastguard Worker   void test_all_real_input_half_output_static_dynamism_support();
115*523fa7a6SAndroid Build Coastguard Worker 
116*523fa7a6SAndroid Build Coastguard Worker   void test_all_real_input_bfloat16_output_static_dynamism_support();
117*523fa7a6SAndroid Build Coastguard Worker 
118*523fa7a6SAndroid Build Coastguard Worker   void test_all_real_input_float_output_static_dynamism_support();
119*523fa7a6SAndroid Build Coastguard Worker 
120*523fa7a6SAndroid Build Coastguard Worker   void test_all_real_input_double_output_static_dynamism_support();
121*523fa7a6SAndroid Build Coastguard Worker 
122*523fa7a6SAndroid Build Coastguard Worker   void test_all_real_input_half_output_bound_dynamism_support();
123*523fa7a6SAndroid Build Coastguard Worker 
124*523fa7a6SAndroid Build Coastguard Worker   void test_all_real_input_bfloat16_output_bound_dynamism_support();
125*523fa7a6SAndroid Build Coastguard Worker 
126*523fa7a6SAndroid Build Coastguard Worker   void test_all_real_input_float_output_bound_dynamism_support();
127*523fa7a6SAndroid Build Coastguard Worker 
128*523fa7a6SAndroid Build Coastguard Worker   void test_all_real_input_double_output_bound_dynamism_support();
129*523fa7a6SAndroid Build Coastguard Worker 
130*523fa7a6SAndroid Build Coastguard Worker   void test_all_real_input_float_output_unbound_dynamism_support();
131*523fa7a6SAndroid Build Coastguard Worker 
132*523fa7a6SAndroid Build Coastguard Worker   void test_all_real_input_double_output_unbound_dynamism_support();
133*523fa7a6SAndroid Build Coastguard Worker 
134*523fa7a6SAndroid Build Coastguard Worker   void test_non_float_output_dtype_dies();
135*523fa7a6SAndroid Build Coastguard Worker };
136*523fa7a6SAndroid Build Coastguard Worker 
137*523fa7a6SAndroid Build Coastguard Worker #define IMPLEMENT_UNARY_UFUNC_REALHB_TO_FLOATH_TEST(TestName)         \
138*523fa7a6SAndroid Build Coastguard Worker   torch::executor::testing::SupportedFeatures*                        \
139*523fa7a6SAndroid Build Coastguard Worker   TestName::get_supported_features() const {                          \
140*523fa7a6SAndroid Build Coastguard Worker     return torch::executor::testing::SupportedFeatures::get();        \
141*523fa7a6SAndroid Build Coastguard Worker   }                                                                   \
142*523fa7a6SAndroid Build Coastguard Worker   TEST_F(TestName, HandleBoolInput) {                                 \
143*523fa7a6SAndroid Build Coastguard Worker     test_bool_input();                                                \
144*523fa7a6SAndroid Build Coastguard Worker   }                                                                   \
145*523fa7a6SAndroid Build Coastguard Worker   TEST_F(TestName, AllRealInputHalfOutputStaticDynamismSupport) {     \
146*523fa7a6SAndroid Build Coastguard Worker     test_all_real_input_half_output_static_dynamism_support();        \
147*523fa7a6SAndroid Build Coastguard Worker   }                                                                   \
148*523fa7a6SAndroid Build Coastguard Worker                                                                       \
149*523fa7a6SAndroid Build Coastguard Worker   TEST_F(TestName, AllRealInputBFloat16OutputStaticDynamismSupport) { \
150*523fa7a6SAndroid Build Coastguard Worker     test_all_real_input_bfloat16_output_static_dynamism_support();    \
151*523fa7a6SAndroid Build Coastguard Worker   }                                                                   \
152*523fa7a6SAndroid Build Coastguard Worker                                                                       \
153*523fa7a6SAndroid Build Coastguard Worker   TEST_F(TestName, AllRealInputFloatOutputStaticDynamismSupport) {    \
154*523fa7a6SAndroid Build Coastguard Worker     test_all_real_input_float_output_static_dynamism_support();       \
155*523fa7a6SAndroid Build Coastguard Worker   }                                                                   \
156*523fa7a6SAndroid Build Coastguard Worker                                                                       \
157*523fa7a6SAndroid Build Coastguard Worker   TEST_F(TestName, AllRealInputDoubleOutputStaticDynamismSupport) {   \
158*523fa7a6SAndroid Build Coastguard Worker     test_all_real_input_double_output_static_dynamism_support();      \
159*523fa7a6SAndroid Build Coastguard Worker   }                                                                   \
160*523fa7a6SAndroid Build Coastguard Worker                                                                       \
161*523fa7a6SAndroid Build Coastguard Worker   TEST_F(TestName, AllRealInputBFloat16OutputBoundDynamismSupport) {  \
162*523fa7a6SAndroid Build Coastguard Worker     test_all_real_input_bfloat16_output_bound_dynamism_support();     \
163*523fa7a6SAndroid Build Coastguard Worker   }                                                                   \
164*523fa7a6SAndroid Build Coastguard Worker                                                                       \
165*523fa7a6SAndroid Build Coastguard Worker   TEST_F(TestName, AllRealInputFloatOutputBoundDynamismSupport) {     \
166*523fa7a6SAndroid Build Coastguard Worker     test_all_real_input_float_output_bound_dynamism_support();        \
167*523fa7a6SAndroid Build Coastguard Worker   }                                                                   \
168*523fa7a6SAndroid Build Coastguard Worker                                                                       \
169*523fa7a6SAndroid Build Coastguard Worker   TEST_F(TestName, AllRealInputDoubleOutputBoundDynamismSupport) {    \
170*523fa7a6SAndroid Build Coastguard Worker     test_all_real_input_double_output_bound_dynamism_support();       \
171*523fa7a6SAndroid Build Coastguard Worker   }                                                                   \
172*523fa7a6SAndroid Build Coastguard Worker                                                                       \
173*523fa7a6SAndroid Build Coastguard Worker   TEST_F(TestName, AllRealInputFloatOutputUnboundDynamismSupport) {   \
174*523fa7a6SAndroid Build Coastguard Worker     test_all_real_input_float_output_unbound_dynamism_support();      \
175*523fa7a6SAndroid Build Coastguard Worker   }                                                                   \
176*523fa7a6SAndroid Build Coastguard Worker                                                                       \
177*523fa7a6SAndroid Build Coastguard Worker   TEST_F(TestName, AllRealInputDoubleOutputUnboundDynamismSupport) {  \
178*523fa7a6SAndroid Build Coastguard Worker     test_all_real_input_double_output_unbound_dynamism_support();     \
179*523fa7a6SAndroid Build Coastguard Worker   }                                                                   \
180*523fa7a6SAndroid Build Coastguard Worker                                                                       \
181*523fa7a6SAndroid Build Coastguard Worker   TEST_F(TestName, AllNonFloatOutputDTypeDies) {                      \
182*523fa7a6SAndroid Build Coastguard Worker     test_non_float_output_dtype_dies();                               \
183*523fa7a6SAndroid Build Coastguard Worker   }                                                                   \
184*523fa7a6SAndroid Build Coastguard Worker                                                                       \
185*523fa7a6SAndroid Build Coastguard Worker   TEST_F(TestName, MismatchedInputShapesDies) {                       \
186*523fa7a6SAndroid Build Coastguard Worker     test_mismatched_input_shapes_dies();                              \
187*523fa7a6SAndroid Build Coastguard Worker   }
188*523fa7a6SAndroid Build Coastguard Worker 
189*523fa7a6SAndroid Build Coastguard Worker } // namespace torch::executor::testing
190