xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/padding_util.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstdint>
10 #include <cstring>
11 
12 #include <executorch/kernels/portable/cpu/util/padding_util.h>
13 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
14 
15 namespace torch {
16 namespace executor {
17 
check_padding_args(int64_t n,const Tensor & in,exec_aten::ArrayRef<int64_t> padding,Tensor & out,bool reflection)18 bool check_padding_args(
19     int64_t n,
20     const Tensor& in,
21     exec_aten::ArrayRef<int64_t> padding,
22     Tensor& out,
23     bool reflection) {
24   ET_LOG_AND_RETURN_IF_FALSE(padding.size() == 2 * n);
25   ET_LOG_AND_RETURN_IF_FALSE(in.dim() == n + 1 || in.dim() == n + 2);
26   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
27   for (size_t i = 1; i <= n; ++i) {
28     ET_LOG_AND_RETURN_IF_FALSE(
29         in.size(in.dim() - i) + padding[2 * i - 2] + padding[2 * i - 1] >= 0);
30     if (reflection) {
31       ET_LOG_AND_RETURN_IF_FALSE(
32           padding[2 * i - 2] < in.size(in.dim() - i) &&
33           padding[2 * i - 1] < in.size(in.dim() - i));
34     }
35   }
36   return true;
37 }
38 
get_padding_out_target_size(int64_t n,const Tensor & in,exec_aten::ArrayRef<int64_t> padding,Tensor::SizesType * out_sizes,size_t * out_ndim)39 void get_padding_out_target_size(
40     int64_t n,
41     const Tensor& in,
42     exec_aten::ArrayRef<int64_t> padding,
43     Tensor::SizesType* out_sizes,
44     size_t* out_ndim) {
45   *out_ndim = in.dim();
46   for (size_t i = 0; i < in.dim(); ++i) {
47     out_sizes[i] = in.size(i);
48   }
49   for (size_t i = 1; i <= n; ++i) {
50     out_sizes[in.dim() - i] =
51         in.size(in.dim() - i) + padding[2 * i - 2] + padding[2 * i - 1];
52   }
53 }
54 
55 } // namespace executor
56 } // namespace torch
57