xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/normalization_ops_util.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 
13 namespace torch {
14 namespace executor {
15 
16 bool check_batch_norm_args(
17     const Tensor& in,
18     const exec_aten::optional<Tensor>& weight,
19     const exec_aten::optional<Tensor>& bias,
20     const exec_aten::optional<Tensor>& running_mean,
21     const exec_aten::optional<Tensor>& running_var,
22     double momentum,
23     double eps,
24     Tensor& out,
25     Tensor& mean_out,
26     Tensor& var_out);
27 
28 bool check_layer_norm_args(
29     const Tensor& input,
30     IntArrayRef normalized_shape,
31     const exec_aten::optional<Tensor>& weight,
32     const exec_aten::optional<Tensor>& bias,
33     Tensor& out,
34     Tensor& mean_out,
35     Tensor& rstd_out);
36 
37 void get_layer_norm_out_target_size(
38     const Tensor& in,
39     IntArrayRef normalized_shape,
40     Tensor::SizesType* mean_rstd_sizes,
41     size_t* mean_rstd_ndim);
42 
43 bool check_group_norm_args(
44     const Tensor& input,
45     const exec_aten::optional<Tensor>& weight,
46     const exec_aten::optional<Tensor>& bias,
47     int64_t N,
48     int64_t C,
49     int64_t HxW,
50     int64_t group,
51     Tensor& out,
52     Tensor& mean_out,
53     Tensor& rstd_out);
54 
55 } // namespace executor
56 } // namespace torch
57