1 #pragma once 2 3 #include <torch/arg.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/types.h> 6 7 namespace torch { 8 namespace nn { 9 10 /// Options for the `BatchNorm` module. 11 struct TORCH_API BatchNormOptions { 12 /* implicit */ BatchNormOptions(int64_t num_features); 13 14 /// The number of features of the input tensor. 15 /// Changing this parameter after construction __has no effect__. 16 TORCH_ARG(int64_t, num_features); 17 18 /// The epsilon value added for numerical stability. 19 /// Changing this parameter after construction __is effective__. 20 TORCH_ARG(double, eps) = 1e-5; 21 22 /// A momentum multiplier for the mean and variance. 23 /// Changing this parameter after construction __is effective__. 24 TORCH_ARG(std::optional<double>, momentum) = 0.1; 25 26 /// Whether to learn a scale and bias that are applied in an affine 27 /// transformation on the input. 28 /// Changing this parameter after construction __has no effect__. 29 TORCH_ARG(bool, affine) = true; 30 31 /// Whether to store and update batch statistics (mean and variance) in the 32 /// module. 33 /// Changing this parameter after construction __has no effect__. 34 TORCH_ARG(bool, track_running_stats) = true; 35 }; 36 37 /// Options for the `BatchNorm1d` module. 38 /// 39 /// Example: 40 /// ``` 41 /// BatchNorm1d 42 /// model(BatchNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); 43 /// ``` 44 using BatchNorm1dOptions = BatchNormOptions; 45 46 /// Options for the `BatchNorm2d` module. 47 /// 48 /// Example: 49 /// ``` 50 /// BatchNorm2d 51 /// model(BatchNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); 52 /// ``` 53 using BatchNorm2dOptions = BatchNormOptions; 54 55 /// Options for the `BatchNorm3d` module. 56 /// 57 /// Example: 58 /// ``` 59 /// BatchNorm3d 60 /// model(BatchNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); 61 /// ``` 62 using BatchNorm3dOptions = BatchNormOptions; 63 64 // ============================================================================ 65 66 namespace functional { 67 68 /// Options for `torch::nn::functional::batch_norm`. 69 /// 70 /// Example: 71 /// ``` 72 /// namespace F = torch::nn::functional; 73 /// F::batch_norm(input, mean, variance, 74 /// F::BatchNormFuncOptions().weight(weight).bias(bias).momentum(0.1).eps(1e-05).training(false)); 75 /// ``` 76 struct TORCH_API BatchNormFuncOptions { 77 TORCH_ARG(Tensor, weight) = Tensor(); 78 79 TORCH_ARG(Tensor, bias) = Tensor(); 80 81 TORCH_ARG(bool, training) = false; 82 83 /// A momentum multiplier for the mean and variance. 84 /// Changing this parameter after construction __is effective__. 85 TORCH_ARG(std::optional<double>, momentum) = 0.1; 86 87 /// The epsilon value added for numerical stability. 88 /// Changing this parameter after construction __is effective__. 89 TORCH_ARG(double, eps) = 1e-5; 90 }; 91 92 } // namespace functional 93 94 } // namespace nn 95 } // namespace torch 96