1 #include <torch/nn/modules/rnn.h>
2
3 #include <torch/nn/init.h>
4 #include <torch/types.h>
5 #include <torch/utils.h>
6
7 #include <c10/util/Exception.h>
8 #include <c10/util/irange.h>
9
10 #include <cmath>
11 #include <cstdint>
12 #include <regex>
13 #include <string>
14 #include <tuple>
15 #include <unordered_set>
16 #include <utility>
17 #include <vector>
18
19 using namespace torch::nn::utils::rnn;
20
21 namespace torch {
22 namespace nn {
23
24 /// These must line up with the CUDNN mode codes:
25 /// https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t
26 enum class CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 };
27
get_cudnn_mode_for_rnn(detail::RNNOptionsBase::rnn_options_base_mode_t mode)28 static CuDNNMode get_cudnn_mode_for_rnn(
29 detail::RNNOptionsBase::rnn_options_base_mode_t mode) {
30 if (std::holds_alternative<enumtype::kRNN_RELU>(mode)) {
31 return CuDNNMode::RNN_RELU;
32 } else if (std::holds_alternative<enumtype::kRNN_TANH>(mode)) {
33 return CuDNNMode::RNN_TANH;
34 } else if (std::holds_alternative<enumtype::kLSTM>(mode)) {
35 return CuDNNMode::LSTM;
36 } else if (std::holds_alternative<enumtype::kGRU>(mode)) {
37 return CuDNNMode::GRU;
38 } else {
39 TORCH_CHECK(false, "Unknown mode: ", torch::enumtype::get_enum_name(mode));
40 }
41 }
42
apply_permutation(const Tensor & tensor,const Tensor & permutation,int64_t dim=1)43 static Tensor apply_permutation(
44 const Tensor& tensor,
45 const Tensor& permutation,
46 int64_t dim = 1) {
47 return tensor.index_select(dim, permutation);
48 }
49
50 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNImplBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
51 namespace detail {
52 template <typename Derived>
RNNImplBase(const RNNOptionsBase & options_)53 RNNImplBase<Derived>::RNNImplBase(const RNNOptionsBase& options_)
54 : options_base(options_) {
55 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
56 reset();
57 }
58
59 template <typename Derived>
reset()60 void RNNImplBase<Derived>::reset() {
61 const int64_t num_directions = options_base.bidirectional() ? 2 : 1;
62
63 TORCH_CHECK(
64 0 <= options_base.dropout() && options_base.dropout() <= 1,
65 "dropout should be a number in range [0, 1] ",
66 "representing the probability of an element being ",
67 "zeroed");
68
69 if (options_base.dropout() > 0 && options_base.num_layers() == 1) {
70 TORCH_WARN(
71 "dropout option adds dropout after all but last ",
72 "recurrent layer, so non-zero dropout expects ",
73 "num_layers greater than 1, but got dropout=",
74 options_base.dropout(),
75 " and ",
76 "num_layers=",
77 options_base.num_layers());
78 }
79
80 TORCH_CHECK(
81 options_base.hidden_size() > 0, "hidden_size must be greater than zero");
82
83 TORCH_CHECK(
84 options_base.num_layers() > 0, "num_layers must be greater than zero");
85
86 TORCH_CHECK(
87 0 <= options_base.proj_size() &&
88 options_base.proj_size() < options_base.hidden_size(),
89 "proj_size has to be a positive integer, smaller than ",
90 "hidden_size or zero to disable projections");
91
92 if (options_base.proj_size() > 0) {
93 TORCH_CHECK(
94 std::get_if<enumtype::kLSTM>(&options_base.mode()),
95 "proj_size argument is only supported for LSTM, not RNN or GRU");
96 }
97
98 int64_t gate_size = 0;
99 if (std::holds_alternative<enumtype::kLSTM>(options_base.mode())) {
100 gate_size = 4 * options_base.hidden_size();
101 } else if (std::holds_alternative<enumtype::kGRU>(options_base.mode())) {
102 gate_size = 3 * options_base.hidden_size();
103 // NOLINTNEXTLINE(bugprone-branch-clone)
104 } else if (std::holds_alternative<enumtype::kRNN_TANH>(options_base.mode())) {
105 gate_size = options_base.hidden_size();
106 } else if (std::holds_alternative<enumtype::kRNN_RELU>(options_base.mode())) {
107 gate_size = options_base.hidden_size();
108 } else {
109 TORCH_CHECK(
110 false,
111 "Unrecognized RNN mode: " +
112 torch::enumtype::get_enum_name(options_base.mode()));
113 }
114
115 flat_weights_names_ = {};
116 all_weights_ = {};
117
118 for (const auto layer : c10::irange(options_base.num_layers())) {
119 for (const auto direction : c10::irange(num_directions)) {
120 int64_t real_hidden_size = options_base.proj_size() > 0
121 ? options_base.proj_size()
122 : options_base.hidden_size();
123 int64_t layer_input_size = layer == 0 ? options_base.input_size()
124 : real_hidden_size * num_directions;
125
126 auto w_ih = torch::empty({gate_size, layer_input_size});
127 auto w_hh = torch::empty({gate_size, real_hidden_size});
128 auto b_ih = torch::empty({gate_size});
129 // Second bias vector included for CuDNN compatibility. Only one
130 // bias vector is needed in standard definition.
131 auto b_hh = torch::empty({gate_size});
132 std::vector<Tensor> layer_params = {w_ih, w_hh};
133
134 std::string suffix = direction == 1 ? "_reverse" : "";
135 std::vector<std::string> param_names = {
136 "weight_ih_l{layer}{suffix}", "weight_hh_l{layer}{suffix}"};
137 if (options_base.bias()) {
138 param_names.emplace_back("bias_ih_l{layer}{suffix}");
139 param_names.emplace_back("bias_hh_l{layer}{suffix}");
140 layer_params.emplace_back(b_ih);
141 layer_params.emplace_back(b_hh);
142 }
143 if (options_base.proj_size() > 0) {
144 auto w_hr = torch::empty(
145 {options_base.proj_size(), options_base.hidden_size()});
146 layer_params.emplace_back(w_hr);
147 param_names.emplace_back("weight_hr_l{layer}{suffix}");
148 }
149 for (auto& param_name : param_names) {
150 std::string x = std::regex_replace(
151 param_name, std::regex("\\{layer\\}"), c10::str(layer));
152 param_name =
153 std::regex_replace(x, std::regex("\\{suffix\\}"), c10::str(suffix));
154 }
155
156 for (const auto i : c10::irange(param_names.size())) {
157 this->register_parameter(param_names[i], std::move(layer_params[i]));
158 }
159 flat_weights_names_.insert(
160 flat_weights_names_.end(), param_names.begin(), param_names.end());
161 all_weights_.emplace_back(std::move(param_names));
162 }
163 }
164
165 flat_weights_ = {};
166 for (const auto& wn : flat_weights_names_) {
167 auto named_parameters = this->named_parameters(/*recurse=*/false);
168 if (named_parameters.contains(wn)) {
169 flat_weights_.emplace_back(named_parameters[wn]);
170 } else {
171 flat_weights_.emplace_back();
172 }
173 }
174
175 this->flatten_parameters();
176 this->reset_parameters();
177 }
178
179 template <typename Derived>
flatten_parameters()180 void RNNImplBase<Derived>::flatten_parameters() {
181 // Resets parameter data pointer so that they can use faster code paths.
182 //
183 // Right now, this works only if the module is on the GPU and cuDNN is
184 // enabled. Otherwise, it's a no-op.
185
186 // Short-circuits if flat_weights_ is only partially instantiated
187 if (flat_weights_.size() != flat_weights_names_.size()) {
188 return;
189 }
190
191 // Short-circuits if any tensor in self.flat_weights_ is not acceptable to
192 // cuDNN or the tensors in flat_weights_ are of different dtypes
193
194 auto first_fw = flat_weights_[0];
195 auto dtype = first_fw.dtype();
196 for (const auto& fw : flat_weights_) {
197 if (!(fw.dtype() == dtype) || !fw.is_cuda() ||
198 !torch::cudnn_is_acceptable(fw)) {
199 return;
200 }
201 }
202
203 // If any parameters alias, we fall back to the slower, copying code path.
204 // This is a sufficient check, because overlapping parameter buffers that
205 // don't completely alias would break the assumptions of the uniqueness check
206 // in Module::named_parameters().
207 std::unordered_set<void*> unique_data_ptrs;
208 for (const auto& p : flat_weights_) {
209 unique_data_ptrs.emplace(p.data_ptr());
210 }
211 if (unique_data_ptrs.size() != flat_weights_.size()) {
212 return;
213 }
214
215 {
216 torch::DeviceGuard device_guard(first_fw.device());
217
218 // Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
219 // an inplace operation on self.flat_weights_
220 {
221 torch::NoGradGuard no_grad;
222 if (torch::_use_cudnn_rnn_flatten_weight()) {
223 int64_t num_weights = options_base.bias() ? 4 : 2;
224 if (options_base.proj_size() > 0) {
225 ++num_weights;
226 }
227 torch::_cudnn_rnn_flatten_weight(
228 flat_weights_,
229 num_weights,
230 options_base.input_size(),
231 static_cast<int64_t>(get_cudnn_mode_for_rnn(options_base.mode())),
232 options_base.hidden_size(),
233 options_base.proj_size(),
234 options_base.num_layers(),
235 options_base.batch_first(),
236 options_base.bidirectional());
237 }
238 }
239 }
240 }
241
242 template <typename Derived>
reset_flat_weights()243 void RNNImplBase<Derived>::reset_flat_weights() {
244 flat_weights_ = {};
245 for (const auto& wn : flat_weights_names_) {
246 auto named_parameters = this->named_parameters(/*recurse=*/false);
247 if (named_parameters.contains(wn)) {
248 flat_weights_.emplace_back(named_parameters[wn]);
249 } else {
250 flat_weights_.emplace_back();
251 }
252 }
253 }
254
255 template <typename Derived>
to(torch::Device device,torch::Dtype dtype,bool non_blocking)256 void RNNImplBase<Derived>::to(
257 torch::Device device,
258 torch::Dtype dtype,
259 bool non_blocking) {
260 nn::Module::to(device, dtype, non_blocking);
261 reset_flat_weights();
262 flatten_parameters();
263 }
264
265 template <typename Derived>
to(torch::Dtype dtype,bool non_blocking)266 void RNNImplBase<Derived>::to(torch::Dtype dtype, bool non_blocking) {
267 nn::Module::to(dtype, non_blocking);
268 reset_flat_weights();
269 flatten_parameters();
270 }
271
272 template <typename Derived>
to(torch::Device device,bool non_blocking)273 void RNNImplBase<Derived>::to(torch::Device device, bool non_blocking) {
274 nn::Module::to(device, non_blocking);
275 reset_flat_weights();
276 flatten_parameters();
277 }
278
279 template <typename Derived>
reset_parameters()280 void RNNImplBase<Derived>::reset_parameters() {
281 const double stdv = 1.0 / std::sqrt(options_base.hidden_size());
282 for (auto& weight : this->parameters()) {
283 init::uniform_(weight, -stdv, stdv);
284 }
285 }
286
287 template <typename Derived>
check_input(const Tensor & input,const Tensor & batch_sizes) const288 void RNNImplBase<Derived>::check_input(
289 const Tensor& input,
290 const Tensor& batch_sizes) const {
291 int64_t expected_input_dim = batch_sizes.defined() ? 2 : 3;
292 TORCH_CHECK(
293 input.dim() == expected_input_dim,
294 "input must have ",
295 expected_input_dim,
296 " dimensions, got ",
297 input.dim());
298 TORCH_CHECK(
299 options_base.input_size() == input.size(-1),
300 "input.size(-1) must be equal to input_size. Expected ",
301 options_base.input_size(),
302 ", got ",
303 input.size(-1));
304 }
305
306 template <typename Derived>
307 std::tuple<int64_t, int64_t, int64_t> RNNImplBase<Derived>::
get_expected_hidden_size(const Tensor & input,const Tensor & batch_sizes) const308 get_expected_hidden_size(const Tensor& input, const Tensor& batch_sizes)
309 const {
310 int64_t mini_batch = 0;
311 if (batch_sizes.defined()) {
312 mini_batch = batch_sizes[0].item<int64_t>();
313 } else {
314 mini_batch = options_base.batch_first() ? input.size(0) : input.size(1);
315 }
316 int64_t num_directions = options_base.bidirectional() ? 2 : 1;
317 int64_t real_hidden_size = options_base.proj_size() > 0
318 ? options_base.proj_size()
319 : options_base.hidden_size();
320 return std::make_tuple(
321 options_base.num_layers() * num_directions, mini_batch, real_hidden_size);
322 }
323
324 template <typename Derived>
check_hidden_size(const Tensor & hx,std::tuple<int64_t,int64_t,int64_t> expected_hidden_size,std::string msg) const325 void RNNImplBase<Derived>::check_hidden_size(
326 const Tensor& hx,
327 std::tuple<int64_t, int64_t, int64_t> expected_hidden_size,
328 std::string msg) const {
329 auto expected_hidden_size_vec = std::vector<int64_t>({
330 std::get<0>(expected_hidden_size),
331 std::get<1>(expected_hidden_size),
332 std::get<2>(expected_hidden_size),
333 });
334 if (hx.sizes() != expected_hidden_size_vec) {
335 msg = std::regex_replace(
336 msg, std::regex("\\{1\\}"), c10::str(expected_hidden_size_vec));
337 msg = std::regex_replace(msg, std::regex("\\{2\\}"), c10::str(hx.sizes()));
338 TORCH_CHECK(false, msg);
339 }
340 }
341
342 template <typename Derived>
check_forward_args(Tensor input,Tensor hidden,Tensor batch_sizes) const343 void RNNImplBase<Derived>::check_forward_args(
344 Tensor input,
345 Tensor hidden,
346 Tensor batch_sizes) const {
347 this->check_input(input, batch_sizes);
348 auto expected_hidden_size =
349 this->get_expected_hidden_size(input, batch_sizes);
350
351 this->check_hidden_size(hidden, expected_hidden_size);
352 }
353
354 template <typename Derived>
permute_hidden(Tensor hx,const Tensor & permutation) const355 Tensor RNNImplBase<Derived>::permute_hidden(
356 Tensor hx,
357 const Tensor& permutation) const {
358 if (!permutation.defined()) {
359 return hx;
360 }
361 return apply_permutation(hx, permutation);
362 }
363
364 template <typename Derived>
pretty_print(std::ostream & stream) const365 void RNNImplBase<Derived>::pretty_print(std::ostream& stream) const {
366 const std::string name = this->name();
367 const std::string name_without_impl = name.substr(0, name.size() - 4);
368 stream << std::boolalpha << name_without_impl
369 << "(input_size=" << options_base.input_size()
370 << ", hidden_size=" << options_base.hidden_size()
371 << ", num_layers=" << options_base.num_layers()
372 << ", bias=" << options_base.bias()
373 << ", batch_first=" << options_base.batch_first()
374 << ", dropout=" << options_base.dropout()
375 << ", bidirectional=" << options_base.bidirectional();
376 if (options_base.proj_size() > 0) {
377 stream << ", proj_size=" << options_base.proj_size();
378 }
379 stream << ")";
380 }
381
382 template <typename Derived>
all_weights() const383 std::vector<Tensor> RNNImplBase<Derived>::all_weights() const {
384 std::vector<Tensor> result = {};
385 auto named_parameters = this->named_parameters(/*recurse=*/false);
386 for (const auto& weights : all_weights_) {
387 for (const auto& weight : weights) {
388 result.emplace_back(named_parameters[weight]);
389 }
390 }
391 return result;
392 }
393
394 template class RNNImplBase<LSTMImpl>;
395 template class RNNImplBase<GRUImpl>;
396 template class RNNImplBase<RNNImpl>;
397 } // namespace detail
398
399 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
400
401 static detail::RNNOptionsBase::rnn_options_base_mode_t
compute_rnn_options_base_mode(RNNOptions::nonlinearity_t nonlinearity)402 compute_rnn_options_base_mode(RNNOptions::nonlinearity_t nonlinearity) {
403 if (std::holds_alternative<enumtype::kTanh>(nonlinearity)) {
404 return torch::kRNN_TANH;
405 } else if (std::holds_alternative<enumtype::kReLU>(nonlinearity)) {
406 return torch::kRNN_RELU;
407 } else {
408 TORCH_CHECK(
409 false,
410 "Unknown nonlinearity ",
411 torch::enumtype::get_enum_name(nonlinearity));
412 }
413 }
414
RNNImpl(const RNNOptions & options_)415 RNNImpl::RNNImpl(const RNNOptions& options_)
416 : detail::RNNImplBase<RNNImpl>(
417 detail::RNNOptionsBase(
418 compute_rnn_options_base_mode(options_.nonlinearity()),
419 options_.input_size(),
420 options_.hidden_size())
421 .num_layers(options_.num_layers())
422 .bias(options_.bias())
423 .batch_first(options_.batch_first())
424 .dropout(options_.dropout())
425 .bidirectional(options_.bidirectional())),
426 options(options_) {}
427
forward_helper(const Tensor & input,const Tensor & batch_sizes,const Tensor & sorted_indices,int64_t max_batch_size,Tensor hx)428 std::tuple<Tensor, Tensor> RNNImpl::forward_helper(
429 const Tensor& input,
430 const Tensor& batch_sizes,
431 const Tensor& sorted_indices,
432 int64_t max_batch_size,
433 Tensor hx) {
434 if (!hx.defined()) {
435 int64_t num_directions = options_base.bidirectional() ? 2 : 1;
436 hx = torch::zeros(
437 {options_base.num_layers() * num_directions,
438 max_batch_size,
439 options_base.hidden_size()},
440 torch::dtype(input.dtype()).device(input.device()));
441 } else {
442 // Each batch of the hidden state should match the input sequence that
443 // the user believes he/she is passing in.
444 hx = this->permute_hidden(hx, sorted_indices);
445 }
446
447 this->check_forward_args(input, hx, batch_sizes);
448
449 std::tuple<Tensor, Tensor> result;
450 if (!batch_sizes.defined()) {
451 if (std::holds_alternative<enumtype::kRNN_TANH>(options_base.mode())) {
452 result = torch::rnn_tanh(
453 input,
454 hx,
455 flat_weights_,
456 options_base.bias(),
457 options_base.num_layers(),
458 options_base.dropout(),
459 this->is_training(),
460 options_base.bidirectional(),
461 options_base.batch_first());
462 } else if (std::holds_alternative<enumtype::kRNN_RELU>(
463 options_base.mode())) {
464 result = torch::rnn_relu(
465 input,
466 hx,
467 flat_weights_,
468 options_base.bias(),
469 options_base.num_layers(),
470 options_base.dropout(),
471 this->is_training(),
472 options_base.bidirectional(),
473 options_base.batch_first());
474 } else {
475 TORCH_CHECK(
476 false,
477 "Unknown mode: ",
478 torch::enumtype::get_enum_name(options_base.mode()));
479 }
480 } else {
481 if (std::holds_alternative<enumtype::kRNN_TANH>(options_base.mode())) {
482 result = torch::rnn_tanh(
483 input,
484 batch_sizes,
485 hx,
486 flat_weights_,
487 options_base.bias(),
488 options_base.num_layers(),
489 options_base.dropout(),
490 this->is_training(),
491 options_base.bidirectional());
492 } else if (std::holds_alternative<enumtype::kRNN_RELU>(
493 options_base.mode())) {
494 result = torch::rnn_relu(
495 input,
496 batch_sizes,
497 hx,
498 flat_weights_,
499 options_base.bias(),
500 options_base.num_layers(),
501 options_base.dropout(),
502 this->is_training(),
503 options_base.bidirectional());
504 } else {
505 TORCH_CHECK(
506 false,
507 "Unknown mode: ",
508 torch::enumtype::get_enum_name(options_base.mode()));
509 }
510 }
511 auto output = std::get<0>(result);
512 auto hidden = std::get<1>(result);
513
514 return std::make_tuple(output, hidden);
515 }
516
forward(const Tensor & input,Tensor hx)517 std::tuple<Tensor, Tensor> RNNImpl::forward(const Tensor& input, Tensor hx) {
518 auto batch_sizes = torch::Tensor();
519 auto max_batch_size =
520 options_base.batch_first() ? input.size(0) : input.size(1);
521 auto sorted_indices = torch::Tensor();
522 auto unsorted_indices = torch::Tensor();
523
524 auto [output, hidden] = this->forward_helper(
525 input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
526
527 return std::make_tuple(
528 output, this->permute_hidden(hidden, unsorted_indices));
529 }
530
forward_with_packed_input(const PackedSequence & packed_input,Tensor hx)531 std::tuple<PackedSequence, Tensor> RNNImpl::forward_with_packed_input(
532 const PackedSequence& packed_input,
533 Tensor hx) {
534 const auto& input = packed_input.data();
535 const auto& batch_sizes = packed_input.batch_sizes();
536 const auto& sorted_indices = packed_input.sorted_indices();
537 const auto& unsorted_indices = packed_input.unsorted_indices();
538 auto max_batch_size = batch_sizes[0].item<int64_t>();
539
540 auto [output, hidden] = this->forward_helper(
541 input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
542
543 auto output_packed =
544 PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices);
545 return std::make_tuple(
546 output_packed, this->permute_hidden(hidden, unsorted_indices));
547 }
548
549 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
550
LSTMImpl(const LSTMOptions & options_)551 LSTMImpl::LSTMImpl(const LSTMOptions& options_)
552 : detail::RNNImplBase<LSTMImpl>(detail::RNNOptionsBase(
553 torch::kLSTM,
554 options_.input_size(),
555 options_.hidden_size())
556 .num_layers(options_.num_layers())
557 .bias(options_.bias())
558 .batch_first(options_.batch_first())
559 .dropout(options_.dropout())
560 .bidirectional(options_.bidirectional())
561 .proj_size(options_.proj_size())),
562 options(options_) {}
563
get_expected_cell_size(const Tensor & input,const Tensor & batch_sizes) const564 std::tuple<int64_t, int64_t, int64_t> LSTMImpl::get_expected_cell_size(
565 const Tensor& input,
566 const Tensor& batch_sizes) const {
567 int64_t mini_batch = 0;
568 if (batch_sizes.defined()) {
569 mini_batch = batch_sizes[0].item<int64_t>();
570 } else {
571 mini_batch = options_base.batch_first() ? input.size(0) : input.size(1);
572 }
573 int64_t num_directions = options_base.bidirectional() ? 2 : 1;
574 return std::make_tuple(
575 options_base.num_layers() * num_directions,
576 mini_batch,
577 options_base.hidden_size());
578 }
579
check_forward_args(const Tensor & input,std::tuple<Tensor,Tensor> hidden,const Tensor & batch_sizes) const580 void LSTMImpl::check_forward_args(
581 const Tensor& input,
582 std::tuple<Tensor, Tensor> hidden,
583 const Tensor& batch_sizes) const {
584 this->check_input(input, batch_sizes);
585 this->check_hidden_size(
586 std::get<0>(hidden),
587 this->get_expected_hidden_size(input, batch_sizes),
588 "Expected hidden[0] size {1}, got {2}");
589 this->check_hidden_size(
590 std::get<1>(hidden),
591 this->get_expected_cell_size(input, batch_sizes),
592 "Expected hidden[1] size {1}, got {2}");
593 }
594
permute_hidden(std::tuple<Tensor,Tensor> hx,const Tensor & permutation) const595 std::tuple<Tensor, Tensor> LSTMImpl::permute_hidden(
596 std::tuple<Tensor, Tensor> hx,
597 const Tensor& permutation) const {
598 if (!permutation.defined()) {
599 return hx;
600 }
601 return std::make_tuple(
602 apply_permutation(std::get<0>(hx), permutation),
603 apply_permutation(std::get<1>(hx), permutation));
604 }
605
forward_helper(const Tensor & input,const Tensor & batch_sizes,const Tensor & sorted_indices,int64_t max_batch_size,torch::optional<std::tuple<Tensor,Tensor>> hx_opt)606 std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward_helper(
607 const Tensor& input,
608 const Tensor& batch_sizes,
609 const Tensor& sorted_indices,
610 int64_t max_batch_size,
611 torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
612 std::tuple<Tensor, Tensor> hx;
613 if (!hx_opt.has_value()) {
614 int64_t num_directions = options.bidirectional() ? 2 : 1;
615 int64_t real_hidden_size =
616 options.proj_size() > 0 ? options.proj_size() : options.hidden_size();
617 auto h_zeros = torch::zeros(
618 {options.num_layers() * num_directions,
619 max_batch_size,
620 real_hidden_size},
621 torch::dtype(input.dtype()).device(input.device()));
622 auto c_zeros = torch::zeros(
623 {options.num_layers() * num_directions,
624 max_batch_size,
625 options.hidden_size()},
626 torch::dtype(input.dtype()).device(input.device()));
627 hx = std::make_tuple(h_zeros, c_zeros);
628 } else {
629 hx = hx_opt.value();
630 // Each batch of the hidden state should match the input sequence that
631 // the user believes he/she is passing in.
632 hx = this->permute_hidden(hx, sorted_indices);
633 }
634
635 this->check_forward_args(input, hx, batch_sizes);
636 std::tuple<Tensor, Tensor, Tensor> result;
637 if (!batch_sizes.defined()) {
638 result = torch::lstm(
639 input,
640 {std::get<0>(hx), std::get<1>(hx)},
641 flat_weights_,
642 options.bias(),
643 options.num_layers(),
644 options.dropout(),
645 this->is_training(),
646 options.bidirectional(),
647 options.batch_first());
648 } else {
649 result = torch::lstm(
650 input,
651 batch_sizes,
652 {std::get<0>(hx), std::get<1>(hx)},
653 flat_weights_,
654 options.bias(),
655 options.num_layers(),
656 options.dropout(),
657 this->is_training(),
658 options.bidirectional());
659 }
660 auto output = std::get<0>(result);
661 auto hidden = std::make_tuple(std::get<1>(result), std::get<2>(result));
662
663 return std::make_tuple(output, hidden);
664 }
665
forward(const Tensor & input,torch::optional<std::tuple<Tensor,Tensor>> hx_opt)666 std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward(
667 const Tensor& input,
668 torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
669 auto batch_sizes = torch::Tensor();
670 auto max_batch_size = options.batch_first() ? input.size(0) : input.size(1);
671 auto sorted_indices = torch::Tensor();
672 auto unsorted_indices = torch::Tensor();
673
674 auto [output, hidden] = this->forward_helper(
675 input, batch_sizes, sorted_indices, max_batch_size, std::move(hx_opt));
676
677 return std::make_tuple(
678 output, this->permute_hidden(hidden, unsorted_indices));
679 }
680
681 std::tuple<PackedSequence, std::tuple<Tensor, Tensor>> LSTMImpl::
forward_with_packed_input(const PackedSequence & packed_input,torch::optional<std::tuple<Tensor,Tensor>> hx_opt)682 forward_with_packed_input(
683 const PackedSequence& packed_input,
684 torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
685 const auto& input = packed_input.data();
686 const auto& batch_sizes = packed_input.batch_sizes();
687 const auto& sorted_indices = packed_input.sorted_indices();
688 const auto& unsorted_indices = packed_input.unsorted_indices();
689 auto max_batch_size = batch_sizes[0].item<int64_t>();
690
691 auto [output, hidden] = this->forward_helper(
692 input, batch_sizes, sorted_indices, max_batch_size, std::move(hx_opt));
693
694 auto output_packed =
695 PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices);
696 return std::make_tuple(
697 output_packed, this->permute_hidden(hidden, unsorted_indices));
698 }
699
700 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
701
GRUImpl(const GRUOptions & options_)702 GRUImpl::GRUImpl(const GRUOptions& options_)
703 : detail::RNNImplBase<GRUImpl>(
704 detail::RNNOptionsBase(
705 torch::kGRU,
706 options_.input_size(),
707 options_.hidden_size())
708 .num_layers(options_.num_layers())
709 .bias(options_.bias())
710 .batch_first(options_.batch_first())
711 .dropout(options_.dropout())
712 .bidirectional(options_.bidirectional())),
713 options(options_) {}
714
forward_helper(const Tensor & input,const Tensor & batch_sizes,const Tensor & sorted_indices,int64_t max_batch_size,Tensor hx)715 std::tuple<Tensor, Tensor> GRUImpl::forward_helper(
716 const Tensor& input,
717 const Tensor& batch_sizes,
718 const Tensor& sorted_indices,
719 int64_t max_batch_size,
720 Tensor hx) {
721 if (!hx.defined()) {
722 int64_t num_directions = options.bidirectional() ? 2 : 1;
723 hx = torch::zeros(
724 {options.num_layers() * num_directions,
725 max_batch_size,
726 options.hidden_size()},
727 torch::dtype(input.dtype()).device(input.device()));
728 } else {
729 // Each batch of the hidden state should match the input sequence that
730 // the user believes he/she is passing in.
731 hx = this->permute_hidden(hx, sorted_indices);
732 }
733
734 this->check_forward_args(input, hx, batch_sizes);
735 std::tuple<Tensor, Tensor> result;
736 if (!batch_sizes.defined()) {
737 result = torch::gru(
738 input,
739 hx,
740 flat_weights_,
741 options.bias(),
742 options.num_layers(),
743 options.dropout(),
744 this->is_training(),
745 options.bidirectional(),
746 options.batch_first());
747 } else {
748 result = torch::gru(
749 input,
750 batch_sizes,
751 hx,
752 flat_weights_,
753 options.bias(),
754 options.num_layers(),
755 options.dropout(),
756 this->is_training(),
757 options.bidirectional());
758 }
759 auto output = std::get<0>(result);
760 auto hidden = std::get<1>(result);
761
762 return std::make_tuple(output, hidden);
763 }
764
forward(const Tensor & input,Tensor hx)765 std::tuple<Tensor, Tensor> GRUImpl::forward(const Tensor& input, Tensor hx) {
766 auto batch_sizes = torch::Tensor();
767 auto max_batch_size = options.batch_first() ? input.size(0) : input.size(1);
768 auto sorted_indices = torch::Tensor();
769 auto unsorted_indices = torch::Tensor();
770
771 auto [output, hidden] = this->forward_helper(
772 input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
773
774 return std::make_tuple(
775 output, this->permute_hidden(hidden, unsorted_indices));
776 }
777
forward_with_packed_input(const PackedSequence & packed_input,Tensor hx)778 std::tuple<PackedSequence, Tensor> GRUImpl::forward_with_packed_input(
779 const PackedSequence& packed_input,
780 Tensor hx) {
781 const auto& input = packed_input.data();
782 const auto& batch_sizes = packed_input.batch_sizes();
783 const auto& sorted_indices = packed_input.sorted_indices();
784 const auto& unsorted_indices = packed_input.unsorted_indices();
785 auto max_batch_size = batch_sizes[0].item<int64_t>();
786
787 auto [output, hidden] = this->forward_helper(
788 input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
789
790 auto output_packed =
791 PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices);
792 return std::make_tuple(
793 output_packed, this->permute_hidden(hidden, unsorted_indices));
794 }
795
796 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImplBase
797 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
798
799 namespace detail {
800 template <typename Derived>
RNNCellImplBase(const RNNCellOptionsBase & options_)801 RNNCellImplBase<Derived>::RNNCellImplBase(const RNNCellOptionsBase& options_)
802 : options_base(options_) {
803 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
804 reset();
805 }
806
807 template <typename Derived>
reset()808 void RNNCellImplBase<Derived>::reset() {
809 weight_ih = this->register_parameter(
810 "weight_ih",
811 torch::empty(
812 {options_base.num_chunks() * options_base.hidden_size(),
813 options_base.input_size()}));
814 weight_hh = this->register_parameter(
815 "weight_hh",
816 torch::empty(
817 {options_base.num_chunks() * options_base.hidden_size(),
818 options_base.hidden_size()}));
819
820 if (options_base.bias()) {
821 bias_ih = this->register_parameter(
822 "bias_ih",
823 torch::empty({options_base.num_chunks() * options_base.hidden_size()}));
824 bias_hh = this->register_parameter(
825 "bias_hh",
826 torch::empty({options_base.num_chunks() * options_base.hidden_size()}));
827 } else {
828 bias_ih =
829 this->register_parameter("bias_ih", Tensor(), /*requires_grad=*/false);
830 bias_hh =
831 this->register_parameter("bias_hh", Tensor(), /*requires_grad=*/false);
832 }
833
834 reset_parameters();
835 }
836
837 template <typename Derived>
reset_parameters()838 void RNNCellImplBase<Derived>::reset_parameters() {
839 const double stdv = 1.0 / std::sqrt(options_base.hidden_size());
840 for (auto& weight : this->parameters()) {
841 init::uniform_(weight, -stdv, stdv);
842 }
843 }
844
845 template <typename Derived>
pretty_print(std::ostream & stream) const846 void RNNCellImplBase<Derived>::pretty_print(std::ostream& stream) const {
847 const std::string name = this->name();
848 const std::string name_without_impl = name.substr(0, name.size() - 4);
849 stream << name_without_impl << "(" << options_base.input_size() << ", "
850 << options_base.hidden_size();
851 if (!options_base.bias()) {
852 stream << ", bias=" << std::boolalpha << false;
853 }
854 auto nonlinearity_str = this->get_nonlinearity_str();
855 if (!nonlinearity_str.empty() && nonlinearity_str != "kTanh") {
856 stream << ", nonlinearity=" << nonlinearity_str;
857 }
858 stream << ")";
859 }
860
861 template <typename Derived>
check_forward_input(const Tensor & input,const string & name) const862 void RNNCellImplBase<Derived>::check_forward_input(
863 const Tensor& input,
864 const string& name) const {
865 TORCH_CHECK(
866 input.dim() == 1 || input.dim() == 2,
867 "Expected ",
868 name.c_str(),
869 " to be 1D or 2D, got ",
870 input.dim(),
871 "D instead");
872 }
873
874 template <typename Derived>
get_nonlinearity_str() const875 std::string RNNCellImplBase<Derived>::get_nonlinearity_str() const {
876 return "";
877 }
878
879 template class RNNCellImplBase<LSTMCellImpl>;
880 template class RNNCellImplBase<GRUCellImpl>;
881 template class RNNCellImplBase<RNNCellImpl>;
882 } // namespace detail
883
884 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCell
885 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
886
RNNCellImpl(const RNNCellOptions & options_)887 RNNCellImpl::RNNCellImpl(const RNNCellOptions& options_)
888 : detail::RNNCellImplBase<RNNCellImpl>(detail::RNNCellOptionsBase(
889 options_.input_size(),
890 options_.hidden_size(),
891 options_.bias(),
892 /*num_chunks=*/1)),
893 options(options_) {}
894
forward(const Tensor & input,Tensor hx)895 Tensor RNNCellImpl::forward(const Tensor& input, Tensor hx) {
896 this->check_forward_input(input, "input");
897 this->check_forward_input(hx, "hidden");
898
899 Tensor r_hx, ret;
900
901 bool is_batched = input.dim() == 2;
902 Tensor r_input = is_batched ? input : input.unsqueeze(0);
903
904 if (!hx.defined()) {
905 r_hx = torch::zeros(
906 {input.size(0), options.hidden_size()},
907 torch::dtype(input.dtype()).device(input.device()));
908 } else {
909 r_hx = is_batched ? hx : hx.unsqueeze(0);
910 }
911
912 if (std::holds_alternative<enumtype::kTanh>(options.nonlinearity())) {
913 ret = torch::rnn_tanh_cell(
914 r_input, r_hx, weight_ih, weight_hh, bias_ih, bias_hh);
915 } else if (std::holds_alternative<enumtype::kReLU>(options.nonlinearity())) {
916 ret = torch::rnn_relu_cell(
917 r_input, r_hx, weight_ih, weight_hh, bias_ih, bias_hh);
918 } else {
919 TORCH_CHECK(
920 false,
921 "Unknown nonlinearity: ",
922 torch::enumtype::get_enum_name(options.nonlinearity()));
923 }
924
925 if (!is_batched) {
926 ret = ret.squeeze(0);
927 }
928
929 return ret;
930 }
931
get_nonlinearity_str() const932 std::string RNNCellImpl::get_nonlinearity_str() const {
933 return get_enum_name(options.nonlinearity());
934 }
935
936 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTMCell
937 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
938
LSTMCellImpl(const LSTMCellOptions & options_)939 LSTMCellImpl::LSTMCellImpl(const LSTMCellOptions& options_)
940 : detail::RNNCellImplBase<LSTMCellImpl>(detail::RNNCellOptionsBase(
941 options_.input_size(),
942 options_.hidden_size(),
943 options_.bias(),
944 /*num_chunks=*/4)),
945 options(options_) {}
946
forward(const Tensor & input,torch::optional<std::tuple<Tensor,Tensor>> hx_opt)947 std::tuple<Tensor, Tensor> LSTMCellImpl::forward(
948 const Tensor& input,
949 torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
950 this->check_forward_input(input, "input");
951 if (hx_opt.has_value()) {
952 this->check_forward_input(std::get<0>(hx_opt.value()), "hx[0]");
953 this->check_forward_input(std::get<1>(hx_opt.value()), "hx[1]");
954 }
955
956 std::tuple<Tensor, Tensor> r_hx, ret;
957
958 bool is_batched = input.dim() == 2;
959 Tensor r_input = is_batched ? input : input.unsqueeze(0);
960
961 if (!hx_opt.has_value()) {
962 auto zeros = torch::zeros(
963 {input.size(0), options.hidden_size()},
964 torch::dtype(input.dtype()).device(input.device()));
965 r_hx = std::make_tuple(zeros, zeros);
966 } else {
967 if (!is_batched) {
968 r_hx = std::make_tuple(
969 std::get<0>(hx_opt.value()).unsqueeze(0),
970 std::get<1>(hx_opt.value()).unsqueeze(0));
971 } else {
972 r_hx = hx_opt.value();
973 }
974 }
975
976 ret = torch::lstm_cell(
977 r_input,
978 {std::get<0>(r_hx), std::get<1>(r_hx)},
979 weight_ih,
980 weight_hh,
981 bias_ih,
982 bias_hh);
983
984 if (!is_batched) {
985 ret = std::make_tuple(
986 std::get<0>(ret).squeeze(0), std::get<1>(ret).squeeze(0));
987 }
988
989 return ret;
990 }
991
992 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRUCell
993 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
994
GRUCellImpl(const GRUCellOptions & options_)995 GRUCellImpl::GRUCellImpl(const GRUCellOptions& options_)
996 : detail::RNNCellImplBase<GRUCellImpl>(detail::RNNCellOptionsBase(
997 options_.input_size(),
998 options_.hidden_size(),
999 options_.bias(),
1000 /*num_chunks=*/3)),
1001 options(options_) {}
1002
forward(const Tensor & input,Tensor hx)1003 Tensor GRUCellImpl::forward(const Tensor& input, Tensor hx) {
1004 this->check_forward_input(input, "input");
1005 this->check_forward_input(hx, "hidden");
1006
1007 Tensor r_hx, ret;
1008
1009 bool is_batched = input.dim() == 2;
1010 Tensor r_input = is_batched ? input : input.unsqueeze(0);
1011
1012 if (!hx.defined()) {
1013 r_hx = torch::zeros(
1014 {input.size(0), options.hidden_size()},
1015 torch::dtype(input.dtype()).device(input.device()));
1016 } else {
1017 r_hx = is_batched ? hx : hx.unsqueeze(0);
1018 }
1019
1020 ret = torch::gru_cell(r_input, r_hx, weight_ih, weight_hh, bias_ih, bias_hh);
1021
1022 if (!is_batched) {
1023 ret = ret.squeeze(0);
1024 }
1025
1026 return ret;
1027 }
1028
1029 } // namespace nn
1030 } // namespace torch
1031