xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/kernel_ops_util.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <tuple>
12 
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 
15 namespace torch {
16 namespace executor {
17 
18 /**
19  * Extracts a value at index i from an int array. If the array length is 1, then
20  * the first element will be returned regardless of what i is requested to
21  * simulate broadcasting.
22  */
23 inline int64_t val_at(IntArrayRef array, size_t i, int64_t default_value = 1) {
24   if (array.size() == 1) {
25     return array[0];
26   } else if (array.size() > 1) {
27     return array[i];
28   } else {
29     return default_value;
30   }
31 }
32 
33 /**
34  * Checks that all elements of an IntArray are greater than or equal to `val`.
35  */
36 bool int_array_all_ge(IntArrayRef array, int64_t val);
37 
38 bool stride_is_valid(IntArrayRef stride, size_t kernel_ndim, bool allow_empty);
39 
40 bool padding_is_valid(
41     IntArrayRef padding,
42     IntArrayRef kernel_size,
43     size_t kernel_ndim,
44     bool enforce_half_kernel = false);
45 
46 bool dilation_is_valid(IntArrayRef dilation, size_t kernel_ndim);
47 
48 bool output_size_is_valid(
49     exec_aten::ArrayRef<exec_aten::SizesType> output_size,
50     size_t kernel_ndim);
51 
52 void get_unsqueezed_sizes(
53     const Tensor& t,
54     int64_t unsqueeze_dim,
55     exec_aten::SizesType* sizes_arr,
56     size_t& ndim);
57 
58 void get_unsqueezed_dim_order(
59     const Tensor& t,
60     exec_aten::DimOrderType unsqueeze_dim,
61     exec_aten::DimOrderType* dim_order_arr);
62 
63 /**
64  * Given an input tensor and N-dim kernel parameters, calculates the output size
65  * of the N-dim kernel region.
66  */
67 void calculate_kernel_output_sizes(
68     const Tensor& in,
69     size_t kernel_ndim,
70     IntArrayRef kernel_sizes,
71     IntArrayRef stride,
72     IntArrayRef padding,
73     IntArrayRef dilation,
74     exec_aten::SizesType* out_sizes,
75     bool ceil_mode = false,
76     bool transposed = false,
77     IntArrayRef output_padding = {});
78 
79 //
80 // Utility functions to apply reduction over a N-dimensional kernel window
81 //
82 
83 /**
84  * Given a 3-D or 4-D tensor, applies a reduction function over a 2D kernel
85  * region for at a given batch and channel output index. Note that reduce_fn
86  * should return both an accumulator value and an index value; so for example,
87  * if reducing using the max() function, this function will track both the
88  * maximum value observed as well as the index of the maximum value that was
89  * observed. Therefore reduce_fn should follow the signature:
90  *
91  * ```
92  * std::tuple<T, int64_t> reduce_fn(
93  *    const T in_val,
94  *    const int64_t in_idx,
95  *    T accum,
96  *    int64_t accum_idx);
97  * ```
98  *
99  * Then, before writing out the accumulator and index values to out_ptr and
100  * indices_ptr respectively, accumulator will be mapped with map_fn, which
101  * should follow the signature:
102  *
103  * ```
104  * T map_fn(const int64_t count, const T accum);
105  * ```
106  *
107  * The index is a linear index with respect to the 2D plane formed by
108  * the height and width axes. So for a tensor of size (N, C, H, W), an element
109  * at location (n, c, h, w) will have a index of h * W + w. Although the an
110  * index accumulator is tracked, if `indices_ptr` is `nullptr` then it will not
111  * be used. Therefore, for reductions that do not care about indices, the index
112  * accumulator can be ignored.
113  *
114  * @param[in] reduce_fn The reduction function used to update accumulator
115  * values.
116  * @param[in] map_fn The map function used to post-process the accumulator
117  * value before writing to out_ptr.
118  * @param[in] include_pad Indicates if the reduction function should be applied
119  * over the padded regions implicitly added by the `padding` argument.
120  * @param[in] in_ptr The pointer to the input tensor data.
121  * @param[in] in_sizes Sizes array describing the size of the input tensor.
122  * @param[in] in_strides Strides array describing the strides of the input
123  * tensor.
124  * @param[in] kernel_size 2D array describing the height and width of the kernel
125  * region.
126  * @param[in] stride 2D array describing how the kernel region "traverses" over
127  * the input tensor.
128  * @param[in] padding 2D array describing padding to apply to the input tensor.
129  * @param[in] dilation 2D array describing the dilation to apply to the kernel
130  * region.
131  * @param[in] out_ptr The pointer to the output tensor data.
132  * @param[in] out_sizes Sizes array describing the size of the output tensor.
133  * @param[in] out_strides Strides array describing the strides of the output
134  * @param[in] indices_ptr The pointer to the indices tensor data. Can be
135  * `nullptr`.
136  * @param[in] batch The batch index of the output locations being computed.
137  * @param[in] out_c The channels index of the output locations being computed.
138  */
139 template <typename CTYPE, typename ReduceOp, typename MapOp>
kernel_reduction_then_map_2d(const ReduceOp & reduce_fn,const MapOp & map_fn,const bool include_pad,const CTYPE * const in_ptr,const exec_aten::ArrayRef<exec_aten::SizesType> in_sizes,const exec_aten::ArrayRef<exec_aten::StridesType> in_strides,const IntArrayRef kernel_size,const IntArrayRef stride,const IntArrayRef padding,const IntArrayRef dilation,CTYPE * const out_ptr,const exec_aten::ArrayRef<exec_aten::SizesType> out_sizes,const exec_aten::ArrayRef<exec_aten::StridesType> out_strides,int64_t * const indices_ptr,const size_t batch,const size_t out_c)140 void kernel_reduction_then_map_2d(
141     const ReduceOp& reduce_fn,
142     const MapOp& map_fn,
143     const bool include_pad,
144     const CTYPE* const in_ptr,
145     const exec_aten::ArrayRef<exec_aten::SizesType> in_sizes,
146     const exec_aten::ArrayRef<exec_aten::StridesType> in_strides,
147     const IntArrayRef kernel_size,
148     const IntArrayRef stride,
149     const IntArrayRef padding,
150     const IntArrayRef dilation,
151     CTYPE* const out_ptr,
152     const exec_aten::ArrayRef<exec_aten::SizesType> out_sizes,
153     const exec_aten::ArrayRef<exec_aten::StridesType> out_strides,
154     int64_t* const indices_ptr,
155     const size_t batch,
156     const size_t out_c) {
157   size_t in_dim = in_sizes.size();
158   size_t out_dim = out_sizes.size();
159 
160   size_t out_H = out_sizes[in_dim - 2];
161   size_t in_H = in_sizes[in_dim - 2];
162 
163   size_t out_W = out_sizes[in_dim - 1];
164   size_t in_W = in_sizes[in_dim - 1];
165 
166   exec_aten::SizesType in_coord[kTensorDimensionLimit];
167   exec_aten::SizesType out_coord[kTensorDimensionLimit];
168   if (in_dim == 4) {
169     in_coord[0] = batch;
170     out_coord[0] = batch;
171   }
172   in_coord[in_dim - 3] = out_c;
173   out_coord[in_dim - 3] = out_c;
174 
175   int64_t k_H = val_at(kernel_size, 0);
176   int64_t k_W = val_at(kernel_size, 1);
177   int64_t s_H = val_at(stride, 0, /*default_value=*/k_H);
178   int64_t s_W = val_at(stride, 1, /*default_value=*/k_W);
179   int64_t p_H = val_at(padding, 0, /*default_value=*/0);
180   int64_t p_W = val_at(padding, 1, /*default_value=*/0);
181   int64_t d_H = val_at(dilation, 0, /*default_value=*/1);
182   int64_t d_W = val_at(dilation, 1, /*default_value=*/1);
183 
184   // Compute 2D output region
185   for (size_t out_y = 0; out_y < out_H; ++out_y) {
186     out_coord[in_dim - 2] = out_y;
187     for (size_t out_x = 0; out_x < out_W; ++out_x) {
188       out_coord[in_dim - 1] = out_x;
189 
190       bool accum_initialized = false;
191       CTYPE accum = 0;
192       int64_t accum_idx = 0;
193       int64_t count = 0;
194 
195       int64_t ih0 = out_y * s_H - p_H;
196       int64_t iw0 = out_x * s_W - p_W;
197       int64_t ih1 = std::min(ih0 + k_H, static_cast<int64_t>(in_H) + p_H);
198       int64_t iw1 = std::min(iw0 + k_W, static_cast<int64_t>(in_W) + p_W);
199       int64_t pool_size = (ih1 - ih0) * (iw1 - iw0);
200       ih0 = std::max(ih0, (int64_t)0);
201       iw0 = std::max(iw0, (int64_t)0);
202       ih1 = std::min(ih1, static_cast<int64_t>(in_H));
203       iw1 = std::min(iw1, static_cast<int64_t>(in_W));
204 
205       if (ih0 >= ih1 || iw0 >= iw1) {
206         continue;
207       }
208 
209       if (include_pad) {
210         count = pool_size;
211       } else {
212         count = (ih1 - ih0) * (iw1 - iw0);
213       }
214 
215       for (size_t w_y = 0; w_y < k_H; ++w_y) {
216         int64_t stride_y = s_H;
217         int64_t padding_y = p_H;
218         int64_t dilation_y = d_H;
219 
220         size_t in_y = stride_y * out_y + dilation_y * w_y - padding_y;
221         in_coord[in_dim - 2] = in_y;
222 
223         for (size_t w_x = 0; w_x < k_W; ++w_x) {
224           int64_t stride_x = s_W;
225           int64_t padding_x = p_W;
226           int64_t dilation_x = d_W;
227 
228           size_t in_x = stride_x * out_x + dilation_x * w_x - padding_x;
229           in_coord[in_dim - 1] = in_x;
230 
231           const bool x_in_bound = (in_x >= 0 && in_x < in_W);
232           const bool y_in_bound = (in_y >= 0 && in_y < in_H);
233           const bool xy_in_bound = (x_in_bound && y_in_bound);
234 
235           CTYPE in_val = 0;
236           if (xy_in_bound) {
237             size_t in_idx =
238                 calculate_linear_index(in_coord, in_strides.data(), in_dim);
239             in_val = in_ptr[in_idx];
240           }
241 
242           int64_t idx = in_y * in_W + in_x;
243           if (include_pad) {
244             idx =
245                 in_y + padding_y * (in_W + 2 * padding_x) + (in_x + padding_x);
246           }
247 
248           if (xy_in_bound) {
249             if (!accum_initialized) {
250               accum = in_val;
251               accum_idx = idx;
252               accum_initialized = true;
253             } else {
254               std::tuple<CTYPE, int64_t> ret =
255                   reduce_fn(in_val, idx, accum, accum_idx);
256               accum = std::get<0>(ret);
257               accum_idx = std::get<1>(ret);
258             }
259           }
260         }
261       }
262 
263       size_t out_idx =
264           calculate_linear_index(out_coord, out_strides.data(), out_dim);
265       out_ptr[out_idx] = map_fn(count, accum);
266       if (indices_ptr) {
267         indices_ptr[out_idx] = accum_idx;
268       }
269     }
270   }
271 }
272 
273 /**
274  * Given a 3-D {C, H, W} or 4-D {N, C, H, W} tensor, applies a reduction
275  * function over a 2D kernel region, which will return two accumulator values
276  * (the first associated with the reduced value and the second associated with
277  * an input index). Then apply a map function to the first accumulator value
278  * before writing to out. Optionally, the second accumulator value will be
279  * written to indices, if provided.
280  *
281  * reduce_fn should have the following
282  * signature:
283  *
284  * ```
285  * std::tuple<T, int64_t> reduce_fn(
286  *    const T in_val,
287  *    const int64_t in_idx,
288  *    const T accum,
289  *    const int64_t accum_idx);
290  * ```
291  *
292  * map_fn should have the following signature:
293  *
294  * ```
295  * T map_fn(const int64_t count, const T accum);
296  * ```
297  *
298  * TODO(ssjia) Allow this to handle 1-D kernels as well by unsqueezing
299  * appropriately.
300  *
301  * @param[in] reduce_fn The reduction function used to update accumulator
302  * values.
303  * @param[in] map_fn The map function used to post-process the accumulated
304  * value before writing to out.
305  * @param[in] include_pad Indicates if the reduction function should be applied
306  * over the padded regions implicitly added by the `padding` argument.
307  * @param[in] in The input tensor.
308  * @param[in] kernel_size 2D array describing the height and width of the kernel
309  * region.
310  * @param[in] stride 2D array describing how the kernel region "traverses" over
311  * the input tensor.
312  * @param[in] padding 2D array describing padding to apply to the input tensor.
313  * @param[in] dilation 2D array describing the dilation to apply to the kernel
314  * region.
315  * @param[in] out The output tensor.
316  * @param[in] indices An optional indices output tensor to write out to.
317  */
318 template <typename CTYPE, typename ReduceOp, typename MapOp>
319 void apply_kernel_2d_reduce_then_map_fn(
320     const ReduceOp& reduce_fn,
321     const MapOp& map_fn,
322     const bool include_pad,
323     const Tensor& in,
324     const IntArrayRef kernel_size,
325     const IntArrayRef stride,
326     const IntArrayRef padding,
327     const IntArrayRef dilation,
328     Tensor& out,
329     exec_aten::optional<Tensor> indices = {}) {
330   exec_aten::ArrayRef<exec_aten::SizesType> in_sizes = in.sizes();
331   exec_aten::ArrayRef<exec_aten::SizesType> out_sizes = out.sizes();
332 
333   exec_aten::ArrayRef<exec_aten::DimOrderType> in_dim_order = in.dim_order();
334   exec_aten::ArrayRef<exec_aten::DimOrderType> out_dim_order = out.dim_order();
335 
336   exec_aten::StridesType in_strides[kTensorDimensionLimit];
337   dim_order_to_stride_nocheck(
338       in_sizes.data(), in_dim_order.data(), in_sizes.size(), in_strides);
339 
340   exec_aten::StridesType out_strides[kTensorDimensionLimit];
341   dim_order_to_stride_nocheck(
342       out_sizes.data(), out_dim_order.data(), out_sizes.size(), out_strides);
343 
344   CTYPE* const out_ptr = out.mutable_data_ptr<CTYPE>();
345   const CTYPE* const in_ptr = in.const_data_ptr<CTYPE>();
346 
347   int64_t* indices_ptr = nullptr;
348   if (indices.has_value()) {
349     indices_ptr = indices.value().mutable_data_ptr<int64_t>();
350   }
351 
352   size_t batch_size = 1;
353   if (in.dim() == 4) {
354     batch_size = in_sizes[0];
355   }
356   for (size_t batch = 0; batch < batch_size; ++batch) {
357     for (size_t channel = 0; channel < in_sizes[in.dim() - 3]; ++channel) {
358       kernel_reduction_then_map_2d(
359           reduce_fn,
360           map_fn,
361           include_pad,
362           in_ptr,
363           in_sizes,
364           {in_strides, 4},
365           kernel_size,
366           stride,
367           padding,
368           dilation,
369           out_ptr,
370           out_sizes,
371           {out_strides, 4},
372           indices_ptr,
373           batch,
374           channel);
375     }
376   }
377 }
378 
379 //
380 // Operator specific utility functions
381 //
382 
383 bool check_arange_args(double start, double end, double step, Tensor& out);
384 
385 bool check_avg_pool2d_args(
386     const Tensor& in,
387     const IntArrayRef kernel_size,
388     const IntArrayRef stride,
389     const IntArrayRef padding,
390     const bool ceil_mode,
391     const bool count_include_pad,
392     const exec_aten::optional<int64_t>& divisor_override,
393     const Tensor& out);
394 
395 void get_avg_pool2d_out_target_size(
396     const Tensor& in,
397     const IntArrayRef kernel_size,
398     const IntArrayRef stride,
399     const IntArrayRef padding,
400     const bool ceil_mode,
401     exec_aten::SizesType* const out_sizes,
402     size_t* const out_ndim);
403 
404 bool check_convolution_args(
405     const Tensor& in,
406     const Tensor& weight,
407     const exec_aten::optional<Tensor>& bias,
408     IntArrayRef stride,
409     IntArrayRef padding,
410     IntArrayRef dilation,
411     bool transposed,
412     IntArrayRef output_padding,
413     int64_t groups,
414     const Tensor& out);
415 
416 void get_convolution_out_target_size(
417     const Tensor& in,
418     const Tensor& weight,
419     IntArrayRef stride,
420     IntArrayRef padding,
421     IntArrayRef dilation,
422     bool transposed,
423     IntArrayRef output_padding,
424     int64_t groups,
425     exec_aten::SizesType* out_sizes,
426     size_t* out_ndim);
427 
428 bool check_cumsum_args(
429     const Tensor& self,
430     int64_t dim,
431     optional<ScalarType> enforced_dtype,
432     Tensor& out);
433 
434 bool check_max_pool2d_with_indices_args(
435     const Tensor& in,
436     IntArrayRef kernel_size,
437     IntArrayRef stride,
438     IntArrayRef padding,
439     IntArrayRef dilation,
440     bool ceil_mode,
441     Tensor& out,
442     Tensor& indices);
443 
444 void get_max_pool2d_with_indices_out_target_size(
445     const Tensor& in,
446     IntArrayRef kernel_size,
447     IntArrayRef stride,
448     IntArrayRef padding,
449     IntArrayRef dilation,
450     bool ceil_mode,
451     exec_aten::SizesType* out_sizes,
452     size_t* out_ndim);
453 
454 bool check_masked_fill_args(
455     const Tensor& in,
456     const Tensor& mask,
457     const Scalar& value,
458     Tensor& out);
459 
460 bool check_constant_pad_args(
461     const Tensor& in,
462     IntArrayRef pad,
463     const Scalar& value,
464     Tensor& out);
465 
466 Error resize_constant_pad_output(
467     const Tensor& in,
468     IntArrayRef pad,
469     Tensor& out);
470 
471 bool check_embedding_args(
472     const Tensor& weight,
473     const Tensor& indices,
474     const Tensor& out);
475 
476 Error resize_embedding_output(
477     const Tensor& weight,
478     const Tensor& indices,
479     const Tensor& out);
480 
481 bool check_alpha_type(
482     const ScalarType alpha_type,
483     const ScalarType common_type);
484 
485 } // namespace executor
486 } // namespace torch
487