xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_permute_copy.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 
12 namespace torch {
13 namespace executor {
14 namespace native {
15 
16 using SizesType = exec_aten::SizesType;
17 using Tensor = exec_aten::Tensor;
18 using IntArrayRef = exec_aten::ArrayRef<int64_t>;
19 
20 namespace {
21 
increment_coordinate_permuted(const Tensor & tensor,size_t * const coordinate,IntArrayRef dims)22 void increment_coordinate_permuted(
23     const Tensor& tensor,
24     size_t* const coordinate,
25     IntArrayRef dims) {
26   for (int i = dims.size() - 1; i >= 0; i--) {
27     size_t d = dims[i] >= 0 ? dims[i] : dims[i] + tensor.dim();
28     coordinate[d]++;
29     if (coordinate[d] == tensor.size(d)) {
30       coordinate[d] = 0;
31     } else {
32       return;
33     }
34   }
35 }
36 
37 } // namespace
38 
permute_copy_out(KernelRuntimeContext & ctx,const Tensor & in,IntArrayRef dims,Tensor & out)39 Tensor& permute_copy_out(
40     KernelRuntimeContext& ctx,
41     const Tensor& in,
42     IntArrayRef dims,
43     Tensor& out) {
44   (void)ctx;
45 
46   ET_KERNEL_CHECK(
47       ctx, check_permute_copy_args(in, dims, out), InvalidArgument, out);
48 
49   ET_KERNEL_CHECK(
50       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
51 
52   Tensor::SizesType expected_out_size[kTensorDimensionLimit];
53   size_t expected_out_dim = 0;
54   get_permute_copy_out_target_size(
55       in, dims, expected_out_size, &expected_out_dim);
56   ET_KERNEL_CHECK(
57       ctx,
58       resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
59       InvalidArgument,
60       out);
61 
62   const auto in_type = out.scalar_type();
63 
64   size_t in_coord[kTensorDimensionLimit] = {0};
65   size_t trailing_dims_memo[kTensorDimensionLimit];
66   executorch::runtime::memoizeTrailingDims(in, trailing_dims_memo);
67 
68   // in and out must be the same dtype
69   ET_SWITCH_ALL_TYPES(in_type, ctx, "permute_copy.out", CTYPE, [&] {
70     const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
71     CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
72 
73     for (size_t i = 0; i < out.numel(); ++i) {
74       out_data[i] =
75           in_data[executorch::runtime::coordinateToIndexWithTrailingDimsMemo(
76               in, in_coord, trailing_dims_memo)];
77       increment_coordinate_permuted(in, in_coord, dims);
78     }
79   });
80 
81   return out;
82 }
83 
84 } // namespace native
85 } // namespace executor
86 } // namespace torch
87