xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_nonzero.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cmath>
10 #include <cstring>
11 
12 #include <executorch/kernels/portable/cpu/util/index_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 #include <executorch/runtime/platform/assert.h>
15 
16 namespace torch {
17 namespace executor {
18 namespace native {
19 
20 using Tensor = exec_aten::Tensor;
21 using ScalarType = exec_aten::ScalarType;
22 using SizesType = exec_aten::SizesType;
23 
24 namespace {
25 
increment_index(size_t * index,const ArrayRef<SizesType> sizes)26 void increment_index(size_t* index, const ArrayRef<SizesType> sizes) {
27   for (ssize_t i = sizes.size() - 1; i >= 0; --i) {
28     index[i]++;
29     if (index[i] == sizes[i]) {
30       index[i] = 0;
31     } else {
32       return;
33     }
34   }
35 }
36 
37 /**
38  * Two pass algorithm where we first count the number of non zeros, then resize
39  * out to the appropriate size, and then loop again and properly write into out
40  */
41 template <typename CTYPE>
nonzero(KernelRuntimeContext & ctx,const Tensor & input,Tensor & output)42 void nonzero(KernelRuntimeContext& ctx, const Tensor& input, Tensor& output) {
43   const CTYPE* in_data = input.const_data_ptr<CTYPE>();
44   size_t lim = input.numel();
45   int32_t num_nonzero = 0;
46 
47   // Count number of non zeros
48   for (size_t i = 0; i < lim; ++i) {
49     if (in_data[i] != 0) {
50       num_nonzero++;
51     }
52   }
53 
54   // resize out
55   SizesType out_shape[2] = {
56       static_cast<SizesType>(num_nonzero), static_cast<SizesType>(input.dim())};
57   ET_KERNEL_CHECK(
58       ctx,
59       resize_tensor(output, ArrayRef<exec_aten::SizesType>(out_shape, 2)) ==
60           Error::Ok,
61       InvalidArgument, );
62 
63   size_t index[kTensorDimensionLimit];
64   memset(index, 0, sizeof(index));
65 
66   int64_t* out_data = output.mutable_data_ptr<int64_t>();
67   size_t out_idx = 0;
68 
69   // Loop again and this time write the proper indices into out
70   for (size_t i = 0; i < lim; i++) {
71     if (in_data[i] != 0) {
72       for (size_t j = 0; j < input.dim(); j++) {
73         out_data[out_idx++] = index[j];
74       }
75     }
76     increment_index(index, input.sizes());
77   }
78 }
79 
80 } // namespace
81 
82 /**
83  * Determines the non zero indices of input.
84  * Out is a 2-D tensor where every row is a non zero index of the input.
85  */
nonzero_out(KernelRuntimeContext & ctx,const Tensor & in,Tensor & out)86 Tensor& nonzero_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
87   (void)ctx;
88 
89   ET_KERNEL_CHECK(ctx, check_nonzero_args(in, out), InvalidArgument, out);
90 
91   ET_SWITCH_REAL_TYPES_AND(
92       Bool, in.scalar_type(), ctx, "nonzero.out", CTYPE, [&] {
93         nonzero<CTYPE>(ctx, in, out);
94       });
95 
96   return out;
97 }
98 
99 } // namespace native
100 } // namespace executor
101 } // namespace torch
102