xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/linear.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/functional/linear.h>
5 #include <torch/nn/module.h>
6 #include <torch/nn/options/linear.h>
7 #include <torch/nn/pimpl.h>
8 #include <torch/types.h>
9 
10 #include <cstddef>
11 #include <vector>
12 
13 namespace torch {
14 namespace nn {
15 
16 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Identity ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
17 
18 /// A placeholder identity operator that is argument-insensitive.
19 /// See https://pytorch.org/docs/main/generated/torch.nn.Identity.html to
20 /// learn about the exact behavior of this module.
21 class TORCH_API IdentityImpl : public Cloneable<IdentityImpl> {
22  public:
23   void reset() override;
24 
25   /// Pretty prints the `Identity` module into the given `stream`.
26   void pretty_print(std::ostream& stream) const override;
27 
28   Tensor forward(const Tensor& input);
29 };
30 
31 /// A `ModuleHolder` subclass for `IdentityImpl`.
32 /// See the documentation for `IdentityImpl` class to learn what methods it
33 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
34 /// module storage semantics.
35 TORCH_MODULE(Identity);
36 
37 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Linear ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
38 
39 /// Applies a linear transformation with optional bias.
40 /// See https://pytorch.org/docs/main/generated/torch.nn.Linear.html to learn
41 /// about the exact behavior of this module.
42 ///
43 /// See the documentation for `torch::nn::LinearOptions` class to learn what
44 /// constructor arguments are supported for this module.
45 ///
46 /// Example:
47 /// ```
48 /// Linear model(LinearOptions(5, 2).bias(false));
49 /// ```
50 class TORCH_API LinearImpl : public Cloneable<LinearImpl> {
51  public:
LinearImpl(int64_t in_features,int64_t out_features)52   LinearImpl(int64_t in_features, int64_t out_features)
53       : LinearImpl(LinearOptions(in_features, out_features)) {}
54   explicit LinearImpl(const LinearOptions& options_);
55 
56   void reset() override;
57 
58   void reset_parameters();
59 
60   /// Pretty prints the `Linear` module into the given `stream`.
61   void pretty_print(std::ostream& stream) const override;
62 
63   /// Transforms the `input` tensor by multiplying with the `weight` and
64   /// optionally adding the `bias`, if `with_bias` is true in the options.
65   Tensor forward(const Tensor& input);
66 
67   /// The options used to configure this module.
68   LinearOptions options;
69 
70   /// The learned weight.
71   Tensor weight;
72 
73   /// The learned bias. If `bias` is false in the `options`, this tensor is
74   /// undefined.
75   Tensor bias;
76 };
77 
78 /// A `ModuleHolder` subclass for `LinearImpl`.
79 /// See the documentation for `LinearImpl` class to learn what methods it
80 /// provides, and examples of how to use `Linear` with
81 /// `torch::nn::LinearOptions`. See the documentation for `ModuleHolder` to
82 /// learn about PyTorch's module storage semantics.
83 TORCH_MODULE(Linear);
84 
85 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Flatten ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
86 
87 /// A placeholder for Flatten operator
88 /// See https://pytorch.org/docs/main/generated/torch.nn.Flatten.html to learn
89 /// about the exact behavior of this module.
90 ///
91 /// See the documentation for `torch::nn::FlattenOptions` class to learn what
92 /// constructor arguments are supported for this module.
93 ///
94 /// Example:
95 /// ```
96 /// Flatten model(FlattenOptions().start_dim(2).end_dim(4));
97 /// ```
98 class TORCH_API FlattenImpl : public Cloneable<FlattenImpl> {
99  public:
100   explicit FlattenImpl(const FlattenOptions& options_ = {});
101 
102   void reset() override;
103 
104   /// Pretty prints the `Flatten` module into the given `stream`.
105   void pretty_print(std::ostream& stream) const override;
106 
107   /// Applies a flatten transform on the `input`.
108   Tensor forward(const Tensor& input);
109 
110   /// The options used to configure this module.
111   FlattenOptions options;
112 };
113 
114 /// A `ModuleHolder` subclass for `FlattenImpl`.
115 /// See the documentation for `FlattenImpl` class to learn what methods it
116 /// provides, and examples of how to use `Flatten` with
117 /// `torch::nn::FlattenOptions`. See the documentation for `ModuleHolder` to
118 /// learn about PyTorch's module storage semantics.
119 TORCH_MODULE(Flatten);
120 
121 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Unflatten
122 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
123 
124 /// A placeholder for unflatten operator
125 /// See https://pytorch.org/docs/main/generated/torch.nn.Unflatten.html to
126 /// learn about the exact behavior of this module.
127 ///
128 /// See the documentation for `torch::nn::UnflattenOptions` class to learn what
129 /// constructor arguments are supported for this module.
130 ///
131 /// Example:
132 /// ```
133 /// Unflatten model(UnflattenOptions(0, {2, 2}));
134 /// Unflatten model(UnflattenOptions("B", {{"B1", 2}, {"B2", 2}}));
135 /// ```
136 class TORCH_API UnflattenImpl : public Cloneable<UnflattenImpl> {
137  public:
UnflattenImpl(int64_t dim,std::vector<int64_t> sizes)138   UnflattenImpl(int64_t dim, std::vector<int64_t> sizes)
139       : UnflattenImpl(UnflattenOptions(dim, sizes)) {}
UnflattenImpl(std::string dimname,UnflattenOptions::namedshape_t namedshape)140   UnflattenImpl(std::string dimname, UnflattenOptions::namedshape_t namedshape)
141       : UnflattenImpl(UnflattenOptions(dimname, namedshape)) {}
142   explicit UnflattenImpl(UnflattenOptions options_);
143 
144   void reset() override;
145 
146   /// Pretty prints the `Unflatten` module into the given `stream`.
147   void pretty_print(std::ostream& stream) const override;
148 
149   /// Applies an unflatten transform on the `input`.
150   Tensor forward(const Tensor& input);
151 
152   /// The options used to configure this module.
153   UnflattenOptions options;
154 };
155 
156 /// A `ModuleHolder` subclass for `UnflattenImpl`.
157 /// See the documentation for `UnflattenImpl` class to learn what methods it
158 /// provides, and examples of how to use `Unflatten` with
159 /// `torch::nn::UnflattenOptions`. See the documentation for `ModuleHolder` to
160 /// learn about PyTorch's module storage semantics.
161 TORCH_MODULE(Unflatten);
162 
163 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Bilinear ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
164 
165 /// Applies a billinear transformation with optional bias.
166 /// See https://pytorch.org/docs/main/generated/torch.nn.Bilinear.html to
167 /// learn about the exact behavior of this module.
168 ///
169 /// See the documentation for `torch::nn::BilinearOptions` class to learn what
170 /// constructor arguments are supported for this module.
171 ///
172 /// Example:
173 /// ```
174 /// Bilinear model(BilinearOptions(3, 2, 4).bias(false));
175 /// ```
176 class TORCH_API BilinearImpl : public Cloneable<BilinearImpl> {
177  public:
BilinearImpl(int64_t in1_features,int64_t in2_features,int64_t out_features)178   BilinearImpl(int64_t in1_features, int64_t in2_features, int64_t out_features)
179       : BilinearImpl(
180             BilinearOptions(in1_features, in2_features, out_features)) {}
181   explicit BilinearImpl(const BilinearOptions& options_);
182 
183   void reset() override;
184 
185   void reset_parameters();
186 
187   /// Pretty prints the `Bilinear` module into the given `stream`.
188   void pretty_print(std::ostream& stream) const override;
189 
190   /// Applies a bilinear transform on the `input1` and `input2` tensor by
191   /// multiplying with the `weight` and optionally adding the `bias`, if
192   /// `with_bias` is true in the options.
193   Tensor forward(const Tensor& input1, const Tensor& input2);
194 
195   /// The options used to configure this module.
196   BilinearOptions options;
197 
198   /// The learned weight.
199   Tensor weight;
200 
201   /// The learned bias. If `with_bias` is false in the `options`, this tensor is
202   /// undefined.
203   Tensor bias;
204 };
205 
206 /// A `ModuleHolder` subclass for `BilinearImpl`.
207 /// See the documentation for `BilinearImpl` class to learn what methods it
208 /// provides, and examples of how to use `Bilinear` with
209 /// `torch::nn::BilinearOptions`. See the documentation for `ModuleHolder` to
210 /// learn about PyTorch's module storage semantics.
211 TORCH_MODULE(Bilinear);
212 
213 } // namespace nn
214 } // namespace torch
215