xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/padding.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/modules/padding.h>
2 
3 #include <torch/expanding_array.h>
4 
5 namespace F = torch::nn::functional;
6 
7 namespace torch {
8 namespace nn {
9 
10 template <size_t D, typename Derived>
ReflectionPadImpl(const ReflectionPadOptions<D> & options_)11 ReflectionPadImpl<D, Derived>::ReflectionPadImpl(
12     const ReflectionPadOptions<D>& options_)
13     : options(options_) {}
14 
15 template <size_t D, typename Derived>
reset()16 void ReflectionPadImpl<D, Derived>::reset() {}
17 
18 template <size_t D, typename Derived>
forward(const Tensor & input)19 Tensor ReflectionPadImpl<D, Derived>::forward(const Tensor& input) {
20   return F::detail::pad(input, options.padding(), torch::kReflect, 0);
21 }
22 
23 template <size_t D, typename Derived>
pretty_print(std::ostream & stream) const24 void ReflectionPadImpl<D, Derived>::pretty_print(std::ostream& stream) const {
25   stream << "torch::nn::ReflectionPad" << D << "d"
26          << "(padding=" << options.padding() << ")";
27 }
28 
29 template class ReflectionPadImpl<1, ReflectionPad1dImpl>;
30 template class ReflectionPadImpl<2, ReflectionPad2dImpl>;
31 template class ReflectionPadImpl<3, ReflectionPad3dImpl>;
32 
33 // ============================================================================
34 
35 template <size_t D, typename Derived>
ReplicationPadImpl(const ReplicationPadOptions<D> & options_)36 ReplicationPadImpl<D, Derived>::ReplicationPadImpl(
37     const ReplicationPadOptions<D>& options_)
38     : options(options_) {}
39 
40 template <size_t D, typename Derived>
reset()41 void ReplicationPadImpl<D, Derived>::reset() {}
42 
43 template <size_t D, typename Derived>
forward(const Tensor & input)44 Tensor ReplicationPadImpl<D, Derived>::forward(const Tensor& input) {
45   return F::detail::pad(input, options.padding(), torch::kReplicate, 0);
46 }
47 
48 template <size_t D, typename Derived>
pretty_print(std::ostream & stream) const49 void ReplicationPadImpl<D, Derived>::pretty_print(std::ostream& stream) const {
50   stream << "torch::nn::ReplicationPad" << D << "d"
51          << "(padding=" << options.padding() << ")";
52 }
53 
54 template class ReplicationPadImpl<1, ReplicationPad1dImpl>;
55 template class ReplicationPadImpl<2, ReplicationPad2dImpl>;
56 template class ReplicationPadImpl<3, ReplicationPad3dImpl>;
57 
58 // ============================================================================
59 
60 template <size_t D, typename Derived>
ZeroPadImpl(const ZeroPadOptions<D> & options_)61 ZeroPadImpl<D, Derived>::ZeroPadImpl(const ZeroPadOptions<D>& options_)
62     : options(options_) {}
63 
64 template <size_t D, typename Derived>
reset()65 void ZeroPadImpl<D, Derived>::reset() {}
66 
67 template <size_t D, typename Derived>
forward(const Tensor & input)68 Tensor ZeroPadImpl<D, Derived>::forward(const Tensor& input) {
69   return F::detail::pad(input, options.padding(), torch::kConstant, 0);
70 }
71 
72 template <size_t D, typename Derived>
pretty_print(std::ostream & stream) const73 void ZeroPadImpl<D, Derived>::pretty_print(std::ostream& stream) const {
74   stream << "torch::nn::ZeroPad" << D << "d"
75          << "(padding=" << options.padding() << ")";
76 }
77 
78 template class ZeroPadImpl<1, ZeroPad1dImpl>;
79 template class ZeroPadImpl<2, ZeroPad2dImpl>;
80 template class ZeroPadImpl<3, ZeroPad3dImpl>;
81 
82 // ============================================================================
83 
84 template <size_t D, typename Derived>
ConstantPadImpl(const ConstantPadOptions<D> & options_)85 ConstantPadImpl<D, Derived>::ConstantPadImpl(
86     const ConstantPadOptions<D>& options_)
87     : options(options_) {}
88 
89 template <size_t D, typename Derived>
reset()90 void ConstantPadImpl<D, Derived>::reset() {}
91 
92 template <size_t D, typename Derived>
forward(const Tensor & input)93 Tensor ConstantPadImpl<D, Derived>::forward(const Tensor& input) {
94   return F::detail::pad(
95       input, options.padding(), torch::kConstant, options.value());
96 }
97 
98 template <size_t D, typename Derived>
pretty_print(std::ostream & stream) const99 void ConstantPadImpl<D, Derived>::pretty_print(std::ostream& stream) const {
100   stream << "torch::nn::ConstantPad" << D << "d"
101          << "(padding=" << options.padding() << ", value=" << options.value()
102          << ")";
103 }
104 
105 template class ConstantPadImpl<1, ConstantPad1dImpl>;
106 template class ConstantPadImpl<2, ConstantPad2dImpl>;
107 template class ConstantPadImpl<3, ConstantPad3dImpl>;
108 
109 } // namespace nn
110 } // namespace torch
111