1 #include <torch/nn/modules/padding.h>
2
3 #include <torch/expanding_array.h>
4
5 namespace F = torch::nn::functional;
6
7 namespace torch {
8 namespace nn {
9
10 template <size_t D, typename Derived>
ReflectionPadImpl(const ReflectionPadOptions<D> & options_)11 ReflectionPadImpl<D, Derived>::ReflectionPadImpl(
12 const ReflectionPadOptions<D>& options_)
13 : options(options_) {}
14
15 template <size_t D, typename Derived>
reset()16 void ReflectionPadImpl<D, Derived>::reset() {}
17
18 template <size_t D, typename Derived>
forward(const Tensor & input)19 Tensor ReflectionPadImpl<D, Derived>::forward(const Tensor& input) {
20 return F::detail::pad(input, options.padding(), torch::kReflect, 0);
21 }
22
23 template <size_t D, typename Derived>
pretty_print(std::ostream & stream) const24 void ReflectionPadImpl<D, Derived>::pretty_print(std::ostream& stream) const {
25 stream << "torch::nn::ReflectionPad" << D << "d"
26 << "(padding=" << options.padding() << ")";
27 }
28
29 template class ReflectionPadImpl<1, ReflectionPad1dImpl>;
30 template class ReflectionPadImpl<2, ReflectionPad2dImpl>;
31 template class ReflectionPadImpl<3, ReflectionPad3dImpl>;
32
33 // ============================================================================
34
35 template <size_t D, typename Derived>
ReplicationPadImpl(const ReplicationPadOptions<D> & options_)36 ReplicationPadImpl<D, Derived>::ReplicationPadImpl(
37 const ReplicationPadOptions<D>& options_)
38 : options(options_) {}
39
40 template <size_t D, typename Derived>
reset()41 void ReplicationPadImpl<D, Derived>::reset() {}
42
43 template <size_t D, typename Derived>
forward(const Tensor & input)44 Tensor ReplicationPadImpl<D, Derived>::forward(const Tensor& input) {
45 return F::detail::pad(input, options.padding(), torch::kReplicate, 0);
46 }
47
48 template <size_t D, typename Derived>
pretty_print(std::ostream & stream) const49 void ReplicationPadImpl<D, Derived>::pretty_print(std::ostream& stream) const {
50 stream << "torch::nn::ReplicationPad" << D << "d"
51 << "(padding=" << options.padding() << ")";
52 }
53
54 template class ReplicationPadImpl<1, ReplicationPad1dImpl>;
55 template class ReplicationPadImpl<2, ReplicationPad2dImpl>;
56 template class ReplicationPadImpl<3, ReplicationPad3dImpl>;
57
58 // ============================================================================
59
60 template <size_t D, typename Derived>
ZeroPadImpl(const ZeroPadOptions<D> & options_)61 ZeroPadImpl<D, Derived>::ZeroPadImpl(const ZeroPadOptions<D>& options_)
62 : options(options_) {}
63
64 template <size_t D, typename Derived>
reset()65 void ZeroPadImpl<D, Derived>::reset() {}
66
67 template <size_t D, typename Derived>
forward(const Tensor & input)68 Tensor ZeroPadImpl<D, Derived>::forward(const Tensor& input) {
69 return F::detail::pad(input, options.padding(), torch::kConstant, 0);
70 }
71
72 template <size_t D, typename Derived>
pretty_print(std::ostream & stream) const73 void ZeroPadImpl<D, Derived>::pretty_print(std::ostream& stream) const {
74 stream << "torch::nn::ZeroPad" << D << "d"
75 << "(padding=" << options.padding() << ")";
76 }
77
78 template class ZeroPadImpl<1, ZeroPad1dImpl>;
79 template class ZeroPadImpl<2, ZeroPad2dImpl>;
80 template class ZeroPadImpl<3, ZeroPad3dImpl>;
81
82 // ============================================================================
83
84 template <size_t D, typename Derived>
ConstantPadImpl(const ConstantPadOptions<D> & options_)85 ConstantPadImpl<D, Derived>::ConstantPadImpl(
86 const ConstantPadOptions<D>& options_)
87 : options(options_) {}
88
89 template <size_t D, typename Derived>
reset()90 void ConstantPadImpl<D, Derived>::reset() {}
91
92 template <size_t D, typename Derived>
forward(const Tensor & input)93 Tensor ConstantPadImpl<D, Derived>::forward(const Tensor& input) {
94 return F::detail::pad(
95 input, options.padding(), torch::kConstant, options.value());
96 }
97
98 template <size_t D, typename Derived>
pretty_print(std::ostream & stream) const99 void ConstantPadImpl<D, Derived>::pretty_print(std::ostream& stream) const {
100 stream << "torch::nn::ConstantPad" << D << "d"
101 << "(padding=" << options.padding() << ", value=" << options.value()
102 << ")";
103 }
104
105 template class ConstantPadImpl<1, ConstantPad1dImpl>;
106 template class ConstantPadImpl<2, ConstantPad2dImpl>;
107 template class ConstantPadImpl<3, ConstantPad3dImpl>;
108
109 } // namespace nn
110 } // namespace torch
111