1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cstring>
10
11 #include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13 #include "executorch/kernels/portable/cpu/util/select_copy_util.h"
14
15 namespace torch {
16 namespace executor {
17
18 using Tensor = exec_aten::Tensor;
19
select_copy_util(const Tensor & in,int64_t dim,int64_t index,Tensor & out)20 Error select_copy_util(
21 const Tensor& in,
22 int64_t dim,
23 int64_t index,
24 Tensor& out) {
25 if (!check_select_copy_out_args(in, dim, index, out)) {
26 return Error::InvalidArgument;
27 }
28
29 if (dim < 0) {
30 dim += nonzero_dim(in);
31 }
32
33 Tensor::SizesType target_sizes[kTensorDimensionLimit];
34 size_t target_ndim = 0;
35 get_select_copy_out_target_size(in, dim, target_sizes, &target_ndim);
36
37 if (!(resize_tensor(out, {target_sizes, target_ndim}) == Error::Ok)) {
38 return Error::InvalidArgument;
39 }
40
41 if (!tensors_have_same_dim_order(in, out)) {
42 return Error::InvalidArgument;
43 }
44
45 // If the input is a empty tensor, no other operation could be done. We just
46 // return the output.
47 if (in.numel() == 0) {
48 return Error::Ok;
49 }
50 // The code past this point assumes that the tensors are non-empty.
51
52 // Support python-style negative indexing
53 if (index < 0) {
54 index += in.size(dim);
55 }
56
57 size_t leading_dims = getLeadingDims(in, dim);
58 size_t trailing_dims = getTrailingDims(in, dim);
59 size_t dim_length = in.size(dim);
60
61 // Number of bytes to copy in the each memcpy operation
62 size_t copy_size_per_op = trailing_dims * out.element_size();
63
64 // Step between the src locations of two adjcant memcpy operations
65 size_t src_step_per_op = dim_length * trailing_dims * in.element_size();
66
67 // the start point of data need to be copied is the start point of overall
68 // data chunk plus the offset between the overall start point and the first
69 // data to be copied.
70 char* input_data = in.mutable_data_ptr<char>();
71
72 size_t start_offset = index * trailing_dims * in.element_size();
73 char* src = input_data + start_offset;
74
75 char* dest = out.mutable_data_ptr<char>();
76
77 for (size_t j = 0; j < leading_dims; ++j) {
78 memcpy(dest, src, copy_size_per_op);
79 src += src_step_per_op;
80 dest += copy_size_per_op;
81 }
82
83 return Error::Ok;
84 }
85
86 } // namespace executor
87 } // namespace torch
88