1 #pragma once 2 3 #include <torch/arg.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/enum.h> 6 #include <torch/types.h> 7 8 namespace torch { 9 namespace nn { 10 11 namespace detail { 12 13 /// Common options for RNN, LSTM and GRU modules. 14 struct TORCH_API RNNOptionsBase { 15 typedef std::variant< 16 enumtype::kLSTM, 17 enumtype::kGRU, 18 enumtype::kRNN_TANH, 19 enumtype::kRNN_RELU> 20 rnn_options_base_mode_t; 21 22 RNNOptionsBase( 23 rnn_options_base_mode_t mode, 24 int64_t input_size, 25 int64_t hidden_size); 26 27 TORCH_ARG(rnn_options_base_mode_t, mode); 28 /// The number of features of a single sample in the input sequence `x`. 29 TORCH_ARG(int64_t, input_size); 30 /// The number of features in the hidden state `h`. 31 TORCH_ARG(int64_t, hidden_size); 32 /// The number of recurrent layers (cells) to use. 33 TORCH_ARG(int64_t, num_layers) = 1; 34 /// Whether a bias term should be added to all linear operations. 35 TORCH_ARG(bool, bias) = true; 36 /// If true, the input sequence should be provided as `(batch, sequence, 37 /// features)`. If false (default), the expected layout is `(sequence, batch, 38 /// features)`. 39 TORCH_ARG(bool, batch_first) = false; 40 /// If non-zero, adds dropout with the given probability to the output of each 41 /// RNN layer, except the final layer. 42 TORCH_ARG(double, dropout) = 0.0; 43 /// Whether to make the RNN bidirectional. 44 TORCH_ARG(bool, bidirectional) = false; 45 /// Cell projection dimension. If 0, projections are not added. Can only be 46 /// used for LSTMs. 47 TORCH_ARG(int64_t, proj_size) = 0; 48 }; 49 50 } // namespace detail 51 52 /// Options for the `RNN` module. 53 /// 54 /// Example: 55 /// ``` 56 /// RNN model(RNNOptions(128, 57 /// 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh)); 58 /// ``` 59 struct TORCH_API RNNOptions { 60 typedef std::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t; 61 62 RNNOptions(int64_t input_size, int64_t hidden_size); 63 64 /// The number of expected features in the input `x` 65 TORCH_ARG(int64_t, input_size); 66 /// The number of features in the hidden state `h` 67 TORCH_ARG(int64_t, hidden_size); 68 /// Number of recurrent layers. E.g., setting ``num_layers=2`` 69 /// would mean stacking two RNNs together to form a `stacked RNN`, 70 /// with the second RNN taking in outputs of the first RNN and 71 /// computing the final results. Default: 1 72 TORCH_ARG(int64_t, num_layers) = 1; 73 /// The non-linearity to use. Can be either ``torch::kTanh`` or 74 /// ``torch::kReLU``. Default: ``torch::kTanh`` 75 TORCH_ARG(nonlinearity_t, nonlinearity) = torch::kTanh; 76 /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`. 77 /// Default: ``true`` 78 TORCH_ARG(bool, bias) = true; 79 /// If ``true``, then the input and output tensors are provided 80 /// as `(batch, seq, feature)`. Default: ``false`` 81 TORCH_ARG(bool, batch_first) = false; 82 /// If non-zero, introduces a `Dropout` layer on the outputs of each 83 /// RNN layer except the last layer, with dropout probability equal to 84 /// `dropout`. Default: 0 85 TORCH_ARG(double, dropout) = 0.0; 86 /// If ``true``, becomes a bidirectional RNN. Default: ``false`` 87 TORCH_ARG(bool, bidirectional) = false; 88 }; 89 90 /// Options for the `LSTM` module. 91 /// 92 /// Example: 93 /// ``` 94 /// LSTM model(LSTMOptions(2, 95 /// 4).num_layers(3).batch_first(false).bidirectional(true)); 96 /// ``` 97 struct TORCH_API LSTMOptions { 98 LSTMOptions(int64_t input_size, int64_t hidden_size); 99 100 /// The number of expected features in the input `x` 101 TORCH_ARG(int64_t, input_size); 102 /// The number of features in the hidden state `h` 103 TORCH_ARG(int64_t, hidden_size); 104 /// Number of recurrent layers. E.g., setting ``num_layers=2`` 105 /// would mean stacking two LSTMs together to form a `stacked LSTM`, 106 /// with the second LSTM taking in outputs of the first LSTM and 107 /// computing the final results. Default: 1 108 TORCH_ARG(int64_t, num_layers) = 1; 109 /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`. 110 /// Default: ``true`` 111 TORCH_ARG(bool, bias) = true; 112 /// If ``true``, then the input and output tensors are provided 113 /// as (batch, seq, feature). Default: ``false`` 114 TORCH_ARG(bool, batch_first) = false; 115 /// If non-zero, introduces a `Dropout` layer on the outputs of each 116 /// LSTM layer except the last layer, with dropout probability equal to 117 /// `dropout`. Default: 0 118 TORCH_ARG(double, dropout) = 0.0; 119 /// If ``true``, becomes a bidirectional LSTM. Default: ``false`` 120 TORCH_ARG(bool, bidirectional) = false; 121 /// Cell projection dimension. If 0, projections are not added 122 TORCH_ARG(int64_t, proj_size) = 0; 123 }; 124 125 /// Options for the `GRU` module. 126 /// 127 /// Example: 128 /// ``` 129 /// GRU model(GRUOptions(2, 130 /// 4).num_layers(3).batch_first(false).bidirectional(true)); 131 /// ``` 132 struct TORCH_API GRUOptions { 133 GRUOptions(int64_t input_size, int64_t hidden_size); 134 135 /// The number of expected features in the input `x` 136 TORCH_ARG(int64_t, input_size); 137 /// The number of features in the hidden state `h` 138 TORCH_ARG(int64_t, hidden_size); 139 /// Number of recurrent layers. E.g., setting ``num_layers=2`` 140 /// would mean stacking two GRUs together to form a `stacked GRU`, 141 /// with the second GRU taking in outputs of the first GRU and 142 /// computing the final results. Default: 1 143 TORCH_ARG(int64_t, num_layers) = 1; 144 /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`. 145 /// Default: ``true`` 146 TORCH_ARG(bool, bias) = true; 147 /// If ``true``, then the input and output tensors are provided 148 /// as (batch, seq, feature). Default: ``false`` 149 TORCH_ARG(bool, batch_first) = false; 150 /// If non-zero, introduces a `Dropout` layer on the outputs of each 151 /// GRU layer except the last layer, with dropout probability equal to 152 /// `dropout`. Default: 0 153 TORCH_ARG(double, dropout) = 0.0; 154 /// If ``true``, becomes a bidirectional GRU. Default: ``false`` 155 TORCH_ARG(bool, bidirectional) = false; 156 }; 157 158 namespace detail { 159 160 /// Common options for RNNCell, LSTMCell and GRUCell modules 161 struct TORCH_API RNNCellOptionsBase { 162 RNNCellOptionsBase( 163 int64_t input_size, 164 int64_t hidden_size, 165 bool bias, 166 int64_t num_chunks); 167 TORCH_ARG(int64_t, input_size); 168 TORCH_ARG(int64_t, hidden_size); 169 TORCH_ARG(bool, bias); 170 TORCH_ARG(int64_t, num_chunks); 171 }; 172 173 } // namespace detail 174 175 /// Options for the `RNNCell` module. 176 /// 177 /// Example: 178 /// ``` 179 /// RNNCell model(RNNCellOptions(20, 180 /// 10).bias(false).nonlinearity(torch::kReLU)); 181 /// ``` 182 struct TORCH_API RNNCellOptions { 183 typedef std::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t; 184 185 RNNCellOptions(int64_t input_size, int64_t hidden_size); 186 187 /// The number of expected features in the input `x` 188 TORCH_ARG(int64_t, input_size); 189 /// The number of features in the hidden state `h` 190 TORCH_ARG(int64_t, hidden_size); 191 /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`. 192 /// Default: ``true`` 193 TORCH_ARG(bool, bias) = true; 194 /// The non-linearity to use. Can be either ``torch::kTanh`` or 195 /// ``torch::kReLU``. Default: ``torch::kTanh`` 196 TORCH_ARG(nonlinearity_t, nonlinearity) = torch::kTanh; 197 }; 198 199 /// Options for the `LSTMCell` module. 200 /// 201 /// Example: 202 /// ``` 203 /// LSTMCell model(LSTMCellOptions(20, 10).bias(false)); 204 /// ``` 205 struct TORCH_API LSTMCellOptions { 206 LSTMCellOptions(int64_t input_size, int64_t hidden_size); 207 208 /// The number of expected features in the input `x` 209 TORCH_ARG(int64_t, input_size); 210 /// The number of features in the hidden state `h` 211 TORCH_ARG(int64_t, hidden_size); 212 /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`. 213 /// Default: ``true`` 214 TORCH_ARG(bool, bias) = true; 215 }; 216 217 /// Options for the `GRUCell` module. 218 /// 219 /// Example: 220 /// ``` 221 /// GRUCell model(GRUCellOptions(20, 10).bias(false)); 222 /// ``` 223 struct TORCH_API GRUCellOptions { 224 GRUCellOptions(int64_t input_size, int64_t hidden_size); 225 226 /// The number of expected features in the input `x` 227 TORCH_ARG(int64_t, input_size); 228 /// The number of features in the hidden state `h` 229 TORCH_ARG(int64_t, hidden_size); 230 /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`. 231 /// Default: ``true`` 232 TORCH_ARG(bool, bias) = true; 233 }; 234 235 } // namespace nn 236 } // namespace torch 237