1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15
16 #include <gtest/gtest.h>
17
18 using namespace ::testing;
19 using exec_aten::Scalar;
20 using exec_aten::ScalarType;
21 using exec_aten::Tensor;
22 using torch::executor::testing::TensorFactory;
23
24 class OpBitwiseNotOutTest : public OperatorTest {
25 protected:
op_bitwise_not_out(const Tensor & a,Tensor & out)26 Tensor& op_bitwise_not_out(const Tensor& a, Tensor& out) {
27 return torch::executor::aten::bitwise_not_outf(context_, a, out);
28 }
29
30 // Common testing for bitwise_not operator
31 template <ScalarType DTYPE>
test_bitwise_not_out()32 void test_bitwise_not_out() {
33 TensorFactory<DTYPE> tf;
34
35 const std::vector<int32_t> sizes = {2, 2};
36
37 // Destination for the bitwise_not operator.
38 Tensor out = tf.zeros(sizes);
39
40 // Check that it matches the expected output.
41 op_bitwise_not_out(tf.make(sizes, /*data=*/{0, -1, -2, 3}), out);
42 EXPECT_TENSOR_EQ(out, tf.make(sizes, /*data=*/{-1, 0, 1, -4}));
43 }
44
45 // Unhandled output dtypes.
46 template <ScalarType DTYPE>
test_bitwise_not_invalid_dtype_dies()47 void test_bitwise_not_invalid_dtype_dies() {
48 TensorFactory<DTYPE> tf;
49
50 const std::vector<int32_t> sizes = {2, 5};
51
52 Tensor in = tf.ones(sizes);
53 Tensor out = tf.zeros(sizes);
54
55 ET_EXPECT_KERNEL_FAILURE(context_, op_bitwise_not_out(in, out));
56 }
57 };
58
59 template <>
test_bitwise_not_out()60 void OpBitwiseNotOutTest::test_bitwise_not_out<ScalarType::Byte>() {
61 TensorFactory<ScalarType::Byte> tf;
62
63 const std::vector<int32_t> sizes = {2, 2};
64
65 // Destination for the bitwise_not operator.
66 Tensor out = tf.zeros(sizes);
67
68 // Check that it matches the expected output.
69 op_bitwise_not_out(tf.make(sizes, /*data=*/{0, 1, 2, 3}), out);
70 EXPECT_TENSOR_EQ(out, tf.make(sizes, /*data=*/{255, 254, 253, 252}));
71 }
72
73 template <>
test_bitwise_not_out()74 void OpBitwiseNotOutTest::test_bitwise_not_out<ScalarType::Bool>() {
75 TensorFactory<ScalarType::Bool> tf;
76
77 const std::vector<int32_t> sizes = {2, 2};
78
79 // Destination for the bitwise_not operator.
80 Tensor out = tf.zeros(sizes);
81
82 // Check that it matches the expected output.
83 op_bitwise_not_out(tf.make(sizes, /*data=*/{true, false, true, false}), out);
84 EXPECT_TENSOR_EQ(out, tf.make(sizes, /*data=*/{false, true, false, true}));
85 }
86
TEST_F(OpBitwiseNotOutTest,AllIntInputOutputSupport)87 TEST_F(OpBitwiseNotOutTest, AllIntInputOutputSupport) {
88 #define TEST_ENTRY(ctype, dtype) test_bitwise_not_out<ScalarType::dtype>();
89 ET_FORALL_INT_TYPES(TEST_ENTRY);
90 #undef TEST_ENTRY
91 }
92
TEST_F(OpBitwiseNotOutTest,BoolInputOutputSupport)93 TEST_F(OpBitwiseNotOutTest, BoolInputOutputSupport) {
94 test_bitwise_not_out<ScalarType::Bool>();
95 }
96
97 // Mismatched shape tests.
TEST_F(OpBitwiseNotOutTest,MismatchedShapesDies)98 TEST_F(OpBitwiseNotOutTest, MismatchedShapesDies) {
99 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
100 GTEST_SKIP() << "ATen kernel can handle mismatched shapes";
101 }
102 TensorFactory<ScalarType::Int> tf;
103
104 Tensor a = tf.ones(/*sizes=*/{4});
105 Tensor out = tf.ones(/*sizes=*/{2, 2});
106
107 ET_EXPECT_KERNEL_FAILURE(context_, op_bitwise_not_out(a, out));
108 }
109
TEST_F(OpBitwiseNotOutTest,AllFloatInputDTypeDies)110 TEST_F(OpBitwiseNotOutTest, AllFloatInputDTypeDies) {
111 #define TEST_ENTRY(ctype, dtype) \
112 test_bitwise_not_invalid_dtype_dies<ScalarType::dtype>();
113 ET_FORALL_FLOAT_TYPES(TEST_ENTRY);
114 #undef TEST_ENTRY
115 }
116
117 /* %python
118 import torch
119 torch.manual_seed(0)
120 x = torch.randint(10, (3, 2))
121 res = torch.bitwise_not(x)
122 op = "op_bitwise_not_out"
123 dtype = "ScalarType::Int"
124 check = "EXPECT_TENSOR_EQ" */
125
TEST_F(OpBitwiseNotOutTest,DynamicShapeUpperBoundSameAsExpected)126 TEST_F(OpBitwiseNotOutTest, DynamicShapeUpperBoundSameAsExpected) {
127 /* %python
128 out_args = "{3, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND"
129 %rewrite(unary_op) */
130
131 TensorFactory<ScalarType::Int> tf;
132
133 Tensor x = tf.make({3, 2}, {4, 9, 3, 0, 3, 9});
134 Tensor expected = tf.make({3, 2}, {-5, -10, -4, -1, -4, -10});
135
136 Tensor out =
137 tf.zeros({3, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
138 op_bitwise_not_out(x, out);
139 EXPECT_TENSOR_EQ(out, expected);
140 }
141
TEST_F(OpBitwiseNotOutTest,DynamicShapeUpperBoundLargerThanExpected)142 TEST_F(OpBitwiseNotOutTest, DynamicShapeUpperBoundLargerThanExpected) {
143 /* %python
144 out_args = "{10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND"
145 %rewrite(unary_op) */
146
147 TensorFactory<ScalarType::Int> tf;
148
149 Tensor x = tf.make({3, 2}, {4, 9, 3, 0, 3, 9});
150 Tensor expected = tf.make({3, 2}, {-5, -10, -4, -1, -4, -10});
151
152 Tensor out =
153 tf.zeros({10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
154 op_bitwise_not_out(x, out);
155 EXPECT_TENSOR_EQ(out, expected);
156 }
157
TEST_F(OpBitwiseNotOutTest,DynamicShapeUnbound)158 TEST_F(OpBitwiseNotOutTest, DynamicShapeUnbound) {
159 GTEST_SKIP() << "Dynamic shape unbound not supported";
160 /* %python
161 out_args = "{1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND"
162 %rewrite(unary_op) */
163
164 TensorFactory<ScalarType::Int> tf;
165
166 Tensor x = tf.make({3, 2}, {4, 9, 3, 0, 3, 9});
167 Tensor expected = tf.make({3, 2}, {-5, -10, -4, -1, -4, -10});
168
169 Tensor out =
170 tf.zeros({1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
171 op_bitwise_not_out(x, out);
172 EXPECT_TENSOR_EQ(out, expected);
173 }
174