1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cinttypes>
10 #include <cstdint>
11 #include <cstring>
12
13 #include <executorch/kernels/portable/cpu/scalar_utils.h>
14 #include <executorch/kernels/portable/cpu/util/index_util.h>
15 #include <executorch/runtime/kernel/kernel_includes.h>
16
17 namespace torch {
18 namespace executor {
19 namespace native {
20
21 using Tensor = exec_aten::Tensor;
22 using ScalarType = exec_aten::ScalarType;
23
24 namespace {
25
26 template <typename CTYPE>
scatter_src_helper(const Tensor & in,int64_t dim,const Tensor & index,const Tensor & src,Tensor & out)27 void scatter_src_helper(
28 const Tensor& in,
29 int64_t dim,
30 const Tensor& index,
31 const Tensor& src,
32 Tensor& out) {
33 const CTYPE* in_data = in.const_data_ptr<CTYPE>();
34 const long* index_data = index.const_data_ptr<long>();
35 const CTYPE* src_data = src.const_data_ptr<CTYPE>();
36 CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
37
38 memcpy(out_data, in_data, in.nbytes());
39
40 if (dim < 0) {
41 dim += nonzero_dim(in);
42 }
43
44 for (size_t ix = 0; ix < index.numel(); ++ix) {
45 // @lint-ignore CLANGTIDY facebook-hte-CArray
46 size_t ix_coord[kTensorDimensionLimit];
47 indexToCoordinate(index, ix, ix_coord);
48
49 size_t src_ix = coordinateToIndex(src, ix_coord);
50
51 // @lint-ignore CLANGTIDY facebook-hte-CArray
52 size_t out_coord[kTensorDimensionLimit];
53 for (size_t i = 0; i < out.dim(); ++i) {
54 if (i == dim) {
55 out_coord[i] = index_data[ix];
56 } else {
57 out_coord[i] = ix_coord[i];
58 }
59 }
60 size_t out_ix = coordinateToIndex(out, out_coord);
61
62 out_data[out_ix] = src_data[src_ix];
63 }
64 }
65
66 template <typename CTYPE, typename CTYPE_VAL>
scatter_value_helper(const Tensor & in,int64_t dim,const Tensor & index,CTYPE_VAL val,Tensor & out)67 void scatter_value_helper(
68 const Tensor& in,
69 int64_t dim,
70 const Tensor& index,
71 CTYPE_VAL val,
72 Tensor& out) {
73 const CTYPE* in_data = in.const_data_ptr<CTYPE>();
74 const long* index_data = index.const_data_ptr<long>();
75 CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
76
77 memcpy(out_data, in_data, in.nbytes());
78
79 if (dim < 0) {
80 dim += nonzero_dim(in);
81 }
82
83 for (size_t ix = 0; ix < index.numel(); ++ix) {
84 // @lint-ignore CLANGTIDY facebook-hte-CArray
85 size_t ix_coord[kTensorDimensionLimit];
86 indexToCoordinate(index, ix, ix_coord);
87
88 // @lint-ignore CLANGTIDY facebook-hte-CArray
89 size_t out_coord[kTensorDimensionLimit];
90 for (size_t i = 0; i < out.dim(); ++i) {
91 if (i == dim) {
92 out_coord[i] = index_data[ix];
93 } else {
94 out_coord[i] = ix_coord[i];
95 }
96 }
97 size_t out_ix = coordinateToIndex(out, out_coord);
98
99 out_data[out_ix] = static_cast<CTYPE>(val);
100 }
101 }
102
103 } // namespace
104
scatter_src_out(KernelRuntimeContext & context,const Tensor & in,int64_t dim,const Tensor & index,const Tensor & src,Tensor & out)105 Tensor& scatter_src_out(
106 KernelRuntimeContext& context,
107 const Tensor& in,
108 int64_t dim,
109 const Tensor& index,
110 const Tensor& src,
111 Tensor& out) {
112 (void)context;
113
114 ET_KERNEL_CHECK(
115 context,
116 check_scatter_src_args(in, dim, index, src, out),
117 InvalidArgument,
118 out);
119
120 ET_KERNEL_CHECK(
121 context,
122 resize_tensor(out, in.sizes()) == Error::Ok,
123 InvalidArgument,
124 out);
125
126 constexpr auto name = "scatter.src_out";
127
128 ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
129 scatter_src_helper<CTYPE>(in, dim, index, src, out);
130 });
131
132 return out;
133 }
134
scatter_value_out(KernelRuntimeContext & ctx,const Tensor & in,int64_t dim,const Tensor & index,const Scalar & value,Tensor & out)135 Tensor& scatter_value_out(
136 KernelRuntimeContext& ctx,
137 const Tensor& in,
138 int64_t dim,
139 const Tensor& index,
140 const Scalar& value,
141 Tensor& out) {
142 (void)ctx;
143
144 ET_KERNEL_CHECK(
145 ctx,
146 check_scatter_value_args(in, dim, index, value, out),
147 InvalidArgument,
148 out);
149
150 ET_KERNEL_CHECK(
151 ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
152
153 ScalarType val_type = utils::get_scalar_dtype(value);
154
155 constexpr auto name = "scatter.value_out";
156
157 ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, name, CTYPE_VAL, [&] {
158 CTYPE_VAL val;
159 utils::extract_scalar(value, &val);
160
161 ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
162 scatter_value_helper<CTYPE>(in, dim, index, val, out);
163 });
164 });
165
166 return out;
167 }
168
169 } // namespace native
170 } // namespace executor
171 } // namespace torch
172