xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/rnn.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/modules/rnn.h>
2 
3 #include <torch/nn/init.h>
4 #include <torch/types.h>
5 #include <torch/utils.h>
6 
7 #include <c10/util/Exception.h>
8 #include <c10/util/irange.h>
9 
10 #include <cmath>
11 #include <cstdint>
12 #include <regex>
13 #include <string>
14 #include <tuple>
15 #include <unordered_set>
16 #include <utility>
17 #include <vector>
18 
19 using namespace torch::nn::utils::rnn;
20 
21 namespace torch {
22 namespace nn {
23 
24 /// These must line up with the CUDNN mode codes:
25 /// https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t
26 enum class CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 };
27 
get_cudnn_mode_for_rnn(detail::RNNOptionsBase::rnn_options_base_mode_t mode)28 static CuDNNMode get_cudnn_mode_for_rnn(
29     detail::RNNOptionsBase::rnn_options_base_mode_t mode) {
30   if (std::holds_alternative<enumtype::kRNN_RELU>(mode)) {
31     return CuDNNMode::RNN_RELU;
32   } else if (std::holds_alternative<enumtype::kRNN_TANH>(mode)) {
33     return CuDNNMode::RNN_TANH;
34   } else if (std::holds_alternative<enumtype::kLSTM>(mode)) {
35     return CuDNNMode::LSTM;
36   } else if (std::holds_alternative<enumtype::kGRU>(mode)) {
37     return CuDNNMode::GRU;
38   } else {
39     TORCH_CHECK(false, "Unknown mode: ", torch::enumtype::get_enum_name(mode));
40   }
41 }
42 
apply_permutation(const Tensor & tensor,const Tensor & permutation,int64_t dim=1)43 static Tensor apply_permutation(
44     const Tensor& tensor,
45     const Tensor& permutation,
46     int64_t dim = 1) {
47   return tensor.index_select(dim, permutation);
48 }
49 
50 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNImplBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
51 namespace detail {
52 template <typename Derived>
RNNImplBase(const RNNOptionsBase & options_)53 RNNImplBase<Derived>::RNNImplBase(const RNNOptionsBase& options_)
54     : options_base(options_) {
55   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
56   reset();
57 }
58 
59 template <typename Derived>
reset()60 void RNNImplBase<Derived>::reset() {
61   const int64_t num_directions = options_base.bidirectional() ? 2 : 1;
62 
63   TORCH_CHECK(
64       0 <= options_base.dropout() && options_base.dropout() <= 1,
65       "dropout should be a number in range [0, 1] ",
66       "representing the probability of an element being ",
67       "zeroed");
68 
69   if (options_base.dropout() > 0 && options_base.num_layers() == 1) {
70     TORCH_WARN(
71         "dropout option adds dropout after all but last ",
72         "recurrent layer, so non-zero dropout expects ",
73         "num_layers greater than 1, but got dropout=",
74         options_base.dropout(),
75         " and ",
76         "num_layers=",
77         options_base.num_layers());
78   }
79 
80   TORCH_CHECK(
81       options_base.hidden_size() > 0, "hidden_size must be greater than zero");
82 
83   TORCH_CHECK(
84       options_base.num_layers() > 0, "num_layers must be greater than zero");
85 
86   TORCH_CHECK(
87       0 <= options_base.proj_size() &&
88           options_base.proj_size() < options_base.hidden_size(),
89       "proj_size has to be a positive integer, smaller than ",
90       "hidden_size or zero to disable projections");
91 
92   if (options_base.proj_size() > 0) {
93     TORCH_CHECK(
94         std::get_if<enumtype::kLSTM>(&options_base.mode()),
95         "proj_size argument is only supported for LSTM, not RNN or GRU");
96   }
97 
98   int64_t gate_size = 0;
99   if (std::holds_alternative<enumtype::kLSTM>(options_base.mode())) {
100     gate_size = 4 * options_base.hidden_size();
101   } else if (std::holds_alternative<enumtype::kGRU>(options_base.mode())) {
102     gate_size = 3 * options_base.hidden_size();
103     // NOLINTNEXTLINE(bugprone-branch-clone)
104   } else if (std::holds_alternative<enumtype::kRNN_TANH>(options_base.mode())) {
105     gate_size = options_base.hidden_size();
106   } else if (std::holds_alternative<enumtype::kRNN_RELU>(options_base.mode())) {
107     gate_size = options_base.hidden_size();
108   } else {
109     TORCH_CHECK(
110         false,
111         "Unrecognized RNN mode: " +
112             torch::enumtype::get_enum_name(options_base.mode()));
113   }
114 
115   flat_weights_names_ = {};
116   all_weights_ = {};
117 
118   for (const auto layer : c10::irange(options_base.num_layers())) {
119     for (const auto direction : c10::irange(num_directions)) {
120       int64_t real_hidden_size = options_base.proj_size() > 0
121           ? options_base.proj_size()
122           : options_base.hidden_size();
123       int64_t layer_input_size = layer == 0 ? options_base.input_size()
124                                             : real_hidden_size * num_directions;
125 
126       auto w_ih = torch::empty({gate_size, layer_input_size});
127       auto w_hh = torch::empty({gate_size, real_hidden_size});
128       auto b_ih = torch::empty({gate_size});
129       // Second bias vector included for CuDNN compatibility. Only one
130       // bias vector is needed in standard definition.
131       auto b_hh = torch::empty({gate_size});
132       std::vector<Tensor> layer_params = {w_ih, w_hh};
133 
134       std::string suffix = direction == 1 ? "_reverse" : "";
135       std::vector<std::string> param_names = {
136           "weight_ih_l{layer}{suffix}", "weight_hh_l{layer}{suffix}"};
137       if (options_base.bias()) {
138         param_names.emplace_back("bias_ih_l{layer}{suffix}");
139         param_names.emplace_back("bias_hh_l{layer}{suffix}");
140         layer_params.emplace_back(b_ih);
141         layer_params.emplace_back(b_hh);
142       }
143       if (options_base.proj_size() > 0) {
144         auto w_hr = torch::empty(
145             {options_base.proj_size(), options_base.hidden_size()});
146         layer_params.emplace_back(w_hr);
147         param_names.emplace_back("weight_hr_l{layer}{suffix}");
148       }
149       for (auto& param_name : param_names) {
150         std::string x = std::regex_replace(
151             param_name, std::regex("\\{layer\\}"), c10::str(layer));
152         param_name =
153             std::regex_replace(x, std::regex("\\{suffix\\}"), c10::str(suffix));
154       }
155 
156       for (const auto i : c10::irange(param_names.size())) {
157         this->register_parameter(param_names[i], std::move(layer_params[i]));
158       }
159       flat_weights_names_.insert(
160           flat_weights_names_.end(), param_names.begin(), param_names.end());
161       all_weights_.emplace_back(std::move(param_names));
162     }
163   }
164 
165   flat_weights_ = {};
166   for (const auto& wn : flat_weights_names_) {
167     auto named_parameters = this->named_parameters(/*recurse=*/false);
168     if (named_parameters.contains(wn)) {
169       flat_weights_.emplace_back(named_parameters[wn]);
170     } else {
171       flat_weights_.emplace_back();
172     }
173   }
174 
175   this->flatten_parameters();
176   this->reset_parameters();
177 }
178 
179 template <typename Derived>
flatten_parameters()180 void RNNImplBase<Derived>::flatten_parameters() {
181   // Resets parameter data pointer so that they can use faster code paths.
182   //
183   // Right now, this works only if the module is on the GPU and cuDNN is
184   // enabled. Otherwise, it's a no-op.
185 
186   // Short-circuits if flat_weights_ is only partially instantiated
187   if (flat_weights_.size() != flat_weights_names_.size()) {
188     return;
189   }
190 
191   // Short-circuits if any tensor in self.flat_weights_ is not acceptable to
192   // cuDNN or the tensors in flat_weights_ are of different dtypes
193 
194   auto first_fw = flat_weights_[0];
195   auto dtype = first_fw.dtype();
196   for (const auto& fw : flat_weights_) {
197     if (!(fw.dtype() == dtype) || !fw.is_cuda() ||
198         !torch::cudnn_is_acceptable(fw)) {
199       return;
200     }
201   }
202 
203   // If any parameters alias, we fall back to the slower, copying code path.
204   // This is a sufficient check, because overlapping parameter buffers that
205   // don't completely alias would break the assumptions of the uniqueness check
206   // in Module::named_parameters().
207   std::unordered_set<void*> unique_data_ptrs;
208   for (const auto& p : flat_weights_) {
209     unique_data_ptrs.emplace(p.data_ptr());
210   }
211   if (unique_data_ptrs.size() != flat_weights_.size()) {
212     return;
213   }
214 
215   {
216     torch::DeviceGuard device_guard(first_fw.device());
217 
218     // Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
219     // an inplace operation on self.flat_weights_
220     {
221       torch::NoGradGuard no_grad;
222       if (torch::_use_cudnn_rnn_flatten_weight()) {
223         int64_t num_weights = options_base.bias() ? 4 : 2;
224         if (options_base.proj_size() > 0) {
225           ++num_weights;
226         }
227         torch::_cudnn_rnn_flatten_weight(
228             flat_weights_,
229             num_weights,
230             options_base.input_size(),
231             static_cast<int64_t>(get_cudnn_mode_for_rnn(options_base.mode())),
232             options_base.hidden_size(),
233             options_base.proj_size(),
234             options_base.num_layers(),
235             options_base.batch_first(),
236             options_base.bidirectional());
237       }
238     }
239   }
240 }
241 
242 template <typename Derived>
reset_flat_weights()243 void RNNImplBase<Derived>::reset_flat_weights() {
244   flat_weights_ = {};
245   for (const auto& wn : flat_weights_names_) {
246     auto named_parameters = this->named_parameters(/*recurse=*/false);
247     if (named_parameters.contains(wn)) {
248       flat_weights_.emplace_back(named_parameters[wn]);
249     } else {
250       flat_weights_.emplace_back();
251     }
252   }
253 }
254 
255 template <typename Derived>
to(torch::Device device,torch::Dtype dtype,bool non_blocking)256 void RNNImplBase<Derived>::to(
257     torch::Device device,
258     torch::Dtype dtype,
259     bool non_blocking) {
260   nn::Module::to(device, dtype, non_blocking);
261   reset_flat_weights();
262   flatten_parameters();
263 }
264 
265 template <typename Derived>
to(torch::Dtype dtype,bool non_blocking)266 void RNNImplBase<Derived>::to(torch::Dtype dtype, bool non_blocking) {
267   nn::Module::to(dtype, non_blocking);
268   reset_flat_weights();
269   flatten_parameters();
270 }
271 
272 template <typename Derived>
to(torch::Device device,bool non_blocking)273 void RNNImplBase<Derived>::to(torch::Device device, bool non_blocking) {
274   nn::Module::to(device, non_blocking);
275   reset_flat_weights();
276   flatten_parameters();
277 }
278 
279 template <typename Derived>
reset_parameters()280 void RNNImplBase<Derived>::reset_parameters() {
281   const double stdv = 1.0 / std::sqrt(options_base.hidden_size());
282   for (auto& weight : this->parameters()) {
283     init::uniform_(weight, -stdv, stdv);
284   }
285 }
286 
287 template <typename Derived>
check_input(const Tensor & input,const Tensor & batch_sizes) const288 void RNNImplBase<Derived>::check_input(
289     const Tensor& input,
290     const Tensor& batch_sizes) const {
291   int64_t expected_input_dim = batch_sizes.defined() ? 2 : 3;
292   TORCH_CHECK(
293       input.dim() == expected_input_dim,
294       "input must have ",
295       expected_input_dim,
296       " dimensions, got ",
297       input.dim());
298   TORCH_CHECK(
299       options_base.input_size() == input.size(-1),
300       "input.size(-1) must be equal to input_size. Expected ",
301       options_base.input_size(),
302       ", got ",
303       input.size(-1));
304 }
305 
306 template <typename Derived>
307 std::tuple<int64_t, int64_t, int64_t> RNNImplBase<Derived>::
get_expected_hidden_size(const Tensor & input,const Tensor & batch_sizes) const308     get_expected_hidden_size(const Tensor& input, const Tensor& batch_sizes)
309         const {
310   int64_t mini_batch = 0;
311   if (batch_sizes.defined()) {
312     mini_batch = batch_sizes[0].item<int64_t>();
313   } else {
314     mini_batch = options_base.batch_first() ? input.size(0) : input.size(1);
315   }
316   int64_t num_directions = options_base.bidirectional() ? 2 : 1;
317   int64_t real_hidden_size = options_base.proj_size() > 0
318       ? options_base.proj_size()
319       : options_base.hidden_size();
320   return std::make_tuple(
321       options_base.num_layers() * num_directions, mini_batch, real_hidden_size);
322 }
323 
324 template <typename Derived>
check_hidden_size(const Tensor & hx,std::tuple<int64_t,int64_t,int64_t> expected_hidden_size,std::string msg) const325 void RNNImplBase<Derived>::check_hidden_size(
326     const Tensor& hx,
327     std::tuple<int64_t, int64_t, int64_t> expected_hidden_size,
328     std::string msg) const {
329   auto expected_hidden_size_vec = std::vector<int64_t>({
330       std::get<0>(expected_hidden_size),
331       std::get<1>(expected_hidden_size),
332       std::get<2>(expected_hidden_size),
333   });
334   if (hx.sizes() != expected_hidden_size_vec) {
335     msg = std::regex_replace(
336         msg, std::regex("\\{1\\}"), c10::str(expected_hidden_size_vec));
337     msg = std::regex_replace(msg, std::regex("\\{2\\}"), c10::str(hx.sizes()));
338     TORCH_CHECK(false, msg);
339   }
340 }
341 
342 template <typename Derived>
check_forward_args(Tensor input,Tensor hidden,Tensor batch_sizes) const343 void RNNImplBase<Derived>::check_forward_args(
344     Tensor input,
345     Tensor hidden,
346     Tensor batch_sizes) const {
347   this->check_input(input, batch_sizes);
348   auto expected_hidden_size =
349       this->get_expected_hidden_size(input, batch_sizes);
350 
351   this->check_hidden_size(hidden, expected_hidden_size);
352 }
353 
354 template <typename Derived>
permute_hidden(Tensor hx,const Tensor & permutation) const355 Tensor RNNImplBase<Derived>::permute_hidden(
356     Tensor hx,
357     const Tensor& permutation) const {
358   if (!permutation.defined()) {
359     return hx;
360   }
361   return apply_permutation(hx, permutation);
362 }
363 
364 template <typename Derived>
pretty_print(std::ostream & stream) const365 void RNNImplBase<Derived>::pretty_print(std::ostream& stream) const {
366   const std::string name = this->name();
367   const std::string name_without_impl = name.substr(0, name.size() - 4);
368   stream << std::boolalpha << name_without_impl
369          << "(input_size=" << options_base.input_size()
370          << ", hidden_size=" << options_base.hidden_size()
371          << ", num_layers=" << options_base.num_layers()
372          << ", bias=" << options_base.bias()
373          << ", batch_first=" << options_base.batch_first()
374          << ", dropout=" << options_base.dropout()
375          << ", bidirectional=" << options_base.bidirectional();
376   if (options_base.proj_size() > 0) {
377     stream << ", proj_size=" << options_base.proj_size();
378   }
379   stream << ")";
380 }
381 
382 template <typename Derived>
all_weights() const383 std::vector<Tensor> RNNImplBase<Derived>::all_weights() const {
384   std::vector<Tensor> result = {};
385   auto named_parameters = this->named_parameters(/*recurse=*/false);
386   for (const auto& weights : all_weights_) {
387     for (const auto& weight : weights) {
388       result.emplace_back(named_parameters[weight]);
389     }
390   }
391   return result;
392 }
393 
394 template class RNNImplBase<LSTMImpl>;
395 template class RNNImplBase<GRUImpl>;
396 template class RNNImplBase<RNNImpl>;
397 } // namespace detail
398 
399 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
400 
401 static detail::RNNOptionsBase::rnn_options_base_mode_t
compute_rnn_options_base_mode(RNNOptions::nonlinearity_t nonlinearity)402 compute_rnn_options_base_mode(RNNOptions::nonlinearity_t nonlinearity) {
403   if (std::holds_alternative<enumtype::kTanh>(nonlinearity)) {
404     return torch::kRNN_TANH;
405   } else if (std::holds_alternative<enumtype::kReLU>(nonlinearity)) {
406     return torch::kRNN_RELU;
407   } else {
408     TORCH_CHECK(
409         false,
410         "Unknown nonlinearity ",
411         torch::enumtype::get_enum_name(nonlinearity));
412   }
413 }
414 
RNNImpl(const RNNOptions & options_)415 RNNImpl::RNNImpl(const RNNOptions& options_)
416     : detail::RNNImplBase<RNNImpl>(
417           detail::RNNOptionsBase(
418               compute_rnn_options_base_mode(options_.nonlinearity()),
419               options_.input_size(),
420               options_.hidden_size())
421               .num_layers(options_.num_layers())
422               .bias(options_.bias())
423               .batch_first(options_.batch_first())
424               .dropout(options_.dropout())
425               .bidirectional(options_.bidirectional())),
426       options(options_) {}
427 
forward_helper(const Tensor & input,const Tensor & batch_sizes,const Tensor & sorted_indices,int64_t max_batch_size,Tensor hx)428 std::tuple<Tensor, Tensor> RNNImpl::forward_helper(
429     const Tensor& input,
430     const Tensor& batch_sizes,
431     const Tensor& sorted_indices,
432     int64_t max_batch_size,
433     Tensor hx) {
434   if (!hx.defined()) {
435     int64_t num_directions = options_base.bidirectional() ? 2 : 1;
436     hx = torch::zeros(
437         {options_base.num_layers() * num_directions,
438          max_batch_size,
439          options_base.hidden_size()},
440         torch::dtype(input.dtype()).device(input.device()));
441   } else {
442     // Each batch of the hidden state should match the input sequence that
443     // the user believes he/she is passing in.
444     hx = this->permute_hidden(hx, sorted_indices);
445   }
446 
447   this->check_forward_args(input, hx, batch_sizes);
448 
449   std::tuple<Tensor, Tensor> result;
450   if (!batch_sizes.defined()) {
451     if (std::holds_alternative<enumtype::kRNN_TANH>(options_base.mode())) {
452       result = torch::rnn_tanh(
453           input,
454           hx,
455           flat_weights_,
456           options_base.bias(),
457           options_base.num_layers(),
458           options_base.dropout(),
459           this->is_training(),
460           options_base.bidirectional(),
461           options_base.batch_first());
462     } else if (std::holds_alternative<enumtype::kRNN_RELU>(
463                    options_base.mode())) {
464       result = torch::rnn_relu(
465           input,
466           hx,
467           flat_weights_,
468           options_base.bias(),
469           options_base.num_layers(),
470           options_base.dropout(),
471           this->is_training(),
472           options_base.bidirectional(),
473           options_base.batch_first());
474     } else {
475       TORCH_CHECK(
476           false,
477           "Unknown mode: ",
478           torch::enumtype::get_enum_name(options_base.mode()));
479     }
480   } else {
481     if (std::holds_alternative<enumtype::kRNN_TANH>(options_base.mode())) {
482       result = torch::rnn_tanh(
483           input,
484           batch_sizes,
485           hx,
486           flat_weights_,
487           options_base.bias(),
488           options_base.num_layers(),
489           options_base.dropout(),
490           this->is_training(),
491           options_base.bidirectional());
492     } else if (std::holds_alternative<enumtype::kRNN_RELU>(
493                    options_base.mode())) {
494       result = torch::rnn_relu(
495           input,
496           batch_sizes,
497           hx,
498           flat_weights_,
499           options_base.bias(),
500           options_base.num_layers(),
501           options_base.dropout(),
502           this->is_training(),
503           options_base.bidirectional());
504     } else {
505       TORCH_CHECK(
506           false,
507           "Unknown mode: ",
508           torch::enumtype::get_enum_name(options_base.mode()));
509     }
510   }
511   auto output = std::get<0>(result);
512   auto hidden = std::get<1>(result);
513 
514   return std::make_tuple(output, hidden);
515 }
516 
forward(const Tensor & input,Tensor hx)517 std::tuple<Tensor, Tensor> RNNImpl::forward(const Tensor& input, Tensor hx) {
518   auto batch_sizes = torch::Tensor();
519   auto max_batch_size =
520       options_base.batch_first() ? input.size(0) : input.size(1);
521   auto sorted_indices = torch::Tensor();
522   auto unsorted_indices = torch::Tensor();
523 
524   auto [output, hidden] = this->forward_helper(
525       input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
526 
527   return std::make_tuple(
528       output, this->permute_hidden(hidden, unsorted_indices));
529 }
530 
forward_with_packed_input(const PackedSequence & packed_input,Tensor hx)531 std::tuple<PackedSequence, Tensor> RNNImpl::forward_with_packed_input(
532     const PackedSequence& packed_input,
533     Tensor hx) {
534   const auto& input = packed_input.data();
535   const auto& batch_sizes = packed_input.batch_sizes();
536   const auto& sorted_indices = packed_input.sorted_indices();
537   const auto& unsorted_indices = packed_input.unsorted_indices();
538   auto max_batch_size = batch_sizes[0].item<int64_t>();
539 
540   auto [output, hidden] = this->forward_helper(
541       input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
542 
543   auto output_packed =
544       PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices);
545   return std::make_tuple(
546       output_packed, this->permute_hidden(hidden, unsorted_indices));
547 }
548 
549 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
550 
LSTMImpl(const LSTMOptions & options_)551 LSTMImpl::LSTMImpl(const LSTMOptions& options_)
552     : detail::RNNImplBase<LSTMImpl>(detail::RNNOptionsBase(
553                                         torch::kLSTM,
554                                         options_.input_size(),
555                                         options_.hidden_size())
556                                         .num_layers(options_.num_layers())
557                                         .bias(options_.bias())
558                                         .batch_first(options_.batch_first())
559                                         .dropout(options_.dropout())
560                                         .bidirectional(options_.bidirectional())
561                                         .proj_size(options_.proj_size())),
562       options(options_) {}
563 
get_expected_cell_size(const Tensor & input,const Tensor & batch_sizes) const564 std::tuple<int64_t, int64_t, int64_t> LSTMImpl::get_expected_cell_size(
565     const Tensor& input,
566     const Tensor& batch_sizes) const {
567   int64_t mini_batch = 0;
568   if (batch_sizes.defined()) {
569     mini_batch = batch_sizes[0].item<int64_t>();
570   } else {
571     mini_batch = options_base.batch_first() ? input.size(0) : input.size(1);
572   }
573   int64_t num_directions = options_base.bidirectional() ? 2 : 1;
574   return std::make_tuple(
575       options_base.num_layers() * num_directions,
576       mini_batch,
577       options_base.hidden_size());
578 }
579 
check_forward_args(const Tensor & input,std::tuple<Tensor,Tensor> hidden,const Tensor & batch_sizes) const580 void LSTMImpl::check_forward_args(
581     const Tensor& input,
582     std::tuple<Tensor, Tensor> hidden,
583     const Tensor& batch_sizes) const {
584   this->check_input(input, batch_sizes);
585   this->check_hidden_size(
586       std::get<0>(hidden),
587       this->get_expected_hidden_size(input, batch_sizes),
588       "Expected hidden[0] size {1}, got {2}");
589   this->check_hidden_size(
590       std::get<1>(hidden),
591       this->get_expected_cell_size(input, batch_sizes),
592       "Expected hidden[1] size {1}, got {2}");
593 }
594 
permute_hidden(std::tuple<Tensor,Tensor> hx,const Tensor & permutation) const595 std::tuple<Tensor, Tensor> LSTMImpl::permute_hidden(
596     std::tuple<Tensor, Tensor> hx,
597     const Tensor& permutation) const {
598   if (!permutation.defined()) {
599     return hx;
600   }
601   return std::make_tuple(
602       apply_permutation(std::get<0>(hx), permutation),
603       apply_permutation(std::get<1>(hx), permutation));
604 }
605 
forward_helper(const Tensor & input,const Tensor & batch_sizes,const Tensor & sorted_indices,int64_t max_batch_size,torch::optional<std::tuple<Tensor,Tensor>> hx_opt)606 std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward_helper(
607     const Tensor& input,
608     const Tensor& batch_sizes,
609     const Tensor& sorted_indices,
610     int64_t max_batch_size,
611     torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
612   std::tuple<Tensor, Tensor> hx;
613   if (!hx_opt.has_value()) {
614     int64_t num_directions = options.bidirectional() ? 2 : 1;
615     int64_t real_hidden_size =
616         options.proj_size() > 0 ? options.proj_size() : options.hidden_size();
617     auto h_zeros = torch::zeros(
618         {options.num_layers() * num_directions,
619          max_batch_size,
620          real_hidden_size},
621         torch::dtype(input.dtype()).device(input.device()));
622     auto c_zeros = torch::zeros(
623         {options.num_layers() * num_directions,
624          max_batch_size,
625          options.hidden_size()},
626         torch::dtype(input.dtype()).device(input.device()));
627     hx = std::make_tuple(h_zeros, c_zeros);
628   } else {
629     hx = hx_opt.value();
630     // Each batch of the hidden state should match the input sequence that
631     // the user believes he/she is passing in.
632     hx = this->permute_hidden(hx, sorted_indices);
633   }
634 
635   this->check_forward_args(input, hx, batch_sizes);
636   std::tuple<Tensor, Tensor, Tensor> result;
637   if (!batch_sizes.defined()) {
638     result = torch::lstm(
639         input,
640         {std::get<0>(hx), std::get<1>(hx)},
641         flat_weights_,
642         options.bias(),
643         options.num_layers(),
644         options.dropout(),
645         this->is_training(),
646         options.bidirectional(),
647         options.batch_first());
648   } else {
649     result = torch::lstm(
650         input,
651         batch_sizes,
652         {std::get<0>(hx), std::get<1>(hx)},
653         flat_weights_,
654         options.bias(),
655         options.num_layers(),
656         options.dropout(),
657         this->is_training(),
658         options.bidirectional());
659   }
660   auto output = std::get<0>(result);
661   auto hidden = std::make_tuple(std::get<1>(result), std::get<2>(result));
662 
663   return std::make_tuple(output, hidden);
664 }
665 
forward(const Tensor & input,torch::optional<std::tuple<Tensor,Tensor>> hx_opt)666 std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward(
667     const Tensor& input,
668     torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
669   auto batch_sizes = torch::Tensor();
670   auto max_batch_size = options.batch_first() ? input.size(0) : input.size(1);
671   auto sorted_indices = torch::Tensor();
672   auto unsorted_indices = torch::Tensor();
673 
674   auto [output, hidden] = this->forward_helper(
675       input, batch_sizes, sorted_indices, max_batch_size, std::move(hx_opt));
676 
677   return std::make_tuple(
678       output, this->permute_hidden(hidden, unsorted_indices));
679 }
680 
681 std::tuple<PackedSequence, std::tuple<Tensor, Tensor>> LSTMImpl::
forward_with_packed_input(const PackedSequence & packed_input,torch::optional<std::tuple<Tensor,Tensor>> hx_opt)682     forward_with_packed_input(
683         const PackedSequence& packed_input,
684         torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
685   const auto& input = packed_input.data();
686   const auto& batch_sizes = packed_input.batch_sizes();
687   const auto& sorted_indices = packed_input.sorted_indices();
688   const auto& unsorted_indices = packed_input.unsorted_indices();
689   auto max_batch_size = batch_sizes[0].item<int64_t>();
690 
691   auto [output, hidden] = this->forward_helper(
692       input, batch_sizes, sorted_indices, max_batch_size, std::move(hx_opt));
693 
694   auto output_packed =
695       PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices);
696   return std::make_tuple(
697       output_packed, this->permute_hidden(hidden, unsorted_indices));
698 }
699 
700 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
701 
GRUImpl(const GRUOptions & options_)702 GRUImpl::GRUImpl(const GRUOptions& options_)
703     : detail::RNNImplBase<GRUImpl>(
704           detail::RNNOptionsBase(
705               torch::kGRU,
706               options_.input_size(),
707               options_.hidden_size())
708               .num_layers(options_.num_layers())
709               .bias(options_.bias())
710               .batch_first(options_.batch_first())
711               .dropout(options_.dropout())
712               .bidirectional(options_.bidirectional())),
713       options(options_) {}
714 
forward_helper(const Tensor & input,const Tensor & batch_sizes,const Tensor & sorted_indices,int64_t max_batch_size,Tensor hx)715 std::tuple<Tensor, Tensor> GRUImpl::forward_helper(
716     const Tensor& input,
717     const Tensor& batch_sizes,
718     const Tensor& sorted_indices,
719     int64_t max_batch_size,
720     Tensor hx) {
721   if (!hx.defined()) {
722     int64_t num_directions = options.bidirectional() ? 2 : 1;
723     hx = torch::zeros(
724         {options.num_layers() * num_directions,
725          max_batch_size,
726          options.hidden_size()},
727         torch::dtype(input.dtype()).device(input.device()));
728   } else {
729     // Each batch of the hidden state should match the input sequence that
730     // the user believes he/she is passing in.
731     hx = this->permute_hidden(hx, sorted_indices);
732   }
733 
734   this->check_forward_args(input, hx, batch_sizes);
735   std::tuple<Tensor, Tensor> result;
736   if (!batch_sizes.defined()) {
737     result = torch::gru(
738         input,
739         hx,
740         flat_weights_,
741         options.bias(),
742         options.num_layers(),
743         options.dropout(),
744         this->is_training(),
745         options.bidirectional(),
746         options.batch_first());
747   } else {
748     result = torch::gru(
749         input,
750         batch_sizes,
751         hx,
752         flat_weights_,
753         options.bias(),
754         options.num_layers(),
755         options.dropout(),
756         this->is_training(),
757         options.bidirectional());
758   }
759   auto output = std::get<0>(result);
760   auto hidden = std::get<1>(result);
761 
762   return std::make_tuple(output, hidden);
763 }
764 
forward(const Tensor & input,Tensor hx)765 std::tuple<Tensor, Tensor> GRUImpl::forward(const Tensor& input, Tensor hx) {
766   auto batch_sizes = torch::Tensor();
767   auto max_batch_size = options.batch_first() ? input.size(0) : input.size(1);
768   auto sorted_indices = torch::Tensor();
769   auto unsorted_indices = torch::Tensor();
770 
771   auto [output, hidden] = this->forward_helper(
772       input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
773 
774   return std::make_tuple(
775       output, this->permute_hidden(hidden, unsorted_indices));
776 }
777 
forward_with_packed_input(const PackedSequence & packed_input,Tensor hx)778 std::tuple<PackedSequence, Tensor> GRUImpl::forward_with_packed_input(
779     const PackedSequence& packed_input,
780     Tensor hx) {
781   const auto& input = packed_input.data();
782   const auto& batch_sizes = packed_input.batch_sizes();
783   const auto& sorted_indices = packed_input.sorted_indices();
784   const auto& unsorted_indices = packed_input.unsorted_indices();
785   auto max_batch_size = batch_sizes[0].item<int64_t>();
786 
787   auto [output, hidden] = this->forward_helper(
788       input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
789 
790   auto output_packed =
791       PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices);
792   return std::make_tuple(
793       output_packed, this->permute_hidden(hidden, unsorted_indices));
794 }
795 
796 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImplBase
797 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
798 
799 namespace detail {
800 template <typename Derived>
RNNCellImplBase(const RNNCellOptionsBase & options_)801 RNNCellImplBase<Derived>::RNNCellImplBase(const RNNCellOptionsBase& options_)
802     : options_base(options_) {
803   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
804   reset();
805 }
806 
807 template <typename Derived>
reset()808 void RNNCellImplBase<Derived>::reset() {
809   weight_ih = this->register_parameter(
810       "weight_ih",
811       torch::empty(
812           {options_base.num_chunks() * options_base.hidden_size(),
813            options_base.input_size()}));
814   weight_hh = this->register_parameter(
815       "weight_hh",
816       torch::empty(
817           {options_base.num_chunks() * options_base.hidden_size(),
818            options_base.hidden_size()}));
819 
820   if (options_base.bias()) {
821     bias_ih = this->register_parameter(
822         "bias_ih",
823         torch::empty({options_base.num_chunks() * options_base.hidden_size()}));
824     bias_hh = this->register_parameter(
825         "bias_hh",
826         torch::empty({options_base.num_chunks() * options_base.hidden_size()}));
827   } else {
828     bias_ih =
829         this->register_parameter("bias_ih", Tensor(), /*requires_grad=*/false);
830     bias_hh =
831         this->register_parameter("bias_hh", Tensor(), /*requires_grad=*/false);
832   }
833 
834   reset_parameters();
835 }
836 
837 template <typename Derived>
reset_parameters()838 void RNNCellImplBase<Derived>::reset_parameters() {
839   const double stdv = 1.0 / std::sqrt(options_base.hidden_size());
840   for (auto& weight : this->parameters()) {
841     init::uniform_(weight, -stdv, stdv);
842   }
843 }
844 
845 template <typename Derived>
pretty_print(std::ostream & stream) const846 void RNNCellImplBase<Derived>::pretty_print(std::ostream& stream) const {
847   const std::string name = this->name();
848   const std::string name_without_impl = name.substr(0, name.size() - 4);
849   stream << name_without_impl << "(" << options_base.input_size() << ", "
850          << options_base.hidden_size();
851   if (!options_base.bias()) {
852     stream << ", bias=" << std::boolalpha << false;
853   }
854   auto nonlinearity_str = this->get_nonlinearity_str();
855   if (!nonlinearity_str.empty() && nonlinearity_str != "kTanh") {
856     stream << ", nonlinearity=" << nonlinearity_str;
857   }
858   stream << ")";
859 }
860 
861 template <typename Derived>
check_forward_input(const Tensor & input,const string & name) const862 void RNNCellImplBase<Derived>::check_forward_input(
863     const Tensor& input,
864     const string& name) const {
865   TORCH_CHECK(
866       input.dim() == 1 || input.dim() == 2,
867       "Expected ",
868       name.c_str(),
869       " to be 1D or 2D, got ",
870       input.dim(),
871       "D instead");
872 }
873 
874 template <typename Derived>
get_nonlinearity_str() const875 std::string RNNCellImplBase<Derived>::get_nonlinearity_str() const {
876   return "";
877 }
878 
879 template class RNNCellImplBase<LSTMCellImpl>;
880 template class RNNCellImplBase<GRUCellImpl>;
881 template class RNNCellImplBase<RNNCellImpl>;
882 } // namespace detail
883 
884 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCell
885 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
886 
RNNCellImpl(const RNNCellOptions & options_)887 RNNCellImpl::RNNCellImpl(const RNNCellOptions& options_)
888     : detail::RNNCellImplBase<RNNCellImpl>(detail::RNNCellOptionsBase(
889           options_.input_size(),
890           options_.hidden_size(),
891           options_.bias(),
892           /*num_chunks=*/1)),
893       options(options_) {}
894 
forward(const Tensor & input,Tensor hx)895 Tensor RNNCellImpl::forward(const Tensor& input, Tensor hx) {
896   this->check_forward_input(input, "input");
897   this->check_forward_input(hx, "hidden");
898 
899   Tensor r_hx, ret;
900 
901   bool is_batched = input.dim() == 2;
902   Tensor r_input = is_batched ? input : input.unsqueeze(0);
903 
904   if (!hx.defined()) {
905     r_hx = torch::zeros(
906         {input.size(0), options.hidden_size()},
907         torch::dtype(input.dtype()).device(input.device()));
908   } else {
909     r_hx = is_batched ? hx : hx.unsqueeze(0);
910   }
911 
912   if (std::holds_alternative<enumtype::kTanh>(options.nonlinearity())) {
913     ret = torch::rnn_tanh_cell(
914         r_input, r_hx, weight_ih, weight_hh, bias_ih, bias_hh);
915   } else if (std::holds_alternative<enumtype::kReLU>(options.nonlinearity())) {
916     ret = torch::rnn_relu_cell(
917         r_input, r_hx, weight_ih, weight_hh, bias_ih, bias_hh);
918   } else {
919     TORCH_CHECK(
920         false,
921         "Unknown nonlinearity: ",
922         torch::enumtype::get_enum_name(options.nonlinearity()));
923   }
924 
925   if (!is_batched) {
926     ret = ret.squeeze(0);
927   }
928 
929   return ret;
930 }
931 
get_nonlinearity_str() const932 std::string RNNCellImpl::get_nonlinearity_str() const {
933   return get_enum_name(options.nonlinearity());
934 }
935 
936 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTMCell
937 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
938 
LSTMCellImpl(const LSTMCellOptions & options_)939 LSTMCellImpl::LSTMCellImpl(const LSTMCellOptions& options_)
940     : detail::RNNCellImplBase<LSTMCellImpl>(detail::RNNCellOptionsBase(
941           options_.input_size(),
942           options_.hidden_size(),
943           options_.bias(),
944           /*num_chunks=*/4)),
945       options(options_) {}
946 
forward(const Tensor & input,torch::optional<std::tuple<Tensor,Tensor>> hx_opt)947 std::tuple<Tensor, Tensor> LSTMCellImpl::forward(
948     const Tensor& input,
949     torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
950   this->check_forward_input(input, "input");
951   if (hx_opt.has_value()) {
952     this->check_forward_input(std::get<0>(hx_opt.value()), "hx[0]");
953     this->check_forward_input(std::get<1>(hx_opt.value()), "hx[1]");
954   }
955 
956   std::tuple<Tensor, Tensor> r_hx, ret;
957 
958   bool is_batched = input.dim() == 2;
959   Tensor r_input = is_batched ? input : input.unsqueeze(0);
960 
961   if (!hx_opt.has_value()) {
962     auto zeros = torch::zeros(
963         {input.size(0), options.hidden_size()},
964         torch::dtype(input.dtype()).device(input.device()));
965     r_hx = std::make_tuple(zeros, zeros);
966   } else {
967     if (!is_batched) {
968       r_hx = std::make_tuple(
969           std::get<0>(hx_opt.value()).unsqueeze(0),
970           std::get<1>(hx_opt.value()).unsqueeze(0));
971     } else {
972       r_hx = hx_opt.value();
973     }
974   }
975 
976   ret = torch::lstm_cell(
977       r_input,
978       {std::get<0>(r_hx), std::get<1>(r_hx)},
979       weight_ih,
980       weight_hh,
981       bias_ih,
982       bias_hh);
983 
984   if (!is_batched) {
985     ret = std::make_tuple(
986         std::get<0>(ret).squeeze(0), std::get<1>(ret).squeeze(0));
987   }
988 
989   return ret;
990 }
991 
992 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRUCell
993 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
994 
GRUCellImpl(const GRUCellOptions & options_)995 GRUCellImpl::GRUCellImpl(const GRUCellOptions& options_)
996     : detail::RNNCellImplBase<GRUCellImpl>(detail::RNNCellOptionsBase(
997           options_.input_size(),
998           options_.hidden_size(),
999           options_.bias(),
1000           /*num_chunks=*/3)),
1001       options(options_) {}
1002 
forward(const Tensor & input,Tensor hx)1003 Tensor GRUCellImpl::forward(const Tensor& input, Tensor hx) {
1004   this->check_forward_input(input, "input");
1005   this->check_forward_input(hx, "hidden");
1006 
1007   Tensor r_hx, ret;
1008 
1009   bool is_batched = input.dim() == 2;
1010   Tensor r_input = is_batched ? input : input.unsqueeze(0);
1011 
1012   if (!hx.defined()) {
1013     r_hx = torch::zeros(
1014         {input.size(0), options.hidden_size()},
1015         torch::dtype(input.dtype()).device(input.device()));
1016   } else {
1017     r_hx = is_batched ? hx : hx.unsqueeze(0);
1018   }
1019 
1020   ret = torch::gru_cell(r_input, r_hx, weight_ih, weight_hh, bias_ih, bias_hh);
1021 
1022   if (!is_batched) {
1023     ret = ret.squeeze(0);
1024   }
1025 
1026   return ret;
1027 }
1028 
1029 } // namespace nn
1030 } // namespace torch
1031