xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_select_scatter.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cinttypes>
10 #include <cstdint>
11 #include <cstring>
12 
13 #include <executorch/kernels/portable/cpu/util/index_util.h>
14 #include <executorch/runtime/kernel/kernel_includes.h>
15 
16 namespace torch {
17 namespace executor {
18 namespace native {
19 
20 using Tensor = exec_aten::Tensor;
21 
22 /// aten::select_scatter.out(Tensor self, Tensor src, int dim, SymInt index, *,
23 /// Tensor(a!) out) -> Tensor(a!)
select_scatter_out(KernelRuntimeContext & ctx,const Tensor & in,const Tensor & src,int64_t dim,int64_t index,Tensor & out)24 Tensor& select_scatter_out(
25     KernelRuntimeContext& ctx,
26     const Tensor& in,
27     const Tensor& src,
28     int64_t dim,
29     int64_t index,
30     Tensor& out) {
31   (void)ctx;
32 
33   ET_KERNEL_CHECK(
34       ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
35 
36   ET_KERNEL_CHECK(
37       ctx, tensors_have_same_dim_order(in, src, out), InvalidArgument, out);
38 
39   // Account for negative indices
40   if (dim < 0) {
41     dim += in.dim();
42   }
43 
44   ET_KERNEL_CHECK(ctx, dim >= 0 && dim < in.dim(), InvalidArgument, out);
45 
46   if (index < 0) {
47     index += in.size(dim);
48   }
49 
50   // Check args
51   ET_KERNEL_CHECK(
52       ctx,
53       check_select_scatter_args(in, src, dim, index, out),
54       InvalidArgument,
55       out);
56 
57   // If the input is an empty tensor, no other operation could be done. We just
58   // return the output.
59   if (in.numel() == 0) {
60     return out;
61   }
62 
63   // To start, copy the input into the output. Input will not be empty due to
64   // the checks performed above.
65   memcpy(out.mutable_data_ptr(), in.const_data_ptr(), in.nbytes());
66 
67   // Strides to help with memory address arithmetic
68   size_t leading_dims = getLeadingDims(in, dim);
69   size_t trailing_stride = getTrailingDims(in, dim);
70   size_t start_offset = index * trailing_stride;
71   size_t out_step = in.size(dim) * trailing_stride;
72 
73   ScalarType in_type = in.scalar_type();
74   ScalarType src_type = src.scalar_type();
75 
76   ET_SWITCH_REAL_TYPES_AND(
77       Bool, in_type, ctx, "select_scatter.out", CTYPE, [&]() {
78         ET_SWITCH_REAL_TYPES_AND(
79             Bool, src_type, ctx, "select_scatter.out", CTYPE_SRC, [&]() {
80               CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
81               const CTYPE_SRC* const src_data = src.const_data_ptr<CTYPE_SRC>();
82 
83               for (size_t i = 0; i < leading_dims; ++i) {
84                 for (size_t j = 0; j < trailing_stride; ++j) {
85                   out_data[start_offset + i * out_step + j] =
86                       convert<CTYPE, CTYPE_SRC>(
87                           src_data[i * trailing_stride + j]);
88                 }
89               }
90             });
91       });
92 
93   return out;
94 }
95 
96 } // namespace native
97 } // namespace executor
98 } // namespace torch
99