1 #pragma once 2 3 #include <torch/nn/cloneable.h> 4 #include <torch/nn/functional/linear.h> 5 #include <torch/nn/module.h> 6 #include <torch/nn/options/linear.h> 7 #include <torch/nn/pimpl.h> 8 #include <torch/types.h> 9 10 #include <cstddef> 11 #include <vector> 12 13 namespace torch { 14 namespace nn { 15 16 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Identity ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 17 18 /// A placeholder identity operator that is argument-insensitive. 19 /// See https://pytorch.org/docs/main/generated/torch.nn.Identity.html to 20 /// learn about the exact behavior of this module. 21 class TORCH_API IdentityImpl : public Cloneable<IdentityImpl> { 22 public: 23 void reset() override; 24 25 /// Pretty prints the `Identity` module into the given `stream`. 26 void pretty_print(std::ostream& stream) const override; 27 28 Tensor forward(const Tensor& input); 29 }; 30 31 /// A `ModuleHolder` subclass for `IdentityImpl`. 32 /// See the documentation for `IdentityImpl` class to learn what methods it 33 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's 34 /// module storage semantics. 35 TORCH_MODULE(Identity); 36 37 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Linear ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 38 39 /// Applies a linear transformation with optional bias. 40 /// See https://pytorch.org/docs/main/generated/torch.nn.Linear.html to learn 41 /// about the exact behavior of this module. 42 /// 43 /// See the documentation for `torch::nn::LinearOptions` class to learn what 44 /// constructor arguments are supported for this module. 45 /// 46 /// Example: 47 /// ``` 48 /// Linear model(LinearOptions(5, 2).bias(false)); 49 /// ``` 50 class TORCH_API LinearImpl : public Cloneable<LinearImpl> { 51 public: LinearImpl(int64_t in_features,int64_t out_features)52 LinearImpl(int64_t in_features, int64_t out_features) 53 : LinearImpl(LinearOptions(in_features, out_features)) {} 54 explicit LinearImpl(const LinearOptions& options_); 55 56 void reset() override; 57 58 void reset_parameters(); 59 60 /// Pretty prints the `Linear` module into the given `stream`. 61 void pretty_print(std::ostream& stream) const override; 62 63 /// Transforms the `input` tensor by multiplying with the `weight` and 64 /// optionally adding the `bias`, if `with_bias` is true in the options. 65 Tensor forward(const Tensor& input); 66 67 /// The options used to configure this module. 68 LinearOptions options; 69 70 /// The learned weight. 71 Tensor weight; 72 73 /// The learned bias. If `bias` is false in the `options`, this tensor is 74 /// undefined. 75 Tensor bias; 76 }; 77 78 /// A `ModuleHolder` subclass for `LinearImpl`. 79 /// See the documentation for `LinearImpl` class to learn what methods it 80 /// provides, and examples of how to use `Linear` with 81 /// `torch::nn::LinearOptions`. See the documentation for `ModuleHolder` to 82 /// learn about PyTorch's module storage semantics. 83 TORCH_MODULE(Linear); 84 85 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Flatten ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 86 87 /// A placeholder for Flatten operator 88 /// See https://pytorch.org/docs/main/generated/torch.nn.Flatten.html to learn 89 /// about the exact behavior of this module. 90 /// 91 /// See the documentation for `torch::nn::FlattenOptions` class to learn what 92 /// constructor arguments are supported for this module. 93 /// 94 /// Example: 95 /// ``` 96 /// Flatten model(FlattenOptions().start_dim(2).end_dim(4)); 97 /// ``` 98 class TORCH_API FlattenImpl : public Cloneable<FlattenImpl> { 99 public: 100 explicit FlattenImpl(const FlattenOptions& options_ = {}); 101 102 void reset() override; 103 104 /// Pretty prints the `Flatten` module into the given `stream`. 105 void pretty_print(std::ostream& stream) const override; 106 107 /// Applies a flatten transform on the `input`. 108 Tensor forward(const Tensor& input); 109 110 /// The options used to configure this module. 111 FlattenOptions options; 112 }; 113 114 /// A `ModuleHolder` subclass for `FlattenImpl`. 115 /// See the documentation for `FlattenImpl` class to learn what methods it 116 /// provides, and examples of how to use `Flatten` with 117 /// `torch::nn::FlattenOptions`. See the documentation for `ModuleHolder` to 118 /// learn about PyTorch's module storage semantics. 119 TORCH_MODULE(Flatten); 120 121 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Unflatten 122 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 123 124 /// A placeholder for unflatten operator 125 /// See https://pytorch.org/docs/main/generated/torch.nn.Unflatten.html to 126 /// learn about the exact behavior of this module. 127 /// 128 /// See the documentation for `torch::nn::UnflattenOptions` class to learn what 129 /// constructor arguments are supported for this module. 130 /// 131 /// Example: 132 /// ``` 133 /// Unflatten model(UnflattenOptions(0, {2, 2})); 134 /// Unflatten model(UnflattenOptions("B", {{"B1", 2}, {"B2", 2}})); 135 /// ``` 136 class TORCH_API UnflattenImpl : public Cloneable<UnflattenImpl> { 137 public: UnflattenImpl(int64_t dim,std::vector<int64_t> sizes)138 UnflattenImpl(int64_t dim, std::vector<int64_t> sizes) 139 : UnflattenImpl(UnflattenOptions(dim, sizes)) {} UnflattenImpl(std::string dimname,UnflattenOptions::namedshape_t namedshape)140 UnflattenImpl(std::string dimname, UnflattenOptions::namedshape_t namedshape) 141 : UnflattenImpl(UnflattenOptions(dimname, namedshape)) {} 142 explicit UnflattenImpl(UnflattenOptions options_); 143 144 void reset() override; 145 146 /// Pretty prints the `Unflatten` module into the given `stream`. 147 void pretty_print(std::ostream& stream) const override; 148 149 /// Applies an unflatten transform on the `input`. 150 Tensor forward(const Tensor& input); 151 152 /// The options used to configure this module. 153 UnflattenOptions options; 154 }; 155 156 /// A `ModuleHolder` subclass for `UnflattenImpl`. 157 /// See the documentation for `UnflattenImpl` class to learn what methods it 158 /// provides, and examples of how to use `Unflatten` with 159 /// `torch::nn::UnflattenOptions`. See the documentation for `ModuleHolder` to 160 /// learn about PyTorch's module storage semantics. 161 TORCH_MODULE(Unflatten); 162 163 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Bilinear ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 164 165 /// Applies a billinear transformation with optional bias. 166 /// See https://pytorch.org/docs/main/generated/torch.nn.Bilinear.html to 167 /// learn about the exact behavior of this module. 168 /// 169 /// See the documentation for `torch::nn::BilinearOptions` class to learn what 170 /// constructor arguments are supported for this module. 171 /// 172 /// Example: 173 /// ``` 174 /// Bilinear model(BilinearOptions(3, 2, 4).bias(false)); 175 /// ``` 176 class TORCH_API BilinearImpl : public Cloneable<BilinearImpl> { 177 public: BilinearImpl(int64_t in1_features,int64_t in2_features,int64_t out_features)178 BilinearImpl(int64_t in1_features, int64_t in2_features, int64_t out_features) 179 : BilinearImpl( 180 BilinearOptions(in1_features, in2_features, out_features)) {} 181 explicit BilinearImpl(const BilinearOptions& options_); 182 183 void reset() override; 184 185 void reset_parameters(); 186 187 /// Pretty prints the `Bilinear` module into the given `stream`. 188 void pretty_print(std::ostream& stream) const override; 189 190 /// Applies a bilinear transform on the `input1` and `input2` tensor by 191 /// multiplying with the `weight` and optionally adding the `bias`, if 192 /// `with_bias` is true in the options. 193 Tensor forward(const Tensor& input1, const Tensor& input2); 194 195 /// The options used to configure this module. 196 BilinearOptions options; 197 198 /// The learned weight. 199 Tensor weight; 200 201 /// The learned bias. If `with_bias` is false in the `options`, this tensor is 202 /// undefined. 203 Tensor bias; 204 }; 205 206 /// A `ModuleHolder` subclass for `BilinearImpl`. 207 /// See the documentation for `BilinearImpl` class to learn what methods it 208 /// provides, and examples of how to use `Bilinear` with 209 /// `torch::nn::BilinearOptions`. See the documentation for `ModuleHolder` to 210 /// learn about PyTorch's module storage semantics. 211 TORCH_MODULE(Bilinear); 212 213 } // namespace nn 214 } // namespace torch 215