xref: /aosp_15_r20/external/executorch/kernels/test/op_scalar_tensor_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
16 
17 #include <gtest/gtest.h>
18 
19 using namespace ::testing;
20 using exec_aten::IntArrayRef;
21 using exec_aten::Scalar;
22 using exec_aten::ScalarType;
23 using exec_aten::Tensor;
24 using torch::executor::testing::TensorFactory;
25 
26 class OpScalarTensorOutTest : public OperatorTest {
27  protected:
op_scalar_tensor_out(const Scalar & s,Tensor & out)28   Tensor& op_scalar_tensor_out(const Scalar& s, Tensor& out) {
29     return torch::executor::aten::scalar_tensor_outf(context_, s, out);
30   }
31 
32   template <typename CTYPE, ScalarType DTYPE>
test_scalar_tensor_out_0d(CTYPE value)33   void test_scalar_tensor_out_0d(CTYPE value) {
34     TensorFactory<DTYPE> tf;
35 
36     std::vector<int32_t> sizes{};
37     Tensor expected = tf.make(sizes, /*data=*/{value});
38 
39     Tensor out = tf.ones(sizes);
40     op_scalar_tensor_out(value, out);
41 
42     EXPECT_TENSOR_EQ(out, expected);
43   }
44 
45   template <typename CTYPE, ScalarType DTYPE>
test_scalar_tensor_out_1d(CTYPE value)46   void test_scalar_tensor_out_1d(CTYPE value) {
47     TensorFactory<DTYPE> tf;
48 
49     std::vector<int32_t> sizes{1};
50     Tensor out = tf.ones(sizes);
51 
52     ET_EXPECT_KERNEL_FAILURE(context_, op_scalar_tensor_out(value, out));
53   }
54 
55   template <typename CTYPE, ScalarType DTYPE>
test_scalar_tensor_out_2d(CTYPE value)56   void test_scalar_tensor_out_2d(CTYPE value) {
57     TensorFactory<DTYPE> tf;
58 
59     std::vector<int32_t> sizes{1, 1};
60     Tensor out = tf.ones(sizes);
61 
62     ET_EXPECT_KERNEL_FAILURE(context_, op_scalar_tensor_out(value, out));
63   }
64 
65   template <typename CTYPE, ScalarType DTYPE>
test_scalar_tensor_out_3d(CTYPE value)66   void test_scalar_tensor_out_3d(CTYPE value) {
67     TensorFactory<DTYPE> tf;
68 
69     std::vector<int32_t> sizes{1, 1, 1};
70     Tensor out = tf.ones(sizes);
71 
72     ET_EXPECT_KERNEL_FAILURE(context_, op_scalar_tensor_out(value, out));
73   }
74 };
75 
76 #define GENERATE_TEST_0D(ctype, dtype)                      \
77   TEST_F(OpScalarTensorOutTest, dtype##TensorsDim0) {       \
78     test_scalar_tensor_out_0d<ctype, ScalarType::dtype>(4); \
79     test_scalar_tensor_out_0d<ctype, ScalarType::dtype>(8); \
80     test_scalar_tensor_out_0d<ctype, ScalarType::dtype>(9); \
81   }
82 
ET_FORALL_REAL_TYPES_AND3(Half,Bool,BFloat16,GENERATE_TEST_0D)83 ET_FORALL_REAL_TYPES_AND3(Half, Bool, BFloat16, GENERATE_TEST_0D)
84 
85 #define GENERATE_TEST(ctype, dtype)                                    \
86   TEST_F(OpScalarTensorOutTest, dtype##Tensors) {                      \
87     if (torch::executor::testing::SupportedFeatures::get()->is_aten) { \
88       GTEST_SKIP() << "ATen kernel resizes output to shape {}";        \
89     }                                                                  \
90     test_scalar_tensor_out_1d<ctype, ScalarType::dtype>(2);            \
91     test_scalar_tensor_out_2d<ctype, ScalarType::dtype>(2);            \
92     test_scalar_tensor_out_3d<ctype, ScalarType::dtype>(2);            \
93     test_scalar_tensor_out_1d<ctype, ScalarType::dtype>(4);            \
94     test_scalar_tensor_out_2d<ctype, ScalarType::dtype>(4);            \
95     test_scalar_tensor_out_3d<ctype, ScalarType::dtype>(4);            \
96     test_scalar_tensor_out_1d<ctype, ScalarType::dtype>(7);            \
97     test_scalar_tensor_out_2d<ctype, ScalarType::dtype>(7);            \
98     test_scalar_tensor_out_3d<ctype, ScalarType::dtype>(7);            \
99   }
100 
101 ET_FORALL_REAL_TYPES_AND3(Half, Bool, BFloat16, GENERATE_TEST)
102 
103 TEST_F(OpScalarTensorOutTest, InvalidOutShapeFails) {
104   if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
105     GTEST_SKIP() << "ATen kernel will reshape output";
106   }
107 
108   TensorFactory<ScalarType::Int> tf;
109   std::vector<int32_t> sizes{1, 2, 1};
110 
111   Tensor out = tf.ones(sizes);
112   ET_EXPECT_KERNEL_FAILURE(context_, op_scalar_tensor_out(7, out));
113 }
114 
TEST_F(OpScalarTensorOutTest,HalfSupport)115 TEST_F(OpScalarTensorOutTest, HalfSupport) {
116   TensorFactory<ScalarType::Half> tf;
117   Tensor out = tf.zeros({});
118 
119   op_scalar_tensor_out(false, out);
120   EXPECT_TENSOR_CLOSE(out, tf.make({}, {0}));
121 
122   op_scalar_tensor_out(true, out);
123   EXPECT_TENSOR_CLOSE(out, tf.make({}, {1}));
124 
125   op_scalar_tensor_out(7, out);
126   EXPECT_TENSOR_CLOSE(out, tf.make({}, {7}));
127 
128   op_scalar_tensor_out(2.5, out);
129   EXPECT_TENSOR_CLOSE(out, tf.make({}, {2.5}));
130 
131   op_scalar_tensor_out(INFINITY, out);
132   EXPECT_TENSOR_CLOSE(out, tf.make({}, {INFINITY}));
133 }
134