xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/options/batchnorm.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/types.h>
6 
7 namespace torch {
8 namespace nn {
9 
10 /// Options for the `BatchNorm` module.
11 struct TORCH_API BatchNormOptions {
12   /* implicit */ BatchNormOptions(int64_t num_features);
13 
14   /// The number of features of the input tensor.
15   /// Changing this parameter after construction __has no effect__.
16   TORCH_ARG(int64_t, num_features);
17 
18   /// The epsilon value added for numerical stability.
19   /// Changing this parameter after construction __is effective__.
20   TORCH_ARG(double, eps) = 1e-5;
21 
22   /// A momentum multiplier for the mean and variance.
23   /// Changing this parameter after construction __is effective__.
24   TORCH_ARG(std::optional<double>, momentum) = 0.1;
25 
26   /// Whether to learn a scale and bias that are applied in an affine
27   /// transformation on the input.
28   /// Changing this parameter after construction __has no effect__.
29   TORCH_ARG(bool, affine) = true;
30 
31   /// Whether to store and update batch statistics (mean and variance) in the
32   /// module.
33   /// Changing this parameter after construction __has no effect__.
34   TORCH_ARG(bool, track_running_stats) = true;
35 };
36 
37 /// Options for the `BatchNorm1d` module.
38 ///
39 /// Example:
40 /// ```
41 /// BatchNorm1d
42 /// model(BatchNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true));
43 /// ```
44 using BatchNorm1dOptions = BatchNormOptions;
45 
46 /// Options for the `BatchNorm2d` module.
47 ///
48 /// Example:
49 /// ```
50 /// BatchNorm2d
51 /// model(BatchNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true));
52 /// ```
53 using BatchNorm2dOptions = BatchNormOptions;
54 
55 /// Options for the `BatchNorm3d` module.
56 ///
57 /// Example:
58 /// ```
59 /// BatchNorm3d
60 /// model(BatchNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true));
61 /// ```
62 using BatchNorm3dOptions = BatchNormOptions;
63 
64 // ============================================================================
65 
66 namespace functional {
67 
68 /// Options for `torch::nn::functional::batch_norm`.
69 ///
70 /// Example:
71 /// ```
72 /// namespace F = torch::nn::functional;
73 /// F::batch_norm(input, mean, variance,
74 /// F::BatchNormFuncOptions().weight(weight).bias(bias).momentum(0.1).eps(1e-05).training(false));
75 /// ```
76 struct TORCH_API BatchNormFuncOptions {
77   TORCH_ARG(Tensor, weight) = Tensor();
78 
79   TORCH_ARG(Tensor, bias) = Tensor();
80 
81   TORCH_ARG(bool, training) = false;
82 
83   /// A momentum multiplier for the mean and variance.
84   /// Changing this parameter after construction __is effective__.
85   TORCH_ARG(std::optional<double>, momentum) = 0.1;
86 
87   /// The epsilon value added for numerical stability.
88   /// Changing this parameter after construction __is effective__.
89   TORCH_ARG(double, eps) = 1e-5;
90 };
91 
92 } // namespace functional
93 
94 } // namespace nn
95 } // namespace torch
96