1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
16
17 #include <gtest/gtest.h>
18
19 using namespace ::testing;
20 using exec_aten::IntArrayRef;
21 using exec_aten::Scalar;
22 using exec_aten::ScalarType;
23 using exec_aten::Tensor;
24 using torch::executor::testing::TensorFactory;
25
26 class OpScalarTensorOutTest : public OperatorTest {
27 protected:
op_scalar_tensor_out(const Scalar & s,Tensor & out)28 Tensor& op_scalar_tensor_out(const Scalar& s, Tensor& out) {
29 return torch::executor::aten::scalar_tensor_outf(context_, s, out);
30 }
31
32 template <typename CTYPE, ScalarType DTYPE>
test_scalar_tensor_out_0d(CTYPE value)33 void test_scalar_tensor_out_0d(CTYPE value) {
34 TensorFactory<DTYPE> tf;
35
36 std::vector<int32_t> sizes{};
37 Tensor expected = tf.make(sizes, /*data=*/{value});
38
39 Tensor out = tf.ones(sizes);
40 op_scalar_tensor_out(value, out);
41
42 EXPECT_TENSOR_EQ(out, expected);
43 }
44
45 template <typename CTYPE, ScalarType DTYPE>
test_scalar_tensor_out_1d(CTYPE value)46 void test_scalar_tensor_out_1d(CTYPE value) {
47 TensorFactory<DTYPE> tf;
48
49 std::vector<int32_t> sizes{1};
50 Tensor out = tf.ones(sizes);
51
52 ET_EXPECT_KERNEL_FAILURE(context_, op_scalar_tensor_out(value, out));
53 }
54
55 template <typename CTYPE, ScalarType DTYPE>
test_scalar_tensor_out_2d(CTYPE value)56 void test_scalar_tensor_out_2d(CTYPE value) {
57 TensorFactory<DTYPE> tf;
58
59 std::vector<int32_t> sizes{1, 1};
60 Tensor out = tf.ones(sizes);
61
62 ET_EXPECT_KERNEL_FAILURE(context_, op_scalar_tensor_out(value, out));
63 }
64
65 template <typename CTYPE, ScalarType DTYPE>
test_scalar_tensor_out_3d(CTYPE value)66 void test_scalar_tensor_out_3d(CTYPE value) {
67 TensorFactory<DTYPE> tf;
68
69 std::vector<int32_t> sizes{1, 1, 1};
70 Tensor out = tf.ones(sizes);
71
72 ET_EXPECT_KERNEL_FAILURE(context_, op_scalar_tensor_out(value, out));
73 }
74 };
75
76 #define GENERATE_TEST_0D(ctype, dtype) \
77 TEST_F(OpScalarTensorOutTest, dtype##TensorsDim0) { \
78 test_scalar_tensor_out_0d<ctype, ScalarType::dtype>(4); \
79 test_scalar_tensor_out_0d<ctype, ScalarType::dtype>(8); \
80 test_scalar_tensor_out_0d<ctype, ScalarType::dtype>(9); \
81 }
82
ET_FORALL_REAL_TYPES_AND3(Half,Bool,BFloat16,GENERATE_TEST_0D)83 ET_FORALL_REAL_TYPES_AND3(Half, Bool, BFloat16, GENERATE_TEST_0D)
84
85 #define GENERATE_TEST(ctype, dtype) \
86 TEST_F(OpScalarTensorOutTest, dtype##Tensors) { \
87 if (torch::executor::testing::SupportedFeatures::get()->is_aten) { \
88 GTEST_SKIP() << "ATen kernel resizes output to shape {}"; \
89 } \
90 test_scalar_tensor_out_1d<ctype, ScalarType::dtype>(2); \
91 test_scalar_tensor_out_2d<ctype, ScalarType::dtype>(2); \
92 test_scalar_tensor_out_3d<ctype, ScalarType::dtype>(2); \
93 test_scalar_tensor_out_1d<ctype, ScalarType::dtype>(4); \
94 test_scalar_tensor_out_2d<ctype, ScalarType::dtype>(4); \
95 test_scalar_tensor_out_3d<ctype, ScalarType::dtype>(4); \
96 test_scalar_tensor_out_1d<ctype, ScalarType::dtype>(7); \
97 test_scalar_tensor_out_2d<ctype, ScalarType::dtype>(7); \
98 test_scalar_tensor_out_3d<ctype, ScalarType::dtype>(7); \
99 }
100
101 ET_FORALL_REAL_TYPES_AND3(Half, Bool, BFloat16, GENERATE_TEST)
102
103 TEST_F(OpScalarTensorOutTest, InvalidOutShapeFails) {
104 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
105 GTEST_SKIP() << "ATen kernel will reshape output";
106 }
107
108 TensorFactory<ScalarType::Int> tf;
109 std::vector<int32_t> sizes{1, 2, 1};
110
111 Tensor out = tf.ones(sizes);
112 ET_EXPECT_KERNEL_FAILURE(context_, op_scalar_tensor_out(7, out));
113 }
114
TEST_F(OpScalarTensorOutTest,HalfSupport)115 TEST_F(OpScalarTensorOutTest, HalfSupport) {
116 TensorFactory<ScalarType::Half> tf;
117 Tensor out = tf.zeros({});
118
119 op_scalar_tensor_out(false, out);
120 EXPECT_TENSOR_CLOSE(out, tf.make({}, {0}));
121
122 op_scalar_tensor_out(true, out);
123 EXPECT_TENSOR_CLOSE(out, tf.make({}, {1}));
124
125 op_scalar_tensor_out(7, out);
126 EXPECT_TENSOR_CLOSE(out, tf.make({}, {7}));
127
128 op_scalar_tensor_out(2.5, out);
129 EXPECT_TENSOR_CLOSE(out, tf.make({}, {2.5}));
130
131 op_scalar_tensor_out(INFINITY, out);
132 EXPECT_TENSOR_CLOSE(out, tf.make({}, {INFINITY}));
133 }
134