xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/select_copy_util.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstring>
10 
11 #include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13 #include "executorch/kernels/portable/cpu/util/select_copy_util.h"
14 
15 namespace torch {
16 namespace executor {
17 
18 using Tensor = exec_aten::Tensor;
19 
select_copy_util(const Tensor & in,int64_t dim,int64_t index,Tensor & out)20 Error select_copy_util(
21     const Tensor& in,
22     int64_t dim,
23     int64_t index,
24     Tensor& out) {
25   if (!check_select_copy_out_args(in, dim, index, out)) {
26     return Error::InvalidArgument;
27   }
28 
29   if (dim < 0) {
30     dim += nonzero_dim(in);
31   }
32 
33   Tensor::SizesType target_sizes[kTensorDimensionLimit];
34   size_t target_ndim = 0;
35   get_select_copy_out_target_size(in, dim, target_sizes, &target_ndim);
36 
37   if (!(resize_tensor(out, {target_sizes, target_ndim}) == Error::Ok)) {
38     return Error::InvalidArgument;
39   }
40 
41   if (!tensors_have_same_dim_order(in, out)) {
42     return Error::InvalidArgument;
43   }
44 
45   // If the input is a empty tensor, no other operation could be done. We just
46   // return the output.
47   if (in.numel() == 0) {
48     return Error::Ok;
49   }
50   // The code past this point assumes that the tensors are non-empty.
51 
52   // Support python-style negative indexing
53   if (index < 0) {
54     index += in.size(dim);
55   }
56 
57   size_t leading_dims = getLeadingDims(in, dim);
58   size_t trailing_dims = getTrailingDims(in, dim);
59   size_t dim_length = in.size(dim);
60 
61   // Number of bytes to copy in the each memcpy operation
62   size_t copy_size_per_op = trailing_dims * out.element_size();
63 
64   // Step between the src locations of two adjcant memcpy operations
65   size_t src_step_per_op = dim_length * trailing_dims * in.element_size();
66 
67   // the start point of data need to be copied is the start point of overall
68   // data chunk plus the offset between the overall start point and the first
69   // data to be copied.
70   char* input_data = in.mutable_data_ptr<char>();
71 
72   size_t start_offset = index * trailing_dims * in.element_size();
73   char* src = input_data + start_offset;
74 
75   char* dest = out.mutable_data_ptr<char>();
76 
77   for (size_t j = 0; j < leading_dims; ++j) {
78     memcpy(dest, src, copy_size_per_op);
79     src += src_step_per_op;
80     dest += copy_size_per_op;
81   }
82 
83   return Error::Ok;
84 }
85 
86 } // namespace executor
87 } // namespace torch
88