1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cinttypes>
10 #include <cstdint>
11 #include <cstring>
12
13 #include <executorch/kernels/portable/cpu/util/index_util.h>
14 #include <executorch/runtime/kernel/kernel_includes.h>
15
16 namespace torch {
17 namespace executor {
18 namespace native {
19
20 using Tensor = exec_aten::Tensor;
21
22 /// aten::select_scatter.out(Tensor self, Tensor src, int dim, SymInt index, *,
23 /// Tensor(a!) out) -> Tensor(a!)
select_scatter_out(KernelRuntimeContext & ctx,const Tensor & in,const Tensor & src,int64_t dim,int64_t index,Tensor & out)24 Tensor& select_scatter_out(
25 KernelRuntimeContext& ctx,
26 const Tensor& in,
27 const Tensor& src,
28 int64_t dim,
29 int64_t index,
30 Tensor& out) {
31 (void)ctx;
32
33 ET_KERNEL_CHECK(
34 ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
35
36 ET_KERNEL_CHECK(
37 ctx, tensors_have_same_dim_order(in, src, out), InvalidArgument, out);
38
39 // Account for negative indices
40 if (dim < 0) {
41 dim += in.dim();
42 }
43
44 ET_KERNEL_CHECK(ctx, dim >= 0 && dim < in.dim(), InvalidArgument, out);
45
46 if (index < 0) {
47 index += in.size(dim);
48 }
49
50 // Check args
51 ET_KERNEL_CHECK(
52 ctx,
53 check_select_scatter_args(in, src, dim, index, out),
54 InvalidArgument,
55 out);
56
57 // If the input is an empty tensor, no other operation could be done. We just
58 // return the output.
59 if (in.numel() == 0) {
60 return out;
61 }
62
63 // To start, copy the input into the output. Input will not be empty due to
64 // the checks performed above.
65 memcpy(out.mutable_data_ptr(), in.const_data_ptr(), in.nbytes());
66
67 // Strides to help with memory address arithmetic
68 size_t leading_dims = getLeadingDims(in, dim);
69 size_t trailing_stride = getTrailingDims(in, dim);
70 size_t start_offset = index * trailing_stride;
71 size_t out_step = in.size(dim) * trailing_stride;
72
73 ScalarType in_type = in.scalar_type();
74 ScalarType src_type = src.scalar_type();
75
76 ET_SWITCH_REAL_TYPES_AND(
77 Bool, in_type, ctx, "select_scatter.out", CTYPE, [&]() {
78 ET_SWITCH_REAL_TYPES_AND(
79 Bool, src_type, ctx, "select_scatter.out", CTYPE_SRC, [&]() {
80 CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
81 const CTYPE_SRC* const src_data = src.const_data_ptr<CTYPE_SRC>();
82
83 for (size_t i = 0; i < leading_dims; ++i) {
84 for (size_t j = 0; j < trailing_stride; ++j) {
85 out_data[start_offset + i * out_step + j] =
86 convert<CTYPE, CTYPE_SRC>(
87 src_data[i * trailing_stride + j]);
88 }
89 }
90 });
91 });
92
93 return out;
94 }
95
96 } // namespace native
97 } // namespace executor
98 } // namespace torch
99