1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <executorch/runtime/kernel/kernel_includes.h> 12 13 namespace torch { 14 namespace executor { 15 16 bool check_batch_norm_args( 17 const Tensor& in, 18 const exec_aten::optional<Tensor>& weight, 19 const exec_aten::optional<Tensor>& bias, 20 const exec_aten::optional<Tensor>& running_mean, 21 const exec_aten::optional<Tensor>& running_var, 22 double momentum, 23 double eps, 24 Tensor& out, 25 Tensor& mean_out, 26 Tensor& var_out); 27 28 bool check_layer_norm_args( 29 const Tensor& input, 30 IntArrayRef normalized_shape, 31 const exec_aten::optional<Tensor>& weight, 32 const exec_aten::optional<Tensor>& bias, 33 Tensor& out, 34 Tensor& mean_out, 35 Tensor& rstd_out); 36 37 void get_layer_norm_out_target_size( 38 const Tensor& in, 39 IntArrayRef normalized_shape, 40 Tensor::SizesType* mean_rstd_sizes, 41 size_t* mean_rstd_ndim); 42 43 bool check_group_norm_args( 44 const Tensor& input, 45 const exec_aten::optional<Tensor>& weight, 46 const exec_aten::optional<Tensor>& bias, 47 int64_t N, 48 int64_t C, 49 int64_t HxW, 50 int64_t group, 51 Tensor& out, 52 Tensor& mean_out, 53 Tensor& rstd_out); 54 55 } // namespace executor 56 } // namespace torch 57