1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/runtime/core/exec_aten/exec_aten.h>
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
14 #include <gtest/gtest.h>
15
16 using namespace ::testing;
17 using exec_aten::ScalarType;
18 using exec_aten::Tensor;
19 using torch::executor::testing::TensorFactory;
20
21 class OpNonzeroTest : public OperatorTest {
22 protected:
op_nonzero_out(const Tensor & self,Tensor & out)23 Tensor& op_nonzero_out(const Tensor& self, Tensor& out) {
24 return torch::executor::aten::nonzero_outf(context_, self, out);
25 }
26
27 template <class CTYPE, exec_aten::ScalarType DTYPE>
test_dtype()28 void test_dtype() {
29 TensorFactory<DTYPE> tf_input;
30 TensorFactory<ScalarType::Long> tf_long;
31 // clang-format off
32 Tensor a = tf_input.make(/*sizes=*/{2, 2}, /*data=*/{2, 0,
33 2, 4});
34 // clang-format on
35 Tensor out = tf_long.zeros({3, 2});
36
37 op_nonzero_out(a, out);
38 // clang-format off
39 EXPECT_TENSOR_EQ(out, tf_long.make({3, 2}, {0, 0,
40 1, 0,
41 1, 1}));
42 // clang-format on
43 }
44 };
45
TEST_F(OpNonzeroTest,AllDtypesSupported)46 TEST_F(OpNonzeroTest, AllDtypesSupported) {
47 #define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
48 ET_FORALL_REAL_TYPES(TEST_ENTRY);
49 #undef TEST_ENTRY
50 }
51
52 #if !defined(USE_ATEN_LIB)
TEST_F(OpNonzeroTest,StaticShapeInconsistentSize)53 TEST_F(OpNonzeroTest, StaticShapeInconsistentSize) {
54 TensorFactory<ScalarType::Float> tf_input;
55 TensorFactory<ScalarType::Long> tf_long;
56 // clang-format off
57 Tensor a = tf_input.make(/*sizes=*/{2, 2}, /*data=*/{2, 0,
58 2, 4});
59 // clang-format on
60 // If we use static size here (by default), it won't work unless we know the
61 // output size
62 Tensor out =
63 tf_long.zeros({4, 2}, torch::executor::TensorShapeDynamism::STATIC);
64
65 ET_EXPECT_KERNEL_FAILURE(context_, op_nonzero_out(a, out));
66 }
67
TEST_F(OpNonzeroTest,DynamicShape)68 TEST_F(OpNonzeroTest, DynamicShape) {
69 TensorFactory<ScalarType::Float> tf_input;
70 TensorFactory<ScalarType::Long> tf_long;
71 // clang-format off
72 Tensor a = tf_input.make(/*sizes=*/{2, 2}, /*data=*/{2, 0,
73 2, 4});
74 // clang-format on
75 Tensor out = tf_long.zeros(
76 {4, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
77
78 op_nonzero_out(a, out);
79 // clang-format off
80 EXPECT_TENSOR_EQ(out, tf_long.make({3, 2}, {0, 0,
81 1, 0,
82 1, 1}));
83 // clang-format on
84 }
85
TEST_F(OpNonzeroTest,DynamicShapeInsufficientBuffer)86 TEST_F(OpNonzeroTest, DynamicShapeInsufficientBuffer) {
87 TensorFactory<ScalarType::Float> tf_input;
88 TensorFactory<ScalarType::Long> tf_long;
89 // clang-format off
90 Tensor a = tf_input.make(/*sizes=*/{2, 2}, /*data=*/{2, 0,
91 2, 4});
92 // clang-format on
93 Tensor out = tf_long.zeros(
94 {2, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
95
96 ET_EXPECT_KERNEL_FAILURE(context_, op_nonzero_out(a, out));
97 }
98 #endif
99