xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/instancenorm.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/modules/batchnorm.h>
4 #include <torch/nn/options/instancenorm.h>
5 
6 namespace torch {
7 namespace nn {
8 
9 /// Base class for all (dimension-specialized) instance norm modules
10 template <size_t D, typename Derived>
11 class InstanceNormImpl
12     : public torch::nn::NormImplBase<D, Derived, InstanceNormOptions> {
13  private:
apply_instance_norm(const Tensor & input)14   inline Tensor apply_instance_norm(const Tensor& input) {
15     return torch::nn::functional::detail::instance_norm(
16         input,
17         this->running_mean,
18         this->running_var,
19         this->weight,
20         this->bias,
21         this->is_training() || !this->options.track_running_stats(),
22         this->options.momentum(),
23         this->options.eps());
24   }
25 
handle_no_batch_input(const Tensor & input)26   inline Tensor handle_no_batch_input(const Tensor& input) {
27     return this->apply_instance_norm(input.unsqueeze(0)).squeeze(0);
28   }
29 
30  public:
31   using torch::nn::NormImplBase<D, Derived, InstanceNormOptions>::NormImplBase;
32 
forward(const Tensor & input)33   Tensor forward(const Tensor& input) {
34     this->_check_input_dim(input);
35 
36     // For InstanceNorm1D, 2D is unbatched and 3D is batched
37     // For InstanceNorm2D, 3D is unbatched and 4D is batched
38     // For InstanceNorm3D, 4D is unbatched and 5D is batched
39     // check if input does not have a batch-dim
40     if (input.dim() == D + 1) {
41       return this->handle_no_batch_input(input);
42     }
43 
44     return this->apply_instance_norm(input);
45   }
46 
47   /// Pretty prints the `InstanceNorm{1,2,3}d` module into the given `stream`.
pretty_print(std::ostream & stream)48   void pretty_print(std::ostream& stream) const override {
49     stream << std::boolalpha << "torch::nn::InstanceNorm" << D << "d("
50            << this->options.num_features() << ", "
51            << "eps=" << this->options.eps() << ", "
52            << "momentum=" << this->options.momentum() << ", "
53            << "affine=" << this->options.affine() << ", "
54            << "track_running_stats=" << this->options.track_running_stats()
55            << ")";
56   }
57 };
58 
59 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm1d
60 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
61 
62 /// Applies the InstanceNorm1d function.
63 /// See https://pytorch.org/docs/main/nn.html#torch.nn.InstanceNorm1d to learn
64 /// about the exact behavior of this module.
65 ///
66 /// See the documentation for `torch::nn::InstanceNorm1dOptions` class to learn
67 /// what constructor arguments are supported for this module.
68 ///
69 /// Example:
70 /// ```
71 /// InstanceNorm1d
72 /// model(InstanceNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true));
73 /// ```
74 class TORCH_API InstanceNorm1dImpl
75     : public InstanceNormImpl<1, InstanceNorm1dImpl> {
76  protected:
77   void _check_input_dim(const Tensor& input) override;
78 
79  public:
80   using InstanceNormImpl<1, InstanceNorm1dImpl>::InstanceNormImpl;
81 };
82 
83 /// A `ModuleHolder` subclass for `InstanceNorm1dImpl`.
84 /// See the documentation for `InstanceNorm1dImpl` class to learn what methods
85 /// it provides, and examples of how to use `InstanceNorm1d` with
86 /// `torch::nn::InstanceNorm1dOptions`. See the documentation for `ModuleHolder`
87 /// to learn about PyTorch's module storage semantics.
88 TORCH_MODULE(InstanceNorm1d);
89 
90 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm2d
91 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
92 
93 /// Applies the InstanceNorm2d function.
94 /// See https://pytorch.org/docs/main/nn.html#torch.nn.InstanceNorm2d to learn
95 /// about the exact behavior of this module.
96 ///
97 /// See the documentation for `torch::nn::InstanceNorm2dOptions` class to learn
98 /// what constructor arguments are supported for this module.
99 ///
100 /// Example:
101 /// ```
102 /// InstanceNorm2d
103 /// model(InstanceNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true));
104 /// ```
105 class TORCH_API InstanceNorm2dImpl
106     : public InstanceNormImpl<2, InstanceNorm2dImpl> {
107  protected:
108   void _check_input_dim(const Tensor& input) override;
109 
110  public:
111   using InstanceNormImpl<2, InstanceNorm2dImpl>::InstanceNormImpl;
112 };
113 
114 /// A `ModuleHolder` subclass for `InstanceNorm2dImpl`.
115 /// See the documentation for `InstanceNorm2dImpl` class to learn what methods
116 /// it provides, and examples of how to use `InstanceNorm2d` with
117 /// `torch::nn::InstanceNorm2dOptions`. See the documentation for `ModuleHolder`
118 /// to learn about PyTorch's module storage semantics.
119 TORCH_MODULE(InstanceNorm2d);
120 
121 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm3d
122 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
123 
124 /// Applies the InstanceNorm3d function.
125 /// See https://pytorch.org/docs/main/nn.html#torch.nn.InstanceNorm3d to learn
126 /// about the exact behavior of this module.
127 ///
128 /// See the documentation for `torch::nn::InstanceNorm3dOptions` class to learn
129 /// what constructor arguments are supported for this module.
130 ///
131 /// Example:
132 /// ```
133 /// InstanceNorm3d
134 /// model(InstanceNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true));
135 /// ```
136 class TORCH_API InstanceNorm3dImpl
137     : public InstanceNormImpl<3, InstanceNorm3dImpl> {
138  protected:
139   void _check_input_dim(const Tensor& input) override;
140 
141  public:
142   using InstanceNormImpl<3, InstanceNorm3dImpl>::InstanceNormImpl;
143 };
144 
145 /// A `ModuleHolder` subclass for `InstanceNorm3dImpl`.
146 /// See the documentation for `InstanceNorm3dImpl` class to learn what methods
147 /// it provides, and examples of how to use `InstanceNorm3d` with
148 /// `torch::nn::InstanceNorm3dOptions`. See the documentation for `ModuleHolder`
149 /// to learn about PyTorch's module storage semantics.
150 TORCH_MODULE(InstanceNorm3d);
151 
152 } // namespace nn
153 } // namespace torch
154