xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/copy_ops_util.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 
13 namespace torch {
14 namespace executor {
15 namespace {
16 
17 /**
18  * Copy input_data to output_data according to the stride and shape recursively
19  */
20 template <typename CTYPE>
_as_strided_copy(CTYPE * input_data,CTYPE * output_data,Tensor & out,ArrayRef<int64_t> size,ArrayRef<int64_t> stride,int64_t dim)21 void _as_strided_copy(
22     CTYPE* input_data,
23     CTYPE* output_data,
24     Tensor& out,
25     ArrayRef<int64_t> size,
26     ArrayRef<int64_t> stride,
27     int64_t dim) {
28   // the last dimension, copy data
29   if (dim == size.size() - 1) {
30     for (size_t i = 0; i < size.at(dim); ++i) {
31       output_data[i] = *input_data;
32       input_data += stride.at(dim);
33     }
34     return;
35   }
36   size_t trailing_dims = getTrailingDims(out, dim);
37   // recursively set data for the next dimension
38   for (size_t i = 0; i < size.at(dim); ++i) {
39     _as_strided_copy<CTYPE>(
40         input_data, output_data, out, size, stride, dim + 1);
41     input_data += stride.at(dim);
42     output_data += trailing_dims;
43   }
44 }
45 
46 } // namespace
47 
48 bool check_as_strided_copy_args(
49     const Tensor& in,
50     ArrayRef<int64_t> size,
51     ArrayRef<int64_t> stride,
52     optional<int64_t> storage_offset,
53     Tensor& out);
54 
55 template <typename CTYPE>
as_strided_copy(const Tensor & in,ArrayRef<int64_t> size,ArrayRef<int64_t> stride,int64_t offset,Tensor & out)56 void as_strided_copy(
57     const Tensor& in,
58     ArrayRef<int64_t> size,
59     ArrayRef<int64_t> stride,
60     int64_t offset,
61     Tensor& out) {
62   CTYPE* in_data = in.mutable_data_ptr<CTYPE>() + offset;
63   CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
64 
65   if (size.empty()) {
66     out_data[0] = in_data[0];
67   } else {
68     _as_strided_copy<CTYPE>(in_data, out_data, out, size, stride, 0);
69   }
70 }
71 
72 bool check_cat_args(
73     exec_aten::ArrayRef<Tensor> tensors,
74     int64_t dim,
75     Tensor& out);
76 
77 void get_cat_out_target_size(
78     exec_aten::ArrayRef<Tensor> tensors,
79     int64_t dim,
80     exec_aten::SizesType* out_sizes,
81     size_t* out_ndim);
82 
83 bool check_expand_copy_args(
84     const Tensor& self,
85     ArrayRef<int64_t> expand_sizes,
86     bool implicit,
87     Tensor& out);
88 
89 bool get_expand_copy_out_target_size(
90     exec_aten::ArrayRef<exec_aten::SizesType> self_sizes,
91     exec_aten::ArrayRef<int64_t> expand_sizes,
92     exec_aten::SizesType* output_sizes,
93     size_t* output_rank);
94 
95 bool check_permute_copy_args(const Tensor& in, IntArrayRef dims, Tensor& out);
96 
97 bool check_unbind_copy_args(const Tensor& in, int64_t dim, TensorList out);
98 
99 void get_permute_copy_out_target_size(
100     const Tensor& in,
101     IntArrayRef dims,
102     exec_aten::SizesType* out_sizes,
103     size_t* out_ndim);
104 
105 bool check_pixel_shuffle_args(
106     const Tensor& in,
107     int64_t upscale_factor,
108     Tensor& out);
109 
110 void get_pixel_shuffle_out_target_size(
111     const Tensor& in,
112     int64_t upscale_factor,
113     exec_aten::SizesType* out_sizes,
114     size_t* out_ndim);
115 
116 bool check_pixel_unshuffle_args(
117     const Tensor& in,
118     int64_t upscale_factor,
119     Tensor& out);
120 
121 void get_pixel_unshuffle_out_target_size(
122     const Tensor& in,
123     int64_t upscale_factor,
124     exec_aten::SizesType* out_sizes,
125     size_t* out_ndim);
126 
127 bool check_select_copy_out_args(
128     const Tensor& in,
129     int64_t dim,
130     int64_t index,
131     Tensor& out);
132 
133 void get_select_copy_out_target_size(
134     const Tensor& in,
135     int64_t dim,
136     exec_aten::SizesType* out_sizes,
137     size_t* out_ndim);
138 
139 bool check_split_with_sizes_copy_args(
140     const Tensor& in,
141     exec_aten::ArrayRef<int64_t> split_sizes,
142     int64_t dim,
143     TensorList out);
144 
145 void get_split_with_sizes_copy_out_target_size(
146     const Tensor& in,
147     int64_t split_size,
148     int64_t dim,
149     exec_aten::SizesType* out_sizes,
150     size_t* out_ndim);
151 
152 bool check_squeeze_copy_dim_args(
153     const Tensor in,
154     int64_t dim,
155     const Tensor out);
156 
157 void get_squeeze_copy_dim_out_target_size(
158     const Tensor in,
159     int64_t dim,
160     exec_aten::SizesType* out_sizes,
161     size_t* out_ndim);
162 
163 bool check_squeeze_copy_dims_args(
164     const Tensor in,
165     const exec_aten::ArrayRef<int64_t> dims,
166     const Tensor out);
167 
168 void get_squeeze_copy_dims_out_target_size(
169     const Tensor in,
170     const exec_aten::ArrayRef<int64_t> dims,
171     exec_aten::SizesType* out_sizes,
172     size_t* out_ndim);
173 
174 bool check_stack_args(
175     exec_aten::ArrayRef<Tensor> tensors,
176     int64_t dim,
177     Tensor& out);
178 
179 void get_stack_out_target_size(
180     exec_aten::ArrayRef<Tensor> tensors,
181     int64_t dim,
182     exec_aten::SizesType* out_sizes,
183     size_t* out_ndim);
184 
185 bool check_tril_args(const Tensor& in, Tensor& out);
186 
187 bool check_split_copy_args(
188     const Tensor& input,
189     int64_t split_size,
190     int64_t dim,
191     TensorList out);
192 
193 bool check_to_copy_args(
194     const Tensor& input,
195     bool non_blocking,
196     exec_aten::optional<exec_aten::MemoryFormat> memory_format,
197     Tensor& out);
198 
199 bool check__to_dim_order_copy_args(
200     const Tensor& input,
201     bool non_blocking,
202     exec_aten::OptionalArrayRef<int64_t> dim_order,
203     Tensor& out);
204 
205 bool check_unsqueeze_copy_args(
206     const Tensor input,
207     int64_t dim,
208     const Tensor out);
209 
210 bool check_view_copy_args(
211     const Tensor& self,
212     exec_aten::ArrayRef<int64_t> size_int64_t,
213     Tensor& out);
214 
215 bool get_view_copy_target_size(
216     const Tensor input,
217     exec_aten::ArrayRef<int64_t> size_int64_t,
218     int64_t dim,
219     exec_aten::SizesType* out_sizes);
220 
221 bool check_diagonal_copy_args(
222     const Tensor& in,
223     int64_t dim1,
224     int64_t dim2,
225     Tensor& out);
226 
227 void get_diagonal_copy_out_target_size(
228     const Tensor& in,
229     int64_t offset,
230     int64_t dim1,
231     int64_t dim2,
232     exec_aten::SizesType* out_sizes,
233     size_t* out_ndim);
234 
235 } // namespace executor
236 } // namespace torch
237