1 #pragma once 2 3 #include <c10/util/irange.h> 4 #include <c10/util/overloaded.h> 5 6 #include <torch/expanding_array.h> 7 #include <torch/nn/cloneable.h> 8 #include <torch/nn/init.h> 9 #include <torch/nn/modules/common.h> 10 #include <torch/nn/modules/utils.h> 11 #include <torch/nn/options/conv.h> 12 #include <torch/nn/pimpl.h> 13 #include <torch/types.h> 14 15 #include <torch/csrc/Export.h> 16 17 #include <cstddef> 18 #include <vector> 19 20 namespace torch { 21 namespace nn { 22 23 /// Base class for all (dimension-specialized) convolution modules. 24 template <size_t D, typename Derived> 25 class ConvNdImpl : public torch::nn::Cloneable<Derived> { 26 public: ConvNdImpl(detail::ConvNdOptions<D> options_)27 explicit ConvNdImpl(detail::ConvNdOptions<D> options_) 28 : options(std::move(options_)) { 29 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) 30 reset(); 31 } 32 reset()33 void reset() override { 34 TORCH_CHECK( 35 options.in_channels() > 0 && options.groups() > 0 && 36 options.out_channels() > 0, 37 "in_channels, groups and out_channels must be a positive integer."); 38 TORCH_CHECK( 39 options.in_channels() % options.groups() == 0, 40 "in_channels must be divisible by groups"); 41 TORCH_CHECK( 42 options.out_channels() % options.groups() == 0, 43 "out_channels must be divisible by groups"); 44 45 std::visit( 46 c10::overloaded( 47 [&](enumtype::kValid) { 48 _reversed_padding_repeated_twice.resize(2 * D); 49 std::fill_n(_reversed_padding_repeated_twice.begin(), 2 * D, 0); 50 }, 51 [&](enumtype::kSame) { 52 for (const auto i : c10::irange(D)) { 53 const auto stride = (*options.stride())[i]; 54 TORCH_CHECK( 55 stride == 1, 56 "padding='same' is not supported for strided convolutions"); 57 } 58 59 _reversed_padding_repeated_twice.resize(2 * D); 60 for (const auto i : c10::irange(D)) { 61 const auto dilation = (*options.dilation())[i]; 62 const auto kernel_size = (*options.kernel_size())[i]; 63 const auto total_padding = dilation * (kernel_size - 1); 64 auto left_pad = total_padding / 2; 65 auto right_pad = total_padding - left_pad; 66 _reversed_padding_repeated_twice[2 * i] = left_pad; 67 _reversed_padding_repeated_twice[2 * i + 1] = right_pad; 68 } 69 }, 70 [&](const ExpandingArray<D>& pad) { 71 _reversed_padding_repeated_twice = 72 torch::nn::modules::utils::_reverse_repeat_vector(pad, 2); 73 }), 74 options.padding()); 75 76 if (options.transposed()) { 77 std::vector<int64_t> weight_sizes = { 78 options.in_channels(), options.out_channels() / options.groups()}; 79 weight_sizes.insert( 80 weight_sizes.end(), 81 (*options.kernel_size()).begin(), 82 (*options.kernel_size()).end()); 83 weight = this->register_parameter("weight", torch::empty(weight_sizes)); 84 } else { 85 std::vector<int64_t> weight_sizes = { 86 options.out_channels(), options.in_channels() / options.groups()}; 87 weight_sizes.insert( 88 weight_sizes.end(), 89 (*options.kernel_size()).begin(), 90 (*options.kernel_size()).end()); 91 weight = this->register_parameter("weight", torch::empty(weight_sizes)); 92 } 93 94 if (options.bias()) { 95 bias = this->register_parameter( 96 "bias", torch::empty({options.out_channels()})); 97 } else { 98 this->register_parameter("bias", Tensor(), /*requires_grad=*/false); 99 } 100 101 reset_parameters(); 102 } 103 reset_parameters()104 void reset_parameters() { 105 init::kaiming_uniform_( 106 weight, 107 /*a=*/std::sqrt(5)); // NOLINT(cppcoreguidelines-avoid-magic-numbers) 108 109 if (bias.defined()) { 110 auto [fan_in, fan_out] = init::_calculate_fan_in_and_fan_out(weight); 111 auto bound = 1 / std::sqrt(fan_in); 112 init::uniform_(bias, -bound, bound); 113 } 114 } 115 116 /// Pretty prints the `Conv{1,2,3}d` module into the given `stream`. pretty_print(std::ostream & stream)117 void pretty_print(std::ostream& stream) const override { 118 stream << "torch::nn::Conv" << D << "d" 119 << "(" << options.in_channels() << ", " << options.out_channels() 120 << ", kernel_size=" << options.kernel_size() 121 << ", stride=" << options.stride(); 122 std::visit( 123 c10::overloaded( 124 [&](enumtype::kValid) { stream << ", padding='valid'"; }, 125 [&](enumtype::kSame) { stream << ", padding='same'"; }, 126 [&](const ExpandingArray<D>& pad) { 127 if (*pad != *ExpandingArray<D>(0)) { 128 stream << ", padding=" << pad; 129 } 130 }), 131 options.padding()); 132 if (*options.dilation() != *ExpandingArray<D>(1)) { 133 stream << ", dilation=" << options.dilation(); 134 } 135 if (*options.output_padding() != *ExpandingArray<D>(0)) { 136 stream << ", output_padding=" << options.output_padding(); 137 } 138 if (options.groups() != 1) { 139 stream << ", groups=" << options.groups(); 140 } 141 if (!options.bias()) { 142 stream << ", bias=" << std::boolalpha << false; 143 } 144 if (!std::get_if<enumtype::kZeros>(&options.padding_mode())) { 145 stream << ", padding_mode=" 146 << enumtype::get_enum_name(options.padding_mode()); 147 } 148 stream << ")"; 149 } 150 151 /// The options with which this `Module` was constructed. 152 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 153 detail::ConvNdOptions<D> options; 154 155 /// The learned kernel (or "weight"). 156 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 157 Tensor weight; 158 159 /// The learned bias. Only defined if the `bias` option was true. 160 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 161 Tensor bias; 162 163 protected: 164 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 165 std::vector<int64_t> _reversed_padding_repeated_twice; 166 }; 167 168 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 169 170 /// Applies convolution over a 1-D input. 171 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Conv1d to learn about 172 /// the exact behavior of this module. 173 /// 174 /// See the documentation for `torch::nn::Conv1dOptions` class to learn what 175 /// constructor arguments are supported for this module. 176 /// 177 /// Example: 178 /// ``` 179 /// Conv1d model(Conv1dOptions(3, 2, 3).stride(1).bias(false)); 180 /// ``` 181 class TORCH_API Conv1dImpl : public ConvNdImpl<1, Conv1dImpl> { 182 public: Conv1dImpl(int64_t input_channels,int64_t output_channels,ExpandingArray<1> kernel_size)183 Conv1dImpl( 184 int64_t input_channels, 185 int64_t output_channels, 186 ExpandingArray<1> kernel_size) 187 : Conv1dImpl( 188 Conv1dOptions(input_channels, output_channels, kernel_size)) {} 189 explicit Conv1dImpl(Conv1dOptions options_); 190 Tensor forward(const Tensor& input); 191 }; 192 193 /// A `ModuleHolder` subclass for `Conv1dImpl`. 194 /// See the documentation for `Conv1dImpl` class to learn what methods it 195 /// provides, and examples of how to use `Conv1d` with 196 /// `torch::nn::Conv1dOptions`. See the documentation for `ModuleHolder` to 197 /// learn about PyTorch's module storage semantics. 198 TORCH_MODULE(Conv1d); 199 200 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 201 202 /// Applies convolution over a 2-D input. 203 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Conv2d to learn about 204 /// the exact behavior of this module. 205 /// 206 /// See the documentation for `torch::nn::Conv2dOptions` class to learn what 207 /// constructor arguments are supported for this module. 208 /// 209 /// Example: 210 /// ``` 211 /// Conv2d model(Conv2dOptions(3, 2, 3).stride(1).bias(false)); 212 /// ``` 213 class TORCH_API Conv2dImpl : public ConvNdImpl<2, Conv2dImpl> { 214 public: Conv2dImpl(int64_t input_channels,int64_t output_channels,ExpandingArray<2> kernel_size)215 Conv2dImpl( 216 int64_t input_channels, 217 int64_t output_channels, 218 ExpandingArray<2> kernel_size) 219 : Conv2dImpl( 220 Conv2dOptions(input_channels, output_channels, kernel_size)) {} 221 explicit Conv2dImpl(Conv2dOptions options_); 222 Tensor forward(const Tensor& input); 223 224 protected: 225 Tensor _conv_forward(const Tensor& input, const Tensor& weight); 226 }; 227 228 /// A `ModuleHolder` subclass for `Conv2dImpl`. 229 /// See the documentation for `Conv2dImpl` class to learn what methods it 230 /// provides, and examples of how to use `Conv2d` with 231 /// `torch::nn::Conv2dOptions`. See the documentation for `ModuleHolder` to 232 /// learn about PyTorch's module storage semantics. 233 TORCH_MODULE(Conv2d); 234 235 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 236 237 /// Applies convolution over a 3-D input. 238 /// See https://pytorch.org/docs/main/nn.html#torch.nn.Conv3d to learn about 239 /// the exact behavior of this module. 240 /// 241 /// See the documentation for `torch::nn::Conv3dOptions` class to learn what 242 /// constructor arguments are supported for this module. 243 /// 244 /// Example: 245 /// ``` 246 /// Conv3d model(Conv3dOptions(3, 2, 3).stride(1).bias(false)); 247 /// ``` 248 class TORCH_API Conv3dImpl : public ConvNdImpl<3, Conv3dImpl> { 249 public: Conv3dImpl(int64_t input_channels,int64_t output_channels,ExpandingArray<3> kernel_size)250 Conv3dImpl( 251 int64_t input_channels, 252 int64_t output_channels, 253 ExpandingArray<3> kernel_size) 254 : Conv3dImpl( 255 Conv3dOptions(input_channels, output_channels, kernel_size)) {} 256 explicit Conv3dImpl(Conv3dOptions options_); 257 Tensor forward(const Tensor& input); 258 }; 259 260 /// A `ModuleHolder` subclass for `Conv3dImpl`. 261 /// See the documentation for `Conv3dImpl` class to learn what methods it 262 /// provides, and examples of how to use `Conv3d` with 263 /// `torch::nn::Conv3dOptions`. See the documentation for `ModuleHolder` to 264 /// learn about PyTorch's module storage semantics. 265 TORCH_MODULE(Conv3d); 266 267 // ~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 268 269 /// Base class for all (dimension-specialized) convolution transpose modules. 270 template <size_t D, typename Derived> 271 class ConvTransposeNdImpl : public ConvNdImpl<D, Derived> { 272 public: 273 using torch::nn::ConvNdImpl<D, Derived>::ConvNdImpl; ConvTransposeNdImpl(detail::ConvNdOptions<D> options_)274 explicit ConvTransposeNdImpl(detail::ConvNdOptions<D> options_) 275 : ConvNdImpl<D, Derived>(options_) { 276 TORCH_INTERNAL_ASSERT( 277 std::holds_alternative<ExpandingArray<D>>(this->options.padding()), 278 "ConvTranspose padding cannot be a string"); 279 } 280 281 /// Pretty prints the `ConvTranspose{1,2,3}d` module into the given `stream`. pretty_print(std::ostream & stream)282 void pretty_print(std::ostream& stream) const override { 283 stream << "torch::nn::ConvTranspose" << D << "d" 284 << "(" << this->options.in_channels() << ", " 285 << this->options.out_channels() 286 << ", kernel_size=" << this->options.kernel_size() 287 << ", stride=" << this->options.stride(); 288 const auto& pad = padding(); 289 if (*pad != *ExpandingArray<D>(0)) { 290 stream << ", padding=" << pad; 291 } 292 if (*this->options.dilation() != *ExpandingArray<D>(1)) { 293 stream << ", dilation=" << this->options.dilation(); 294 } 295 if (*this->options.output_padding() != *ExpandingArray<D>(0)) { 296 stream << ", output_padding=" << this->options.output_padding(); 297 } 298 if (this->options.groups() != 1) { 299 stream << ", groups=" << this->options.groups(); 300 } 301 if (!this->options.bias()) { 302 stream << ", bias=" << std::boolalpha << false; 303 } 304 if (!std::get_if<enumtype::kZeros>(&this->options.padding_mode())) { 305 stream << ", padding_mode=" 306 << enumtype::get_enum_name(this->options.padding_mode()); 307 } 308 stream << ")"; 309 } 310 311 protected: padding()312 const ExpandingArray<D>& padding() const { 313 return std::get<ExpandingArray<D>>(this->options.padding()); 314 } 315 316 std::vector<int64_t> _output_padding( 317 const Tensor& input, 318 const std::optional<at::IntArrayRef>& output_size, 319 const ExpandingArray<D>& stride, 320 const ExpandingArray<D>& padding, 321 const ExpandingArray<D>& kernel_size); 322 }; 323 324 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose1d 325 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 326 327 /// Applies the ConvTranspose1d function. 328 /// See https://pytorch.org/docs/main/nn.html#torch.nn.ConvTranspose1d to 329 /// learn about the exact behavior of this module. 330 /// 331 /// See the documentation for `torch::nn::ConvTranspose1dOptions` class to learn 332 /// what constructor arguments are supported for this module. 333 /// 334 /// Example: 335 /// ``` 336 /// ConvTranspose1d model(ConvTranspose1dOptions(3, 2, 337 /// 3).stride(1).bias(false)); 338 /// ``` 339 class TORCH_API ConvTranspose1dImpl 340 : public ConvTransposeNdImpl<1, ConvTranspose1dImpl> { 341 public: ConvTranspose1dImpl(int64_t input_channels,int64_t output_channels,ExpandingArray<1> kernel_size)342 ConvTranspose1dImpl( 343 int64_t input_channels, 344 int64_t output_channels, 345 ExpandingArray<1> kernel_size) 346 : ConvTranspose1dImpl(ConvTranspose1dOptions( 347 input_channels, 348 output_channels, 349 kernel_size)) {} 350 explicit ConvTranspose1dImpl(ConvTranspose1dOptions options_); 351 Tensor forward( 352 const Tensor& input, 353 const std::optional<at::IntArrayRef>& output_size = std::nullopt); 354 355 protected: 356 FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(std::optional<at::IntArrayRef>())}) 357 }; 358 359 /// A `ModuleHolder` subclass for `ConvTranspose1dImpl`. 360 /// See the documentation for `ConvTranspose1dImpl` class to learn what methods 361 /// it provides, and examples of how to use `ConvTranspose1d` with 362 /// `torch::nn::ConvTranspose1dOptions`. See the documentation for 363 /// `ModuleHolder` to learn about PyTorch's module storage semantics. 364 TORCH_MODULE(ConvTranspose1d); 365 366 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose2d 367 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 368 369 /// Applies the ConvTranspose2d function. 370 /// See https://pytorch.org/docs/main/nn.html#torch.nn.ConvTranspose2d to 371 /// learn about the exact behavior of this module. 372 /// 373 /// See the documentation for `torch::nn::ConvTranspose2dOptions` class to learn 374 /// what constructor arguments are supported for this module. 375 /// 376 /// Example: 377 /// ``` 378 /// ConvTranspose2d model(ConvTranspose2dOptions(3, 2, 379 /// 3).stride(1).bias(false)); 380 /// ``` 381 class TORCH_API ConvTranspose2dImpl 382 : public ConvTransposeNdImpl<2, ConvTranspose2dImpl> { 383 public: ConvTranspose2dImpl(int64_t input_channels,int64_t output_channels,ExpandingArray<2> kernel_size)384 ConvTranspose2dImpl( 385 int64_t input_channels, 386 int64_t output_channels, 387 ExpandingArray<2> kernel_size) 388 : ConvTranspose2dImpl(ConvTranspose2dOptions( 389 input_channels, 390 output_channels, 391 kernel_size)) {} 392 explicit ConvTranspose2dImpl(ConvTranspose2dOptions options_); 393 Tensor forward( 394 const Tensor& input, 395 const std::optional<at::IntArrayRef>& output_size = std::nullopt); 396 397 protected: 398 FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(std::optional<at::IntArrayRef>())}) 399 }; 400 401 /// A `ModuleHolder` subclass for `ConvTranspose2dImpl`. 402 /// See the documentation for `ConvTranspose2dImpl` class to learn what methods 403 /// it provides, and examples of how to use `ConvTranspose2d` with 404 /// `torch::nn::ConvTranspose2dOptions`. See the documentation for 405 /// `ModuleHolder` to learn about PyTorch's module storage semantics. 406 TORCH_MODULE(ConvTranspose2d); 407 408 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ConvTranspose3d 409 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 410 411 /// Applies the ConvTranspose3d function. 412 /// See https://pytorch.org/docs/main/nn.html#torch.nn.ConvTranspose3d to 413 /// learn about the exact behavior of this module. 414 /// 415 /// See the documentation for `torch::nn::ConvTranspose3dOptions` class to learn 416 /// what constructor arguments are supported for this module. 417 /// 418 /// Example: 419 /// ``` 420 /// ConvTranspose3d model(ConvTranspose3dOptions(2, 2, 421 /// 2).stride(1).bias(false)); 422 /// ``` 423 class TORCH_API ConvTranspose3dImpl 424 : public ConvTransposeNdImpl<3, ConvTranspose3dImpl> { 425 public: ConvTranspose3dImpl(int64_t input_channels,int64_t output_channels,ExpandingArray<3> kernel_size)426 ConvTranspose3dImpl( 427 int64_t input_channels, 428 int64_t output_channels, 429 ExpandingArray<3> kernel_size) 430 : ConvTranspose3dImpl(ConvTranspose3dOptions( 431 input_channels, 432 output_channels, 433 kernel_size)) {} 434 explicit ConvTranspose3dImpl(ConvTranspose3dOptions options_); 435 Tensor forward( 436 const Tensor& input, 437 const std::optional<at::IntArrayRef>& output_size = std::nullopt); 438 439 protected: 440 FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(std::optional<at::IntArrayRef>())}) 441 }; 442 443 /// A `ModuleHolder` subclass for `ConvTranspose3dImpl`. 444 /// See the documentation for `ConvTranspose3dImpl` class to learn what methods 445 /// it provides, and examples of how to use `ConvTranspose3d` with 446 /// `torch::nn::ConvTranspose3dOptions`. See the documentation for 447 /// `ModuleHolder` to learn about PyTorch's module storage semantics. 448 TORCH_MODULE(ConvTranspose3d); 449 450 } // namespace nn 451 } // namespace torch 452