1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cmath>
10 #include <cstring>
11
12 #include <executorch/kernels/portable/cpu/util/index_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 #include <executorch/runtime/platform/assert.h>
15
16 namespace torch {
17 namespace executor {
18 namespace native {
19
20 using Tensor = exec_aten::Tensor;
21 using ScalarType = exec_aten::ScalarType;
22 using SizesType = exec_aten::SizesType;
23
24 namespace {
25
increment_index(size_t * index,const ArrayRef<SizesType> sizes)26 void increment_index(size_t* index, const ArrayRef<SizesType> sizes) {
27 for (ssize_t i = sizes.size() - 1; i >= 0; --i) {
28 index[i]++;
29 if (index[i] == sizes[i]) {
30 index[i] = 0;
31 } else {
32 return;
33 }
34 }
35 }
36
37 /**
38 * Two pass algorithm where we first count the number of non zeros, then resize
39 * out to the appropriate size, and then loop again and properly write into out
40 */
41 template <typename CTYPE>
nonzero(KernelRuntimeContext & ctx,const Tensor & input,Tensor & output)42 void nonzero(KernelRuntimeContext& ctx, const Tensor& input, Tensor& output) {
43 const CTYPE* in_data = input.const_data_ptr<CTYPE>();
44 size_t lim = input.numel();
45 int32_t num_nonzero = 0;
46
47 // Count number of non zeros
48 for (size_t i = 0; i < lim; ++i) {
49 if (in_data[i] != 0) {
50 num_nonzero++;
51 }
52 }
53
54 // resize out
55 SizesType out_shape[2] = {
56 static_cast<SizesType>(num_nonzero), static_cast<SizesType>(input.dim())};
57 ET_KERNEL_CHECK(
58 ctx,
59 resize_tensor(output, ArrayRef<exec_aten::SizesType>(out_shape, 2)) ==
60 Error::Ok,
61 InvalidArgument, );
62
63 size_t index[kTensorDimensionLimit];
64 memset(index, 0, sizeof(index));
65
66 int64_t* out_data = output.mutable_data_ptr<int64_t>();
67 size_t out_idx = 0;
68
69 // Loop again and this time write the proper indices into out
70 for (size_t i = 0; i < lim; i++) {
71 if (in_data[i] != 0) {
72 for (size_t j = 0; j < input.dim(); j++) {
73 out_data[out_idx++] = index[j];
74 }
75 }
76 increment_index(index, input.sizes());
77 }
78 }
79
80 } // namespace
81
82 /**
83 * Determines the non zero indices of input.
84 * Out is a 2-D tensor where every row is a non zero index of the input.
85 */
nonzero_out(KernelRuntimeContext & ctx,const Tensor & in,Tensor & out)86 Tensor& nonzero_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
87 (void)ctx;
88
89 ET_KERNEL_CHECK(ctx, check_nonzero_args(in, out), InvalidArgument, out);
90
91 ET_SWITCH_REAL_TYPES_AND(
92 Bool, in.scalar_type(), ctx, "nonzero.out", CTYPE, [&] {
93 nonzero<CTYPE>(ctx, in, out);
94 });
95
96 return out;
97 }
98
99 } // namespace native
100 } // namespace executor
101 } // namespace torch
102