1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <executorch/runtime/kernel/kernel_includes.h>
12
13 namespace torch {
14 namespace executor {
15 namespace {
16
17 /**
18 * Copy input_data to output_data according to the stride and shape recursively
19 */
20 template <typename CTYPE>
_as_strided_copy(CTYPE * input_data,CTYPE * output_data,Tensor & out,ArrayRef<int64_t> size,ArrayRef<int64_t> stride,int64_t dim)21 void _as_strided_copy(
22 CTYPE* input_data,
23 CTYPE* output_data,
24 Tensor& out,
25 ArrayRef<int64_t> size,
26 ArrayRef<int64_t> stride,
27 int64_t dim) {
28 // the last dimension, copy data
29 if (dim == size.size() - 1) {
30 for (size_t i = 0; i < size.at(dim); ++i) {
31 output_data[i] = *input_data;
32 input_data += stride.at(dim);
33 }
34 return;
35 }
36 size_t trailing_dims = getTrailingDims(out, dim);
37 // recursively set data for the next dimension
38 for (size_t i = 0; i < size.at(dim); ++i) {
39 _as_strided_copy<CTYPE>(
40 input_data, output_data, out, size, stride, dim + 1);
41 input_data += stride.at(dim);
42 output_data += trailing_dims;
43 }
44 }
45
46 } // namespace
47
48 bool check_as_strided_copy_args(
49 const Tensor& in,
50 ArrayRef<int64_t> size,
51 ArrayRef<int64_t> stride,
52 optional<int64_t> storage_offset,
53 Tensor& out);
54
55 template <typename CTYPE>
as_strided_copy(const Tensor & in,ArrayRef<int64_t> size,ArrayRef<int64_t> stride,int64_t offset,Tensor & out)56 void as_strided_copy(
57 const Tensor& in,
58 ArrayRef<int64_t> size,
59 ArrayRef<int64_t> stride,
60 int64_t offset,
61 Tensor& out) {
62 CTYPE* in_data = in.mutable_data_ptr<CTYPE>() + offset;
63 CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
64
65 if (size.empty()) {
66 out_data[0] = in_data[0];
67 } else {
68 _as_strided_copy<CTYPE>(in_data, out_data, out, size, stride, 0);
69 }
70 }
71
72 bool check_cat_args(
73 exec_aten::ArrayRef<Tensor> tensors,
74 int64_t dim,
75 Tensor& out);
76
77 void get_cat_out_target_size(
78 exec_aten::ArrayRef<Tensor> tensors,
79 int64_t dim,
80 exec_aten::SizesType* out_sizes,
81 size_t* out_ndim);
82
83 bool check_expand_copy_args(
84 const Tensor& self,
85 ArrayRef<int64_t> expand_sizes,
86 bool implicit,
87 Tensor& out);
88
89 bool get_expand_copy_out_target_size(
90 exec_aten::ArrayRef<exec_aten::SizesType> self_sizes,
91 exec_aten::ArrayRef<int64_t> expand_sizes,
92 exec_aten::SizesType* output_sizes,
93 size_t* output_rank);
94
95 bool check_permute_copy_args(const Tensor& in, IntArrayRef dims, Tensor& out);
96
97 bool check_unbind_copy_args(const Tensor& in, int64_t dim, TensorList out);
98
99 void get_permute_copy_out_target_size(
100 const Tensor& in,
101 IntArrayRef dims,
102 exec_aten::SizesType* out_sizes,
103 size_t* out_ndim);
104
105 bool check_pixel_shuffle_args(
106 const Tensor& in,
107 int64_t upscale_factor,
108 Tensor& out);
109
110 void get_pixel_shuffle_out_target_size(
111 const Tensor& in,
112 int64_t upscale_factor,
113 exec_aten::SizesType* out_sizes,
114 size_t* out_ndim);
115
116 bool check_pixel_unshuffle_args(
117 const Tensor& in,
118 int64_t upscale_factor,
119 Tensor& out);
120
121 void get_pixel_unshuffle_out_target_size(
122 const Tensor& in,
123 int64_t upscale_factor,
124 exec_aten::SizesType* out_sizes,
125 size_t* out_ndim);
126
127 bool check_select_copy_out_args(
128 const Tensor& in,
129 int64_t dim,
130 int64_t index,
131 Tensor& out);
132
133 void get_select_copy_out_target_size(
134 const Tensor& in,
135 int64_t dim,
136 exec_aten::SizesType* out_sizes,
137 size_t* out_ndim);
138
139 bool check_split_with_sizes_copy_args(
140 const Tensor& in,
141 exec_aten::ArrayRef<int64_t> split_sizes,
142 int64_t dim,
143 TensorList out);
144
145 void get_split_with_sizes_copy_out_target_size(
146 const Tensor& in,
147 int64_t split_size,
148 int64_t dim,
149 exec_aten::SizesType* out_sizes,
150 size_t* out_ndim);
151
152 bool check_squeeze_copy_dim_args(
153 const Tensor in,
154 int64_t dim,
155 const Tensor out);
156
157 void get_squeeze_copy_dim_out_target_size(
158 const Tensor in,
159 int64_t dim,
160 exec_aten::SizesType* out_sizes,
161 size_t* out_ndim);
162
163 bool check_squeeze_copy_dims_args(
164 const Tensor in,
165 const exec_aten::ArrayRef<int64_t> dims,
166 const Tensor out);
167
168 void get_squeeze_copy_dims_out_target_size(
169 const Tensor in,
170 const exec_aten::ArrayRef<int64_t> dims,
171 exec_aten::SizesType* out_sizes,
172 size_t* out_ndim);
173
174 bool check_stack_args(
175 exec_aten::ArrayRef<Tensor> tensors,
176 int64_t dim,
177 Tensor& out);
178
179 void get_stack_out_target_size(
180 exec_aten::ArrayRef<Tensor> tensors,
181 int64_t dim,
182 exec_aten::SizesType* out_sizes,
183 size_t* out_ndim);
184
185 bool check_tril_args(const Tensor& in, Tensor& out);
186
187 bool check_split_copy_args(
188 const Tensor& input,
189 int64_t split_size,
190 int64_t dim,
191 TensorList out);
192
193 bool check_to_copy_args(
194 const Tensor& input,
195 bool non_blocking,
196 exec_aten::optional<exec_aten::MemoryFormat> memory_format,
197 Tensor& out);
198
199 bool check__to_dim_order_copy_args(
200 const Tensor& input,
201 bool non_blocking,
202 exec_aten::OptionalArrayRef<int64_t> dim_order,
203 Tensor& out);
204
205 bool check_unsqueeze_copy_args(
206 const Tensor input,
207 int64_t dim,
208 const Tensor out);
209
210 bool check_view_copy_args(
211 const Tensor& self,
212 exec_aten::ArrayRef<int64_t> size_int64_t,
213 Tensor& out);
214
215 bool get_view_copy_target_size(
216 const Tensor input,
217 exec_aten::ArrayRef<int64_t> size_int64_t,
218 int64_t dim,
219 exec_aten::SizesType* out_sizes);
220
221 bool check_diagonal_copy_args(
222 const Tensor& in,
223 int64_t dim1,
224 int64_t dim2,
225 Tensor& out);
226
227 void get_diagonal_copy_out_target_size(
228 const Tensor& in,
229 int64_t offset,
230 int64_t dim1,
231 int64_t dim2,
232 exec_aten::SizesType* out_sizes,
233 size_t* out_ndim);
234
235 } // namespace executor
236 } // namespace torch
237