1 #pragma once 2 3 #include <ATen/TensorIterator.h> 4 #include <ATen/native/DispatchStub.h> 5 6 namespace at::native { 7 8 using renorm_scale_factor_fn = void (*) (TensorIteratorBase& iter, double maxnorm); 9 DECLARE_DISPATCH(renorm_scale_factor_fn, renorm_scale_factor_stub); 10 11 enum class BatchNormBackend { 12 Native, 13 Cudnn, 14 Miopen, 15 }; 16 17 TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps); 18 19 } // namespace at::native 20