xref: /aosp_15_r20/external/executorch/kernels/test/op_nonzero_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/runtime/core/exec_aten/exec_aten.h>
12 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
14 #include <gtest/gtest.h>
15 
16 using namespace ::testing;
17 using exec_aten::ScalarType;
18 using exec_aten::Tensor;
19 using torch::executor::testing::TensorFactory;
20 
21 class OpNonzeroTest : public OperatorTest {
22  protected:
op_nonzero_out(const Tensor & self,Tensor & out)23   Tensor& op_nonzero_out(const Tensor& self, Tensor& out) {
24     return torch::executor::aten::nonzero_outf(context_, self, out);
25   }
26 
27   template <class CTYPE, exec_aten::ScalarType DTYPE>
test_dtype()28   void test_dtype() {
29     TensorFactory<DTYPE> tf_input;
30     TensorFactory<ScalarType::Long> tf_long;
31     // clang-format off
32     Tensor a = tf_input.make(/*sizes=*/{2, 2}, /*data=*/{2, 0,
33                                                          2, 4});
34     // clang-format on
35     Tensor out = tf_long.zeros({3, 2});
36 
37     op_nonzero_out(a, out);
38     // clang-format off
39     EXPECT_TENSOR_EQ(out, tf_long.make({3, 2}, {0, 0,
40                                                 1, 0,
41                                                 1, 1}));
42     // clang-format on
43   }
44 };
45 
TEST_F(OpNonzeroTest,AllDtypesSupported)46 TEST_F(OpNonzeroTest, AllDtypesSupported) {
47 #define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
48   ET_FORALL_REAL_TYPES(TEST_ENTRY);
49 #undef TEST_ENTRY
50 }
51 
52 #if !defined(USE_ATEN_LIB)
TEST_F(OpNonzeroTest,StaticShapeInconsistentSize)53 TEST_F(OpNonzeroTest, StaticShapeInconsistentSize) {
54   TensorFactory<ScalarType::Float> tf_input;
55   TensorFactory<ScalarType::Long> tf_long;
56   // clang-format off
57   Tensor a = tf_input.make(/*sizes=*/{2, 2}, /*data=*/{2, 0,
58                                                        2, 4});
59   // clang-format on
60   // If we use static size here (by default), it won't work unless we know the
61   // output size
62   Tensor out =
63       tf_long.zeros({4, 2}, torch::executor::TensorShapeDynamism::STATIC);
64 
65   ET_EXPECT_KERNEL_FAILURE(context_, op_nonzero_out(a, out));
66 }
67 
TEST_F(OpNonzeroTest,DynamicShape)68 TEST_F(OpNonzeroTest, DynamicShape) {
69   TensorFactory<ScalarType::Float> tf_input;
70   TensorFactory<ScalarType::Long> tf_long;
71   // clang-format off
72   Tensor a = tf_input.make(/*sizes=*/{2, 2}, /*data=*/{2, 0,
73                                                        2, 4});
74   // clang-format on
75   Tensor out = tf_long.zeros(
76       {4, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
77 
78   op_nonzero_out(a, out);
79   // clang-format off
80   EXPECT_TENSOR_EQ(out, tf_long.make({3, 2}, {0, 0,
81                                               1, 0,
82                                               1, 1}));
83   // clang-format on
84 }
85 
TEST_F(OpNonzeroTest,DynamicShapeInsufficientBuffer)86 TEST_F(OpNonzeroTest, DynamicShapeInsufficientBuffer) {
87   TensorFactory<ScalarType::Float> tf_input;
88   TensorFactory<ScalarType::Long> tf_long;
89   // clang-format off
90   Tensor a = tf_input.make(/*sizes=*/{2, 2}, /*data=*/{2, 0,
91                                                        2, 4});
92   // clang-format on
93   Tensor out = tf_long.zeros(
94       {2, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
95 
96   ET_EXPECT_KERNEL_FAILURE(context_, op_nonzero_out(a, out));
97 }
98 #endif
99