1 #pragma once 2 3 #include <torch/nn/modules/batchnorm.h> 4 #include <torch/nn/options/instancenorm.h> 5 6 namespace torch { 7 namespace nn { 8 9 /// Base class for all (dimension-specialized) instance norm modules 10 template <size_t D, typename Derived> 11 class InstanceNormImpl 12 : public torch::nn::NormImplBase<D, Derived, InstanceNormOptions> { 13 private: apply_instance_norm(const Tensor & input)14 inline Tensor apply_instance_norm(const Tensor& input) { 15 return torch::nn::functional::detail::instance_norm( 16 input, 17 this->running_mean, 18 this->running_var, 19 this->weight, 20 this->bias, 21 this->is_training() || !this->options.track_running_stats(), 22 this->options.momentum(), 23 this->options.eps()); 24 } 25 handle_no_batch_input(const Tensor & input)26 inline Tensor handle_no_batch_input(const Tensor& input) { 27 return this->apply_instance_norm(input.unsqueeze(0)).squeeze(0); 28 } 29 30 public: 31 using torch::nn::NormImplBase<D, Derived, InstanceNormOptions>::NormImplBase; 32 forward(const Tensor & input)33 Tensor forward(const Tensor& input) { 34 this->_check_input_dim(input); 35 36 // For InstanceNorm1D, 2D is unbatched and 3D is batched 37 // For InstanceNorm2D, 3D is unbatched and 4D is batched 38 // For InstanceNorm3D, 4D is unbatched and 5D is batched 39 // check if input does not have a batch-dim 40 if (input.dim() == D + 1) { 41 return this->handle_no_batch_input(input); 42 } 43 44 return this->apply_instance_norm(input); 45 } 46 47 /// Pretty prints the `InstanceNorm{1,2,3}d` module into the given `stream`. pretty_print(std::ostream & stream)48 void pretty_print(std::ostream& stream) const override { 49 stream << std::boolalpha << "torch::nn::InstanceNorm" << D << "d(" 50 << this->options.num_features() << ", " 51 << "eps=" << this->options.eps() << ", " 52 << "momentum=" << this->options.momentum() << ", " 53 << "affine=" << this->options.affine() << ", " 54 << "track_running_stats=" << this->options.track_running_stats() 55 << ")"; 56 } 57 }; 58 59 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm1d 60 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 61 62 /// Applies the InstanceNorm1d function. 63 /// See https://pytorch.org/docs/main/nn.html#torch.nn.InstanceNorm1d to learn 64 /// about the exact behavior of this module. 65 /// 66 /// See the documentation for `torch::nn::InstanceNorm1dOptions` class to learn 67 /// what constructor arguments are supported for this module. 68 /// 69 /// Example: 70 /// ``` 71 /// InstanceNorm1d 72 /// model(InstanceNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); 73 /// ``` 74 class TORCH_API InstanceNorm1dImpl 75 : public InstanceNormImpl<1, InstanceNorm1dImpl> { 76 protected: 77 void _check_input_dim(const Tensor& input) override; 78 79 public: 80 using InstanceNormImpl<1, InstanceNorm1dImpl>::InstanceNormImpl; 81 }; 82 83 /// A `ModuleHolder` subclass for `InstanceNorm1dImpl`. 84 /// See the documentation for `InstanceNorm1dImpl` class to learn what methods 85 /// it provides, and examples of how to use `InstanceNorm1d` with 86 /// `torch::nn::InstanceNorm1dOptions`. See the documentation for `ModuleHolder` 87 /// to learn about PyTorch's module storage semantics. 88 TORCH_MODULE(InstanceNorm1d); 89 90 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm2d 91 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 92 93 /// Applies the InstanceNorm2d function. 94 /// See https://pytorch.org/docs/main/nn.html#torch.nn.InstanceNorm2d to learn 95 /// about the exact behavior of this module. 96 /// 97 /// See the documentation for `torch::nn::InstanceNorm2dOptions` class to learn 98 /// what constructor arguments are supported for this module. 99 /// 100 /// Example: 101 /// ``` 102 /// InstanceNorm2d 103 /// model(InstanceNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); 104 /// ``` 105 class TORCH_API InstanceNorm2dImpl 106 : public InstanceNormImpl<2, InstanceNorm2dImpl> { 107 protected: 108 void _check_input_dim(const Tensor& input) override; 109 110 public: 111 using InstanceNormImpl<2, InstanceNorm2dImpl>::InstanceNormImpl; 112 }; 113 114 /// A `ModuleHolder` subclass for `InstanceNorm2dImpl`. 115 /// See the documentation for `InstanceNorm2dImpl` class to learn what methods 116 /// it provides, and examples of how to use `InstanceNorm2d` with 117 /// `torch::nn::InstanceNorm2dOptions`. See the documentation for `ModuleHolder` 118 /// to learn about PyTorch's module storage semantics. 119 TORCH_MODULE(InstanceNorm2d); 120 121 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm3d 122 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 123 124 /// Applies the InstanceNorm3d function. 125 /// See https://pytorch.org/docs/main/nn.html#torch.nn.InstanceNorm3d to learn 126 /// about the exact behavior of this module. 127 /// 128 /// See the documentation for `torch::nn::InstanceNorm3dOptions` class to learn 129 /// what constructor arguments are supported for this module. 130 /// 131 /// Example: 132 /// ``` 133 /// InstanceNorm3d 134 /// model(InstanceNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false).track_running_stats(true)); 135 /// ``` 136 class TORCH_API InstanceNorm3dImpl 137 : public InstanceNormImpl<3, InstanceNorm3dImpl> { 138 protected: 139 void _check_input_dim(const Tensor& input) override; 140 141 public: 142 using InstanceNormImpl<3, InstanceNorm3dImpl>::InstanceNormImpl; 143 }; 144 145 /// A `ModuleHolder` subclass for `InstanceNorm3dImpl`. 146 /// See the documentation for `InstanceNorm3dImpl` class to learn what methods 147 /// it provides, and examples of how to use `InstanceNorm3d` with 148 /// `torch::nn::InstanceNorm3dOptions`. See the documentation for `ModuleHolder` 149 /// to learn about PyTorch's module storage semantics. 150 TORCH_MODULE(InstanceNorm3d); 151 152 } // namespace nn 153 } // namespace torch 154