xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/options/rnn.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/enum.h>
6 #include <torch/types.h>
7 
8 namespace torch {
9 namespace nn {
10 
11 namespace detail {
12 
13 /// Common options for RNN, LSTM and GRU modules.
14 struct TORCH_API RNNOptionsBase {
15   typedef std::variant<
16       enumtype::kLSTM,
17       enumtype::kGRU,
18       enumtype::kRNN_TANH,
19       enumtype::kRNN_RELU>
20       rnn_options_base_mode_t;
21 
22   RNNOptionsBase(
23       rnn_options_base_mode_t mode,
24       int64_t input_size,
25       int64_t hidden_size);
26 
27   TORCH_ARG(rnn_options_base_mode_t, mode);
28   /// The number of features of a single sample in the input sequence `x`.
29   TORCH_ARG(int64_t, input_size);
30   /// The number of features in the hidden state `h`.
31   TORCH_ARG(int64_t, hidden_size);
32   /// The number of recurrent layers (cells) to use.
33   TORCH_ARG(int64_t, num_layers) = 1;
34   /// Whether a bias term should be added to all linear operations.
35   TORCH_ARG(bool, bias) = true;
36   /// If true, the input sequence should be provided as `(batch, sequence,
37   /// features)`. If false (default), the expected layout is `(sequence, batch,
38   /// features)`.
39   TORCH_ARG(bool, batch_first) = false;
40   /// If non-zero, adds dropout with the given probability to the output of each
41   /// RNN layer, except the final layer.
42   TORCH_ARG(double, dropout) = 0.0;
43   /// Whether to make the RNN bidirectional.
44   TORCH_ARG(bool, bidirectional) = false;
45   /// Cell projection dimension. If 0, projections are not added. Can only be
46   /// used for LSTMs.
47   TORCH_ARG(int64_t, proj_size) = 0;
48 };
49 
50 } // namespace detail
51 
52 /// Options for the `RNN` module.
53 ///
54 /// Example:
55 /// ```
56 /// RNN model(RNNOptions(128,
57 /// 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh));
58 /// ```
59 struct TORCH_API RNNOptions {
60   typedef std::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t;
61 
62   RNNOptions(int64_t input_size, int64_t hidden_size);
63 
64   /// The number of expected features in the input `x`
65   TORCH_ARG(int64_t, input_size);
66   /// The number of features in the hidden state `h`
67   TORCH_ARG(int64_t, hidden_size);
68   /// Number of recurrent layers. E.g., setting ``num_layers=2``
69   /// would mean stacking two RNNs together to form a `stacked RNN`,
70   /// with the second RNN taking in outputs of the first RNN and
71   /// computing the final results. Default: 1
72   TORCH_ARG(int64_t, num_layers) = 1;
73   /// The non-linearity to use. Can be either ``torch::kTanh`` or
74   /// ``torch::kReLU``. Default: ``torch::kTanh``
75   TORCH_ARG(nonlinearity_t, nonlinearity) = torch::kTanh;
76   /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`.
77   /// Default: ``true``
78   TORCH_ARG(bool, bias) = true;
79   /// If ``true``, then the input and output tensors are provided
80   /// as `(batch, seq, feature)`. Default: ``false``
81   TORCH_ARG(bool, batch_first) = false;
82   /// If non-zero, introduces a `Dropout` layer on the outputs of each
83   /// RNN layer except the last layer, with dropout probability equal to
84   /// `dropout`. Default: 0
85   TORCH_ARG(double, dropout) = 0.0;
86   /// If ``true``, becomes a bidirectional RNN. Default: ``false``
87   TORCH_ARG(bool, bidirectional) = false;
88 };
89 
90 /// Options for the `LSTM` module.
91 ///
92 /// Example:
93 /// ```
94 /// LSTM model(LSTMOptions(2,
95 /// 4).num_layers(3).batch_first(false).bidirectional(true));
96 /// ```
97 struct TORCH_API LSTMOptions {
98   LSTMOptions(int64_t input_size, int64_t hidden_size);
99 
100   /// The number of expected features in the input `x`
101   TORCH_ARG(int64_t, input_size);
102   /// The number of features in the hidden state `h`
103   TORCH_ARG(int64_t, hidden_size);
104   /// Number of recurrent layers. E.g., setting ``num_layers=2``
105   /// would mean stacking two LSTMs together to form a `stacked LSTM`,
106   /// with the second LSTM taking in outputs of the first LSTM and
107   /// computing the final results. Default: 1
108   TORCH_ARG(int64_t, num_layers) = 1;
109   /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`.
110   /// Default: ``true``
111   TORCH_ARG(bool, bias) = true;
112   /// If ``true``, then the input and output tensors are provided
113   /// as (batch, seq, feature). Default: ``false``
114   TORCH_ARG(bool, batch_first) = false;
115   /// If non-zero, introduces a `Dropout` layer on the outputs of each
116   /// LSTM layer except the last layer, with dropout probability equal to
117   /// `dropout`. Default: 0
118   TORCH_ARG(double, dropout) = 0.0;
119   /// If ``true``, becomes a bidirectional LSTM. Default: ``false``
120   TORCH_ARG(bool, bidirectional) = false;
121   /// Cell projection dimension. If 0, projections are not added
122   TORCH_ARG(int64_t, proj_size) = 0;
123 };
124 
125 /// Options for the `GRU` module.
126 ///
127 /// Example:
128 /// ```
129 /// GRU model(GRUOptions(2,
130 /// 4).num_layers(3).batch_first(false).bidirectional(true));
131 /// ```
132 struct TORCH_API GRUOptions {
133   GRUOptions(int64_t input_size, int64_t hidden_size);
134 
135   /// The number of expected features in the input `x`
136   TORCH_ARG(int64_t, input_size);
137   /// The number of features in the hidden state `h`
138   TORCH_ARG(int64_t, hidden_size);
139   /// Number of recurrent layers. E.g., setting ``num_layers=2``
140   /// would mean stacking two GRUs together to form a `stacked GRU`,
141   /// with the second GRU taking in outputs of the first GRU and
142   /// computing the final results. Default: 1
143   TORCH_ARG(int64_t, num_layers) = 1;
144   /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`.
145   /// Default: ``true``
146   TORCH_ARG(bool, bias) = true;
147   /// If ``true``, then the input and output tensors are provided
148   /// as (batch, seq, feature). Default: ``false``
149   TORCH_ARG(bool, batch_first) = false;
150   /// If non-zero, introduces a `Dropout` layer on the outputs of each
151   /// GRU layer except the last layer, with dropout probability equal to
152   /// `dropout`. Default: 0
153   TORCH_ARG(double, dropout) = 0.0;
154   /// If ``true``, becomes a bidirectional GRU. Default: ``false``
155   TORCH_ARG(bool, bidirectional) = false;
156 };
157 
158 namespace detail {
159 
160 /// Common options for RNNCell, LSTMCell and GRUCell modules
161 struct TORCH_API RNNCellOptionsBase {
162   RNNCellOptionsBase(
163       int64_t input_size,
164       int64_t hidden_size,
165       bool bias,
166       int64_t num_chunks);
167   TORCH_ARG(int64_t, input_size);
168   TORCH_ARG(int64_t, hidden_size);
169   TORCH_ARG(bool, bias);
170   TORCH_ARG(int64_t, num_chunks);
171 };
172 
173 } // namespace detail
174 
175 /// Options for the `RNNCell` module.
176 ///
177 /// Example:
178 /// ```
179 /// RNNCell model(RNNCellOptions(20,
180 /// 10).bias(false).nonlinearity(torch::kReLU));
181 /// ```
182 struct TORCH_API RNNCellOptions {
183   typedef std::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t;
184 
185   RNNCellOptions(int64_t input_size, int64_t hidden_size);
186 
187   /// The number of expected features in the input `x`
188   TORCH_ARG(int64_t, input_size);
189   /// The number of features in the hidden state `h`
190   TORCH_ARG(int64_t, hidden_size);
191   /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`.
192   /// Default: ``true``
193   TORCH_ARG(bool, bias) = true;
194   /// The non-linearity to use. Can be either ``torch::kTanh`` or
195   /// ``torch::kReLU``. Default: ``torch::kTanh``
196   TORCH_ARG(nonlinearity_t, nonlinearity) = torch::kTanh;
197 };
198 
199 /// Options for the `LSTMCell` module.
200 ///
201 /// Example:
202 /// ```
203 /// LSTMCell model(LSTMCellOptions(20, 10).bias(false));
204 /// ```
205 struct TORCH_API LSTMCellOptions {
206   LSTMCellOptions(int64_t input_size, int64_t hidden_size);
207 
208   /// The number of expected features in the input `x`
209   TORCH_ARG(int64_t, input_size);
210   /// The number of features in the hidden state `h`
211   TORCH_ARG(int64_t, hidden_size);
212   /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`.
213   /// Default: ``true``
214   TORCH_ARG(bool, bias) = true;
215 };
216 
217 /// Options for the `GRUCell` module.
218 ///
219 /// Example:
220 /// ```
221 /// GRUCell model(GRUCellOptions(20, 10).bias(false));
222 /// ```
223 struct TORCH_API GRUCellOptions {
224   GRUCellOptions(int64_t input_size, int64_t hidden_size);
225 
226   /// The number of expected features in the input `x`
227   TORCH_ARG(int64_t, input_size);
228   /// The number of features in the hidden state `h`
229   TORCH_ARG(int64_t, hidden_size);
230   /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`.
231   /// Default: ``true``
232   TORCH_ARG(bool, bias) = true;
233 };
234 
235 } // namespace nn
236 } // namespace torch
237