1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11
12 namespace torch {
13 namespace executor {
14 namespace native {
15
16 using SizesType = exec_aten::SizesType;
17 using Tensor = exec_aten::Tensor;
18 using IntArrayRef = exec_aten::ArrayRef<int64_t>;
19
20 namespace {
21
increment_coordinate_permuted(const Tensor & tensor,size_t * const coordinate,IntArrayRef dims)22 void increment_coordinate_permuted(
23 const Tensor& tensor,
24 size_t* const coordinate,
25 IntArrayRef dims) {
26 for (int i = dims.size() - 1; i >= 0; i--) {
27 size_t d = dims[i] >= 0 ? dims[i] : dims[i] + tensor.dim();
28 coordinate[d]++;
29 if (coordinate[d] == tensor.size(d)) {
30 coordinate[d] = 0;
31 } else {
32 return;
33 }
34 }
35 }
36
37 } // namespace
38
permute_copy_out(KernelRuntimeContext & ctx,const Tensor & in,IntArrayRef dims,Tensor & out)39 Tensor& permute_copy_out(
40 KernelRuntimeContext& ctx,
41 const Tensor& in,
42 IntArrayRef dims,
43 Tensor& out) {
44 (void)ctx;
45
46 ET_KERNEL_CHECK(
47 ctx, check_permute_copy_args(in, dims, out), InvalidArgument, out);
48
49 ET_KERNEL_CHECK(
50 ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
51
52 Tensor::SizesType expected_out_size[kTensorDimensionLimit];
53 size_t expected_out_dim = 0;
54 get_permute_copy_out_target_size(
55 in, dims, expected_out_size, &expected_out_dim);
56 ET_KERNEL_CHECK(
57 ctx,
58 resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
59 InvalidArgument,
60 out);
61
62 const auto in_type = out.scalar_type();
63
64 size_t in_coord[kTensorDimensionLimit] = {0};
65 size_t trailing_dims_memo[kTensorDimensionLimit];
66 executorch::runtime::memoizeTrailingDims(in, trailing_dims_memo);
67
68 // in and out must be the same dtype
69 ET_SWITCH_ALL_TYPES(in_type, ctx, "permute_copy.out", CTYPE, [&] {
70 const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
71 CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
72
73 for (size_t i = 0; i < out.numel(); ++i) {
74 out_data[i] =
75 in_data[executorch::runtime::coordinateToIndexWithTrailingDimsMemo(
76 in, in_coord, trailing_dims_memo)];
77 increment_coordinate_permuted(in, in_coord, dims);
78 }
79 });
80
81 return out;
82 }
83
84 } // namespace native
85 } // namespace executor
86 } // namespace torch
87