xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_scatter.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cinttypes>
10 #include <cstdint>
11 #include <cstring>
12 
13 #include <executorch/kernels/portable/cpu/scalar_utils.h>
14 #include <executorch/kernels/portable/cpu/util/index_util.h>
15 #include <executorch/runtime/kernel/kernel_includes.h>
16 
17 namespace torch {
18 namespace executor {
19 namespace native {
20 
21 using Tensor = exec_aten::Tensor;
22 using ScalarType = exec_aten::ScalarType;
23 
24 namespace {
25 
26 template <typename CTYPE>
scatter_src_helper(const Tensor & in,int64_t dim,const Tensor & index,const Tensor & src,Tensor & out)27 void scatter_src_helper(
28     const Tensor& in,
29     int64_t dim,
30     const Tensor& index,
31     const Tensor& src,
32     Tensor& out) {
33   const CTYPE* in_data = in.const_data_ptr<CTYPE>();
34   const long* index_data = index.const_data_ptr<long>();
35   const CTYPE* src_data = src.const_data_ptr<CTYPE>();
36   CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
37 
38   memcpy(out_data, in_data, in.nbytes());
39 
40   if (dim < 0) {
41     dim += nonzero_dim(in);
42   }
43 
44   for (size_t ix = 0; ix < index.numel(); ++ix) {
45     // @lint-ignore CLANGTIDY facebook-hte-CArray
46     size_t ix_coord[kTensorDimensionLimit];
47     indexToCoordinate(index, ix, ix_coord);
48 
49     size_t src_ix = coordinateToIndex(src, ix_coord);
50 
51     // @lint-ignore CLANGTIDY facebook-hte-CArray
52     size_t out_coord[kTensorDimensionLimit];
53     for (size_t i = 0; i < out.dim(); ++i) {
54       if (i == dim) {
55         out_coord[i] = index_data[ix];
56       } else {
57         out_coord[i] = ix_coord[i];
58       }
59     }
60     size_t out_ix = coordinateToIndex(out, out_coord);
61 
62     out_data[out_ix] = src_data[src_ix];
63   }
64 }
65 
66 template <typename CTYPE, typename CTYPE_VAL>
scatter_value_helper(const Tensor & in,int64_t dim,const Tensor & index,CTYPE_VAL val,Tensor & out)67 void scatter_value_helper(
68     const Tensor& in,
69     int64_t dim,
70     const Tensor& index,
71     CTYPE_VAL val,
72     Tensor& out) {
73   const CTYPE* in_data = in.const_data_ptr<CTYPE>();
74   const long* index_data = index.const_data_ptr<long>();
75   CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
76 
77   memcpy(out_data, in_data, in.nbytes());
78 
79   if (dim < 0) {
80     dim += nonzero_dim(in);
81   }
82 
83   for (size_t ix = 0; ix < index.numel(); ++ix) {
84     // @lint-ignore CLANGTIDY facebook-hte-CArray
85     size_t ix_coord[kTensorDimensionLimit];
86     indexToCoordinate(index, ix, ix_coord);
87 
88     // @lint-ignore CLANGTIDY facebook-hte-CArray
89     size_t out_coord[kTensorDimensionLimit];
90     for (size_t i = 0; i < out.dim(); ++i) {
91       if (i == dim) {
92         out_coord[i] = index_data[ix];
93       } else {
94         out_coord[i] = ix_coord[i];
95       }
96     }
97     size_t out_ix = coordinateToIndex(out, out_coord);
98 
99     out_data[out_ix] = static_cast<CTYPE>(val);
100   }
101 }
102 
103 } // namespace
104 
scatter_src_out(KernelRuntimeContext & context,const Tensor & in,int64_t dim,const Tensor & index,const Tensor & src,Tensor & out)105 Tensor& scatter_src_out(
106     KernelRuntimeContext& context,
107     const Tensor& in,
108     int64_t dim,
109     const Tensor& index,
110     const Tensor& src,
111     Tensor& out) {
112   (void)context;
113 
114   ET_KERNEL_CHECK(
115       context,
116       check_scatter_src_args(in, dim, index, src, out),
117       InvalidArgument,
118       out);
119 
120   ET_KERNEL_CHECK(
121       context,
122       resize_tensor(out, in.sizes()) == Error::Ok,
123       InvalidArgument,
124       out);
125 
126   constexpr auto name = "scatter.src_out";
127 
128   ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
129     scatter_src_helper<CTYPE>(in, dim, index, src, out);
130   });
131 
132   return out;
133 }
134 
scatter_value_out(KernelRuntimeContext & ctx,const Tensor & in,int64_t dim,const Tensor & index,const Scalar & value,Tensor & out)135 Tensor& scatter_value_out(
136     KernelRuntimeContext& ctx,
137     const Tensor& in,
138     int64_t dim,
139     const Tensor& index,
140     const Scalar& value,
141     Tensor& out) {
142   (void)ctx;
143 
144   ET_KERNEL_CHECK(
145       ctx,
146       check_scatter_value_args(in, dim, index, value, out),
147       InvalidArgument,
148       out);
149 
150   ET_KERNEL_CHECK(
151       ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
152 
153   ScalarType val_type = utils::get_scalar_dtype(value);
154 
155   constexpr auto name = "scatter.value_out";
156 
157   ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, name, CTYPE_VAL, [&] {
158     CTYPE_VAL val;
159     utils::extract_scalar(value, &val);
160 
161     ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
162       scatter_value_helper<CTYPE>(in, dim, index, val, out);
163     });
164   });
165 
166   return out;
167 }
168 
169 } // namespace native
170 } // namespace executor
171 } // namespace torch
172