xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Normalization.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/TensorIterator.h>
4 #include <ATen/native/DispatchStub.h>
5 
6 namespace at::native {
7 
8 using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm);
9 DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub);
10 
11 enum class BatchNormBackend {
12   Native,
13   Cudnn,
14   Miopen,
15 };
16 
17 TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps);
18 
19 }  // namespace at::native
20