1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import math 4import numbers 5import warnings 6import weakref 7from typing import List, Optional, overload, Tuple 8from typing_extensions import deprecated 9 10import torch 11from torch import _VF, Tensor 12from torch.nn import init 13from torch.nn.parameter import Parameter 14from torch.nn.utils.rnn import PackedSequence 15 16from .module import Module 17 18 19__all__ = [ 20 "RNNBase", 21 "RNN", 22 "LSTM", 23 "GRU", 24 "RNNCellBase", 25 "RNNCell", 26 "LSTMCell", 27 "GRUCell", 28] 29 30_rnn_impls = { 31 "RNN_TANH": _VF.rnn_tanh, 32 "RNN_RELU": _VF.rnn_relu, 33} 34 35 36def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: 37 return tensor.index_select(dim, permutation) 38 39 40@deprecated( 41 "`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead", 42 category=FutureWarning, 43) 44def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: 45 return _apply_permutation(tensor, permutation, dim) 46 47 48class RNNBase(Module): 49 r"""Base class for RNN modules (RNN, LSTM, GRU). 50 51 Implements aspects of RNNs shared by the RNN, LSTM, and GRU classes, such as module initialization 52 and utility methods for parameter storage management. 53 54 .. note:: 55 The forward method is not implemented by the RNNBase class. 56 57 .. note:: 58 LSTM and GRU classes override some methods implemented by RNNBase. 59 """ 60 61 __constants__ = [ 62 "mode", 63 "input_size", 64 "hidden_size", 65 "num_layers", 66 "bias", 67 "batch_first", 68 "dropout", 69 "bidirectional", 70 "proj_size", 71 ] 72 __jit_unused_properties__ = ["all_weights"] 73 74 mode: str 75 input_size: int 76 hidden_size: int 77 num_layers: int 78 bias: bool 79 batch_first: bool 80 dropout: float 81 bidirectional: bool 82 proj_size: int 83 84 def __init__( 85 self, 86 mode: str, 87 input_size: int, 88 hidden_size: int, 89 num_layers: int = 1, 90 bias: bool = True, 91 batch_first: bool = False, 92 dropout: float = 0.0, 93 bidirectional: bool = False, 94 proj_size: int = 0, 95 device=None, 96 dtype=None, 97 ) -> None: 98 factory_kwargs = {"device": device, "dtype": dtype} 99 super().__init__() 100 self.mode = mode 101 self.input_size = input_size 102 self.hidden_size = hidden_size 103 self.num_layers = num_layers 104 self.bias = bias 105 self.batch_first = batch_first 106 self.dropout = float(dropout) 107 self.bidirectional = bidirectional 108 self.proj_size = proj_size 109 self._flat_weight_refs: List[Optional[weakref.ReferenceType[Parameter]]] = [] 110 num_directions = 2 if bidirectional else 1 111 112 if ( 113 not isinstance(dropout, numbers.Number) 114 or not 0 <= dropout <= 1 115 or isinstance(dropout, bool) 116 ): 117 raise ValueError( 118 "dropout should be a number in range [0, 1] " 119 "representing the probability of an element being " 120 "zeroed" 121 ) 122 if dropout > 0 and num_layers == 1: 123 warnings.warn( 124 "dropout option adds dropout after all but last " 125 "recurrent layer, so non-zero dropout expects " 126 f"num_layers greater than 1, but got dropout={dropout} and " 127 f"num_layers={num_layers}" 128 ) 129 130 if not isinstance(hidden_size, int): 131 raise TypeError( 132 f"hidden_size should be of type int, got: {type(hidden_size).__name__}" 133 ) 134 if hidden_size <= 0: 135 raise ValueError("hidden_size must be greater than zero") 136 if num_layers <= 0: 137 raise ValueError("num_layers must be greater than zero") 138 if proj_size < 0: 139 raise ValueError( 140 "proj_size should be a positive integer or zero to disable projections" 141 ) 142 if proj_size >= hidden_size: 143 raise ValueError("proj_size has to be smaller than hidden_size") 144 145 if mode == "LSTM": 146 gate_size = 4 * hidden_size 147 elif mode == "GRU": 148 gate_size = 3 * hidden_size 149 elif mode == "RNN_TANH": 150 gate_size = hidden_size 151 elif mode == "RNN_RELU": 152 gate_size = hidden_size 153 else: 154 raise ValueError("Unrecognized RNN mode: " + mode) 155 156 self._flat_weights_names = [] 157 self._all_weights = [] 158 for layer in range(num_layers): 159 for direction in range(num_directions): 160 real_hidden_size = proj_size if proj_size > 0 else hidden_size 161 layer_input_size = ( 162 input_size if layer == 0 else real_hidden_size * num_directions 163 ) 164 165 w_ih = Parameter( 166 torch.empty((gate_size, layer_input_size), **factory_kwargs) 167 ) 168 w_hh = Parameter( 169 torch.empty((gate_size, real_hidden_size), **factory_kwargs) 170 ) 171 b_ih = Parameter(torch.empty(gate_size, **factory_kwargs)) 172 # Second bias vector included for CuDNN compatibility. Only one 173 # bias vector is needed in standard definition. 174 b_hh = Parameter(torch.empty(gate_size, **factory_kwargs)) 175 layer_params: Tuple[Tensor, ...] = () 176 if self.proj_size == 0: 177 if bias: 178 layer_params = (w_ih, w_hh, b_ih, b_hh) 179 else: 180 layer_params = (w_ih, w_hh) 181 else: 182 w_hr = Parameter( 183 torch.empty((proj_size, hidden_size), **factory_kwargs) 184 ) 185 if bias: 186 layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr) 187 else: 188 layer_params = (w_ih, w_hh, w_hr) 189 190 suffix = "_reverse" if direction == 1 else "" 191 param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"] 192 if bias: 193 param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"] 194 if self.proj_size > 0: 195 param_names += ["weight_hr_l{}{}"] 196 param_names = [x.format(layer, suffix) for x in param_names] 197 198 for name, param in zip(param_names, layer_params): 199 setattr(self, name, param) 200 self._flat_weights_names.extend(param_names) 201 self._all_weights.append(param_names) 202 203 self._init_flat_weights() 204 205 self.reset_parameters() 206 207 def _init_flat_weights(self): 208 self._flat_weights = [ 209 getattr(self, wn) if hasattr(self, wn) else None 210 for wn in self._flat_weights_names 211 ] 212 self._flat_weight_refs = [ 213 weakref.ref(w) if w is not None else None for w in self._flat_weights 214 ] 215 self.flatten_parameters() 216 217 def __setattr__(self, attr, value): 218 if hasattr(self, "_flat_weights_names") and attr in self._flat_weights_names: 219 # keep self._flat_weights up to date if you do self.weight = ... 220 idx = self._flat_weights_names.index(attr) 221 self._flat_weights[idx] = value 222 super().__setattr__(attr, value) 223 224 def flatten_parameters(self) -> None: 225 """Reset parameter data pointer so that they can use faster code paths. 226 227 Right now, this works only if the module is on the GPU and cuDNN is enabled. 228 Otherwise, it's a no-op. 229 """ 230 # Short-circuits if _flat_weights is only partially instantiated 231 if len(self._flat_weights) != len(self._flat_weights_names): 232 return 233 234 for w in self._flat_weights: 235 if not isinstance(w, Tensor): 236 return 237 # Short-circuits if any tensor in self._flat_weights is not acceptable to cuDNN 238 # or the tensors in _flat_weights are of different dtypes 239 240 first_fw = self._flat_weights[0] 241 dtype = first_fw.dtype 242 for fw in self._flat_weights: 243 if ( 244 not isinstance(fw, Tensor) 245 or not (fw.dtype == dtype) 246 or not fw.is_cuda 247 or not torch.backends.cudnn.is_acceptable(fw) 248 ): 249 return 250 251 # If any parameters alias, we fall back to the slower, copying code path. This is 252 # a sufficient check, because overlapping parameter buffers that don't completely 253 # alias would break the assumptions of the uniqueness check in 254 # Module.named_parameters(). 255 unique_data_ptrs = {p.data_ptr() for p in self._flat_weights} 256 if len(unique_data_ptrs) != len(self._flat_weights): 257 return 258 259 with torch.cuda.device_of(first_fw): 260 import torch.backends.cudnn.rnn as rnn 261 262 # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is 263 # an inplace operation on self._flat_weights 264 with torch.no_grad(): 265 if torch._use_cudnn_rnn_flatten_weight(): 266 num_weights = 4 if self.bias else 2 267 if self.proj_size > 0: 268 num_weights += 1 269 torch._cudnn_rnn_flatten_weight( 270 self._flat_weights, 271 num_weights, 272 self.input_size, 273 rnn.get_cudnn_mode(self.mode), 274 self.hidden_size, 275 self.proj_size, 276 self.num_layers, 277 self.batch_first, 278 bool(self.bidirectional), 279 ) 280 281 def _apply(self, fn, recurse=True): 282 self._flat_weight_refs = [] 283 ret = super()._apply(fn, recurse) 284 285 # Resets _flat_weights 286 # Note: be v. careful before removing this, as 3rd party device types 287 # likely rely on this behavior to properly .to() modules like LSTM. 288 self._init_flat_weights() 289 290 return ret 291 292 def reset_parameters(self) -> None: 293 stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0 294 for weight in self.parameters(): 295 init.uniform_(weight, -stdv, stdv) 296 297 def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: 298 if not torch.jit.is_scripting(): 299 if ( 300 input.dtype != self._flat_weights[0].dtype 301 and not torch._C._is_any_autocast_enabled() 302 ): 303 raise ValueError( 304 f"input must have the type {self._flat_weights[0].dtype}, got type {input.dtype}" 305 ) 306 expected_input_dim = 2 if batch_sizes is not None else 3 307 if input.dim() != expected_input_dim: 308 raise RuntimeError( 309 f"input must have {expected_input_dim} dimensions, got {input.dim()}" 310 ) 311 if self.input_size != input.size(-1): 312 raise RuntimeError( 313 f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}" 314 ) 315 316 def get_expected_hidden_size( 317 self, input: Tensor, batch_sizes: Optional[Tensor] 318 ) -> Tuple[int, int, int]: 319 if batch_sizes is not None: 320 mini_batch = int(batch_sizes[0]) 321 else: 322 mini_batch = input.size(0) if self.batch_first else input.size(1) 323 num_directions = 2 if self.bidirectional else 1 324 if self.proj_size > 0: 325 expected_hidden_size = ( 326 self.num_layers * num_directions, 327 mini_batch, 328 self.proj_size, 329 ) 330 else: 331 expected_hidden_size = ( 332 self.num_layers * num_directions, 333 mini_batch, 334 self.hidden_size, 335 ) 336 return expected_hidden_size 337 338 def check_hidden_size( 339 self, 340 hx: Tensor, 341 expected_hidden_size: Tuple[int, int, int], 342 msg: str = "Expected hidden size {}, got {}", 343 ) -> None: 344 if hx.size() != expected_hidden_size: 345 raise RuntimeError(msg.format(expected_hidden_size, list(hx.size()))) 346 347 def _weights_have_changed(self): 348 # Returns True if the weight tensors have changed since the last forward pass. 349 # This is the case when used with torch.func.functional_call(), for example. 350 weights_changed = False 351 for ref, name in zip(self._flat_weight_refs, self._flat_weights_names): 352 weight = getattr(self, name) if hasattr(self, name) else None 353 if weight is not None and ref is not None and ref() is not weight: 354 weights_changed = True 355 break 356 return weights_changed 357 358 def check_forward_args( 359 self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] 360 ): 361 self.check_input(input, batch_sizes) 362 expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) 363 364 self.check_hidden_size(hidden, expected_hidden_size) 365 366 def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]): 367 if permutation is None: 368 return hx 369 return _apply_permutation(hx, permutation) 370 371 def extra_repr(self) -> str: 372 s = "{input_size}, {hidden_size}" 373 if self.proj_size != 0: 374 s += ", proj_size={proj_size}" 375 if self.num_layers != 1: 376 s += ", num_layers={num_layers}" 377 if self.bias is not True: 378 s += ", bias={bias}" 379 if self.batch_first is not False: 380 s += ", batch_first={batch_first}" 381 if self.dropout != 0: 382 s += ", dropout={dropout}" 383 if self.bidirectional is not False: 384 s += ", bidirectional={bidirectional}" 385 return s.format(**self.__dict__) 386 387 def _update_flat_weights(self): 388 if not torch.jit.is_scripting(): 389 if self._weights_have_changed(): 390 self._init_flat_weights() 391 392 def __getstate__(self): 393 # If weights have been changed, update the _flat_weights in __getstate__ here. 394 self._update_flat_weights() 395 # Don't serialize the weight references. 396 state = self.__dict__.copy() 397 del state["_flat_weight_refs"] 398 return state 399 400 def __setstate__(self, d): 401 super().__setstate__(d) 402 if "all_weights" in d: 403 self._all_weights = d["all_weights"] 404 # In PyTorch 1.8 we added a proj_size member variable to LSTM. 405 # LSTMs that were serialized via torch.save(module) before PyTorch 1.8 406 # don't have it, so to preserve compatibility we set proj_size here. 407 if "proj_size" not in d: 408 self.proj_size = 0 409 410 if not isinstance(self._all_weights[0][0], str): 411 num_layers = self.num_layers 412 num_directions = 2 if self.bidirectional else 1 413 self._flat_weights_names = [] 414 self._all_weights = [] 415 for layer in range(num_layers): 416 for direction in range(num_directions): 417 suffix = "_reverse" if direction == 1 else "" 418 weights = [ 419 "weight_ih_l{}{}", 420 "weight_hh_l{}{}", 421 "bias_ih_l{}{}", 422 "bias_hh_l{}{}", 423 "weight_hr_l{}{}", 424 ] 425 weights = [x.format(layer, suffix) for x in weights] 426 if self.bias: 427 if self.proj_size > 0: 428 self._all_weights += [weights] 429 self._flat_weights_names.extend(weights) 430 else: 431 self._all_weights += [weights[:4]] 432 self._flat_weights_names.extend(weights[:4]) 433 else: 434 if self.proj_size > 0: 435 self._all_weights += [weights[:2]] + [weights[-1:]] 436 self._flat_weights_names.extend( 437 weights[:2] + [weights[-1:]] 438 ) 439 else: 440 self._all_weights += [weights[:2]] 441 self._flat_weights_names.extend(weights[:2]) 442 self._flat_weights = [ 443 getattr(self, wn) if hasattr(self, wn) else None 444 for wn in self._flat_weights_names 445 ] 446 447 self._flat_weight_refs = [ 448 weakref.ref(w) if w is not None else None for w in self._flat_weights 449 ] 450 451 @property 452 def all_weights(self) -> List[List[Parameter]]: 453 return [ 454 [getattr(self, weight) for weight in weights] 455 for weights in self._all_weights 456 ] 457 458 def _replicate_for_data_parallel(self): 459 replica = super()._replicate_for_data_parallel() 460 # Need to copy these caches, otherwise the replica will share the same 461 # flat weights list. 462 replica._flat_weights = replica._flat_weights[:] 463 replica._flat_weights_names = replica._flat_weights_names[:] 464 return replica 465 466 467class RNN(RNNBase): 468 r"""__init__(input_size,hidden_size,num_layers=1,nonlinearity='tanh',bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None) 469 470 Apply a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` 471 non-linearity to an input sequence. For each element in the input sequence, 472 each layer computes the following function: 473 474 .. math:: 475 h_t = \tanh(x_t W_{ih}^T + b_{ih} + h_{t-1}W_{hh}^T + b_{hh}) 476 477 where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is 478 the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the 479 previous layer at time `t-1` or the initial hidden state at time `0`. 480 If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`. 481 482 .. code-block:: python 483 484 # Efficient implementation equivalent to the following with bidirectional=False 485 def forward(x, h_0=None): 486 if batch_first: 487 x = x.transpose(0, 1) 488 seq_len, batch_size, _ = x.size() 489 if h_0 is None: 490 h_0 = torch.zeros(num_layers, batch_size, hidden_size) 491 h_t_minus_1 = h_0 492 h_t = h_0 493 output = [] 494 for t in range(seq_len): 495 for layer in range(num_layers): 496 h_t[layer] = torch.tanh( 497 x[t] @ weight_ih[layer].T 498 + bias_ih[layer] 499 + h_t_minus_1[layer] @ weight_hh[layer].T 500 + bias_hh[layer] 501 ) 502 output.append(h_t[-1]) 503 h_t_minus_1 = h_t 504 output = torch.stack(output) 505 if batch_first: 506 output = output.transpose(0, 1) 507 return output, h_t 508 509 Args: 510 input_size: The number of expected features in the input `x` 511 hidden_size: The number of features in the hidden state `h` 512 num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` 513 would mean stacking two RNNs together to form a `stacked RNN`, 514 with the second RNN taking in outputs of the first RNN and 515 computing the final results. Default: 1 516 nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'`` 517 bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. 518 Default: ``True`` 519 batch_first: If ``True``, then the input and output tensors are provided 520 as `(batch, seq, feature)` instead of `(seq, batch, feature)`. 521 Note that this does not apply to hidden or cell states. See the 522 Inputs/Outputs sections below for details. Default: ``False`` 523 dropout: If non-zero, introduces a `Dropout` layer on the outputs of each 524 RNN layer except the last layer, with dropout probability equal to 525 :attr:`dropout`. Default: 0 526 bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False`` 527 528 Inputs: input, h_0 529 * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, 530 :math:`(L, N, H_{in})` when ``batch_first=False`` or 531 :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of 532 the input sequence. The input can also be a packed variable length sequence. 533 See :func:`torch.nn.utils.rnn.pack_padded_sequence` or 534 :func:`torch.nn.utils.rnn.pack_sequence` for details. 535 * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or 536 :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden 537 state for the input sequence batch. Defaults to zeros if not provided. 538 539 where: 540 541 .. math:: 542 \begin{aligned} 543 N ={} & \text{batch size} \\ 544 L ={} & \text{sequence length} \\ 545 D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ 546 H_{in} ={} & \text{input\_size} \\ 547 H_{out} ={} & \text{hidden\_size} 548 \end{aligned} 549 550 Outputs: output, h_n 551 * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input, 552 :math:`(L, N, D * H_{out})` when ``batch_first=False`` or 553 :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features 554 `(h_t)` from the last layer of the RNN, for each `t`. If a 555 :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output 556 will also be a packed sequence. 557 * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or 558 :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state 559 for each element in the batch. 560 561 Attributes: 562 weight_ih_l[k]: the learnable input-hidden weights of the k-th layer, 563 of shape `(hidden_size, input_size)` for `k = 0`. Otherwise, the shape is 564 `(hidden_size, num_directions * hidden_size)` 565 weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer, 566 of shape `(hidden_size, hidden_size)` 567 bias_ih_l[k]: the learnable input-hidden bias of the k-th layer, 568 of shape `(hidden_size)` 569 bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer, 570 of shape `(hidden_size)` 571 572 .. note:: 573 All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` 574 where :math:`k = \frac{1}{\text{hidden\_size}}` 575 576 .. note:: 577 For bidirectional RNNs, forward and backward are directions 0 and 1 respectively. 578 Example of splitting the output layers when ``batch_first=False``: 579 ``output.view(seq_len, batch, num_directions, hidden_size)``. 580 581 .. note:: 582 ``batch_first`` argument is ignored for unbatched inputs. 583 584 .. include:: ../cudnn_rnn_determinism.rst 585 586 .. include:: ../cudnn_persistent_rnn.rst 587 588 Examples:: 589 590 >>> rnn = nn.RNN(10, 20, 2) 591 >>> input = torch.randn(5, 3, 10) 592 >>> h0 = torch.randn(2, 3, 20) 593 >>> output, hn = rnn(input, h0) 594 """ 595 596 @overload 597 def __init__( 598 self, 599 input_size: int, 600 hidden_size: int, 601 num_layers: int = 1, 602 nonlinearity: str = "tanh", 603 bias: bool = True, 604 batch_first: bool = False, 605 dropout: float = 0.0, 606 bidirectional: bool = False, 607 device=None, 608 dtype=None, 609 ) -> None: 610 ... 611 612 @overload 613 def __init__(self, *args, **kwargs): 614 ... 615 616 def __init__(self, *args, **kwargs): 617 if "proj_size" in kwargs: 618 raise ValueError( 619 "proj_size argument is only supported for LSTM, not RNN or GRU" 620 ) 621 if len(args) > 3: 622 self.nonlinearity = args[3] 623 args = args[:3] + args[4:] 624 else: 625 self.nonlinearity = kwargs.pop("nonlinearity", "tanh") 626 if self.nonlinearity == "tanh": 627 mode = "RNN_TANH" 628 elif self.nonlinearity == "relu": 629 mode = "RNN_RELU" 630 else: 631 raise ValueError( 632 f"Unknown nonlinearity '{self.nonlinearity}'. Select from 'tanh' or 'relu'." 633 ) 634 super().__init__(mode, *args, **kwargs) 635 636 @overload 637 @torch._jit_internal._overload_method # noqa: F811 638 def forward( 639 self, input: Tensor, hx: Optional[Tensor] = None 640 ) -> Tuple[Tensor, Tensor]: 641 pass 642 643 @overload 644 @torch._jit_internal._overload_method # noqa: F811 645 def forward( 646 self, input: PackedSequence, hx: Optional[Tensor] = None 647 ) -> Tuple[PackedSequence, Tensor]: 648 pass 649 650 def forward(self, input, hx=None): # noqa: F811 651 self._update_flat_weights() 652 653 num_directions = 2 if self.bidirectional else 1 654 orig_input = input 655 656 if isinstance(orig_input, PackedSequence): 657 input, batch_sizes, sorted_indices, unsorted_indices = input 658 max_batch_size = batch_sizes[0] 659 # script() is unhappy when max_batch_size is different type in cond branches, so we duplicate 660 if hx is None: 661 hx = torch.zeros( 662 self.num_layers * num_directions, 663 max_batch_size, 664 self.hidden_size, 665 dtype=input.dtype, 666 device=input.device, 667 ) 668 else: 669 # Each batch of the hidden state should match the input sequence that 670 # the user believes he/she is passing in. 671 hx = self.permute_hidden(hx, sorted_indices) 672 else: 673 batch_sizes = None 674 if input.dim() not in (2, 3): 675 raise ValueError( 676 f"RNN: Expected input to be 2D or 3D, got {input.dim()}D tensor instead" 677 ) 678 is_batched = input.dim() == 3 679 batch_dim = 0 if self.batch_first else 1 680 if not is_batched: 681 input = input.unsqueeze(batch_dim) 682 if hx is not None: 683 if hx.dim() != 2: 684 raise RuntimeError( 685 f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor" 686 ) 687 hx = hx.unsqueeze(1) 688 else: 689 if hx is not None and hx.dim() != 3: 690 raise RuntimeError( 691 f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor" 692 ) 693 max_batch_size = input.size(0) if self.batch_first else input.size(1) 694 sorted_indices = None 695 unsorted_indices = None 696 if hx is None: 697 hx = torch.zeros( 698 self.num_layers * num_directions, 699 max_batch_size, 700 self.hidden_size, 701 dtype=input.dtype, 702 device=input.device, 703 ) 704 else: 705 # Each batch of the hidden state should match the input sequence that 706 # the user believes he/she is passing in. 707 hx = self.permute_hidden(hx, sorted_indices) 708 709 assert hx is not None 710 self.check_forward_args(input, hx, batch_sizes) 711 assert self.mode == "RNN_TANH" or self.mode == "RNN_RELU" 712 if batch_sizes is None: 713 if self.mode == "RNN_TANH": 714 result = _VF.rnn_tanh( 715 input, 716 hx, 717 self._flat_weights, 718 self.bias, 719 self.num_layers, 720 self.dropout, 721 self.training, 722 self.bidirectional, 723 self.batch_first, 724 ) 725 else: 726 result = _VF.rnn_relu( 727 input, 728 hx, 729 self._flat_weights, 730 self.bias, 731 self.num_layers, 732 self.dropout, 733 self.training, 734 self.bidirectional, 735 self.batch_first, 736 ) 737 else: 738 if self.mode == "RNN_TANH": 739 result = _VF.rnn_tanh( 740 input, 741 batch_sizes, 742 hx, 743 self._flat_weights, 744 self.bias, 745 self.num_layers, 746 self.dropout, 747 self.training, 748 self.bidirectional, 749 ) 750 else: 751 result = _VF.rnn_relu( 752 input, 753 batch_sizes, 754 hx, 755 self._flat_weights, 756 self.bias, 757 self.num_layers, 758 self.dropout, 759 self.training, 760 self.bidirectional, 761 ) 762 763 output = result[0] 764 hidden = result[1] 765 766 if isinstance(orig_input, PackedSequence): 767 output_packed = PackedSequence( 768 output, batch_sizes, sorted_indices, unsorted_indices 769 ) 770 return output_packed, self.permute_hidden(hidden, unsorted_indices) 771 772 if not is_batched: # type: ignore[possibly-undefined] 773 output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] 774 hidden = hidden.squeeze(1) 775 776 return output, self.permute_hidden(hidden, unsorted_indices) 777 778 779# XXX: LSTM and GRU implementation is different from RNNBase, this is because: 780# 1. we want to support nn.LSTM and nn.GRU in TorchScript and TorchScript in 781# its current state could not support the python Union Type or Any Type 782# 2. TorchScript static typing does not allow a Function or Callable type in 783# Dict values, so we have to separately call _VF instead of using _rnn_impls 784# 3. This is temporary only and in the transition state that we want to make it 785# on time for the release 786# 787# More discussion details in https://github.com/pytorch/pytorch/pull/23266 788# 789# TODO: remove the overriding implementations for LSTM and GRU when TorchScript 790# support expressing these two modules generally. 791 792 793class LSTM(RNNBase): 794 r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,proj_size=0,device=None,dtype=None) 795 796 Apply a multi-layer long short-term memory (LSTM) RNN to an input sequence. 797 For each element in the input sequence, each layer computes the following 798 function: 799 800 .. math:: 801 \begin{array}{ll} \\ 802 i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\ 803 f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\ 804 g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\ 805 o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\ 806 c_t = f_t \odot c_{t-1} + i_t \odot g_t \\ 807 h_t = o_t \odot \tanh(c_t) \\ 808 \end{array} 809 810 where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell 811 state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}` 812 is the hidden state of the layer at time `t-1` or the initial hidden 813 state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`, 814 :math:`o_t` are the input, forget, cell, and output gates, respectively. 815 :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. 816 817 In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer 818 (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by 819 dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random 820 variable which is :math:`0` with probability :attr:`dropout`. 821 822 If ``proj_size > 0`` is specified, LSTM with projections will be used. This changes 823 the LSTM cell in the following way. First, the dimension of :math:`h_t` will be changed from 824 ``hidden_size`` to ``proj_size`` (dimensions of :math:`W_{hi}` will be changed accordingly). 825 Second, the output hidden state of each layer will be multiplied by a learnable projection 826 matrix: :math:`h_t = W_{hr}h_t`. Note that as a consequence of this, the output 827 of LSTM network will be of different shape as well. See Inputs/Outputs sections below for exact 828 dimensions of all variables. You can find more details in https://arxiv.org/abs/1402.1128. 829 830 Args: 831 input_size: The number of expected features in the input `x` 832 hidden_size: The number of features in the hidden state `h` 833 num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` 834 would mean stacking two LSTMs together to form a `stacked LSTM`, 835 with the second LSTM taking in outputs of the first LSTM and 836 computing the final results. Default: 1 837 bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. 838 Default: ``True`` 839 batch_first: If ``True``, then the input and output tensors are provided 840 as `(batch, seq, feature)` instead of `(seq, batch, feature)`. 841 Note that this does not apply to hidden or cell states. See the 842 Inputs/Outputs sections below for details. Default: ``False`` 843 dropout: If non-zero, introduces a `Dropout` layer on the outputs of each 844 LSTM layer except the last layer, with dropout probability equal to 845 :attr:`dropout`. Default: 0 846 bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False`` 847 proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0 848 849 Inputs: input, (h_0, c_0) 850 * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, 851 :math:`(L, N, H_{in})` when ``batch_first=False`` or 852 :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of 853 the input sequence. The input can also be a packed variable length sequence. 854 See :func:`torch.nn.utils.rnn.pack_padded_sequence` or 855 :func:`torch.nn.utils.rnn.pack_sequence` for details. 856 * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or 857 :math:`(D * \text{num\_layers}, N, H_{out})` containing the 858 initial hidden state for each element in the input sequence. 859 Defaults to zeros if (h_0, c_0) is not provided. 860 * **c_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or 861 :math:`(D * \text{num\_layers}, N, H_{cell})` containing the 862 initial cell state for each element in the input sequence. 863 Defaults to zeros if (h_0, c_0) is not provided. 864 865 where: 866 867 .. math:: 868 \begin{aligned} 869 N ={} & \text{batch size} \\ 870 L ={} & \text{sequence length} \\ 871 D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ 872 H_{in} ={} & \text{input\_size} \\ 873 H_{cell} ={} & \text{hidden\_size} \\ 874 H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\ 875 \end{aligned} 876 877 Outputs: output, (h_n, c_n) 878 * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input, 879 :math:`(L, N, D * H_{out})` when ``batch_first=False`` or 880 :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features 881 `(h_t)` from the last layer of the LSTM, for each `t`. If a 882 :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output 883 will also be a packed sequence. When ``bidirectional=True``, `output` will contain 884 a concatenation of the forward and reverse hidden states at each time step in the sequence. 885 * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or 886 :math:`(D * \text{num\_layers}, N, H_{out})` containing the 887 final hidden state for each element in the sequence. When ``bidirectional=True``, 888 `h_n` will contain a concatenation of the final forward and reverse hidden states, respectively. 889 * **c_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or 890 :math:`(D * \text{num\_layers}, N, H_{cell})` containing the 891 final cell state for each element in the sequence. When ``bidirectional=True``, 892 `c_n` will contain a concatenation of the final forward and reverse cell states, respectively. 893 894 Attributes: 895 weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer 896 `(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`. 897 Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)`. If 898 ``proj_size > 0`` was specified, the shape will be 899 `(4*hidden_size, num_directions * proj_size)` for `k > 0` 900 weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer 901 `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`. If ``proj_size > 0`` 902 was specified, the shape will be `(4*hidden_size, proj_size)`. 903 bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer 904 `(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)` 905 bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer 906 `(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)` 907 weight_hr_l[k] : the learnable projection weights of the :math:`\text{k}^{th}` layer 908 of shape `(proj_size, hidden_size)`. Only present when ``proj_size > 0`` was 909 specified. 910 weight_ih_l[k]_reverse: Analogous to `weight_ih_l[k]` for the reverse direction. 911 Only present when ``bidirectional=True``. 912 weight_hh_l[k]_reverse: Analogous to `weight_hh_l[k]` for the reverse direction. 913 Only present when ``bidirectional=True``. 914 bias_ih_l[k]_reverse: Analogous to `bias_ih_l[k]` for the reverse direction. 915 Only present when ``bidirectional=True``. 916 bias_hh_l[k]_reverse: Analogous to `bias_hh_l[k]` for the reverse direction. 917 Only present when ``bidirectional=True``. 918 weight_hr_l[k]_reverse: Analogous to `weight_hr_l[k]` for the reverse direction. 919 Only present when ``bidirectional=True`` and ``proj_size > 0`` was specified. 920 921 .. note:: 922 All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` 923 where :math:`k = \frac{1}{\text{hidden\_size}}` 924 925 .. note:: 926 For bidirectional LSTMs, forward and backward are directions 0 and 1 respectively. 927 Example of splitting the output layers when ``batch_first=False``: 928 ``output.view(seq_len, batch, num_directions, hidden_size)``. 929 930 .. note:: 931 For bidirectional LSTMs, `h_n` is not equivalent to the last element of `output`; the 932 former contains the final forward and reverse hidden states, while the latter contains the 933 final forward hidden state and the initial reverse hidden state. 934 935 .. note:: 936 ``batch_first`` argument is ignored for unbatched inputs. 937 938 .. note:: 939 ``proj_size`` should be smaller than ``hidden_size``. 940 941 .. include:: ../cudnn_rnn_determinism.rst 942 943 .. include:: ../cudnn_persistent_rnn.rst 944 945 Examples:: 946 947 >>> rnn = nn.LSTM(10, 20, 2) 948 >>> input = torch.randn(5, 3, 10) 949 >>> h0 = torch.randn(2, 3, 20) 950 >>> c0 = torch.randn(2, 3, 20) 951 >>> output, (hn, cn) = rnn(input, (h0, c0)) 952 """ 953 954 @overload 955 def __init__( 956 self, 957 input_size: int, 958 hidden_size: int, 959 num_layers: int = 1, 960 bias: bool = True, 961 batch_first: bool = False, 962 dropout: float = 0.0, 963 bidirectional: bool = False, 964 proj_size: int = 0, 965 device=None, 966 dtype=None, 967 ) -> None: 968 ... 969 970 @overload 971 def __init__(self, *args, **kwargs): 972 ... 973 974 def __init__(self, *args, **kwargs): 975 super().__init__("LSTM", *args, **kwargs) 976 977 def get_expected_cell_size( 978 self, input: Tensor, batch_sizes: Optional[Tensor] 979 ) -> Tuple[int, int, int]: 980 if batch_sizes is not None: 981 mini_batch = int(batch_sizes[0]) 982 else: 983 mini_batch = input.size(0) if self.batch_first else input.size(1) 984 num_directions = 2 if self.bidirectional else 1 985 expected_hidden_size = ( 986 self.num_layers * num_directions, 987 mini_batch, 988 self.hidden_size, 989 ) 990 return expected_hidden_size 991 992 # In the future, we should prevent mypy from applying contravariance rules here. 993 # See torch/nn/modules/module.py::_forward_unimplemented 994 def check_forward_args( 995 self, 996 input: Tensor, 997 hidden: Tuple[Tensor, Tensor], # type: ignore[override] 998 batch_sizes: Optional[Tensor], 999 ): 1000 self.check_input(input, batch_sizes) 1001 self.check_hidden_size( 1002 hidden[0], 1003 self.get_expected_hidden_size(input, batch_sizes), 1004 "Expected hidden[0] size {}, got {}", 1005 ) 1006 self.check_hidden_size( 1007 hidden[1], 1008 self.get_expected_cell_size(input, batch_sizes), 1009 "Expected hidden[1] size {}, got {}", 1010 ) 1011 1012 # Same as above, see torch/nn/modules/module.py::_forward_unimplemented 1013 def permute_hidden( # type: ignore[override] 1014 self, 1015 hx: Tuple[Tensor, Tensor], 1016 permutation: Optional[Tensor], 1017 ) -> Tuple[Tensor, Tensor]: 1018 if permutation is None: 1019 return hx 1020 return _apply_permutation(hx[0], permutation), _apply_permutation( 1021 hx[1], permutation 1022 ) 1023 1024 # Same as above, see torch/nn/modules/module.py::_forward_unimplemented 1025 @overload # type: ignore[override] 1026 @torch._jit_internal._overload_method # noqa: F811 1027 def forward( 1028 self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None 1029 ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # noqa: F811 1030 pass 1031 1032 # Same as above, see torch/nn/modules/module.py::_forward_unimplemented 1033 @overload 1034 @torch._jit_internal._overload_method # noqa: F811 1035 def forward( 1036 self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None 1037 ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: # noqa: F811 1038 pass 1039 1040 def forward(self, input, hx=None): # noqa: F811 1041 self._update_flat_weights() 1042 1043 orig_input = input 1044 # xxx: isinstance check needs to be in conditional for TorchScript to compile 1045 batch_sizes = None 1046 do_permute = False 1047 num_directions = 2 if self.bidirectional else 1 1048 real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size 1049 if isinstance(orig_input, PackedSequence): 1050 input, batch_sizes, sorted_indices, unsorted_indices = input 1051 max_batch_size = batch_sizes[0] 1052 if hx is None: 1053 h_zeros = torch.zeros( 1054 self.num_layers * num_directions, 1055 max_batch_size, 1056 real_hidden_size, 1057 dtype=input.dtype, 1058 device=input.device, 1059 ) 1060 c_zeros = torch.zeros( 1061 self.num_layers * num_directions, 1062 max_batch_size, 1063 self.hidden_size, 1064 dtype=input.dtype, 1065 device=input.device, 1066 ) 1067 hx = (h_zeros, c_zeros) 1068 else: 1069 # Each batch of the hidden state should match the input sequence that 1070 # the user believes he/she is passing in. 1071 hx = self.permute_hidden(hx, sorted_indices) 1072 else: 1073 if input.dim() not in (2, 3): 1074 raise ValueError( 1075 f"LSTM: Expected input to be 2D or 3D, got {input.dim()}D instead" 1076 ) 1077 is_batched = input.dim() == 3 1078 batch_dim = 0 if self.batch_first else 1 1079 if not is_batched: 1080 input = input.unsqueeze(batch_dim) 1081 max_batch_size = input.size(0) if self.batch_first else input.size(1) 1082 sorted_indices = None 1083 unsorted_indices = None 1084 if hx is None: 1085 h_zeros = torch.zeros( 1086 self.num_layers * num_directions, 1087 max_batch_size, 1088 real_hidden_size, 1089 dtype=input.dtype, 1090 device=input.device, 1091 ) 1092 c_zeros = torch.zeros( 1093 self.num_layers * num_directions, 1094 max_batch_size, 1095 self.hidden_size, 1096 dtype=input.dtype, 1097 device=input.device, 1098 ) 1099 hx = (h_zeros, c_zeros) 1100 self.check_forward_args(input, hx, batch_sizes) 1101 else: 1102 if is_batched: 1103 if hx[0].dim() != 3 or hx[1].dim() != 3: 1104 msg = ( 1105 "For batched 3-D input, hx and cx should " 1106 f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" 1107 ) 1108 raise RuntimeError(msg) 1109 else: 1110 if hx[0].dim() != 2 or hx[1].dim() != 2: 1111 msg = ( 1112 "For unbatched 2-D input, hx and cx should " 1113 f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" 1114 ) 1115 raise RuntimeError(msg) 1116 hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1)) 1117 # Each batch of the hidden state should match the input sequence that 1118 # the user believes he/she is passing in. 1119 self.check_forward_args(input, hx, batch_sizes) 1120 hx = self.permute_hidden(hx, sorted_indices) 1121 1122 if batch_sizes is None: 1123 result = _VF.lstm( 1124 input, 1125 hx, 1126 self._flat_weights, 1127 self.bias, 1128 self.num_layers, 1129 self.dropout, 1130 self.training, 1131 self.bidirectional, 1132 self.batch_first, 1133 ) 1134 else: 1135 result = _VF.lstm( 1136 input, 1137 batch_sizes, 1138 hx, 1139 self._flat_weights, 1140 self.bias, 1141 self.num_layers, 1142 self.dropout, 1143 self.training, 1144 self.bidirectional, 1145 ) 1146 output = result[0] 1147 hidden = result[1:] 1148 # xxx: isinstance check needs to be in conditional for TorchScript to compile 1149 if isinstance(orig_input, PackedSequence): 1150 output_packed = PackedSequence( 1151 output, batch_sizes, sorted_indices, unsorted_indices 1152 ) 1153 return output_packed, self.permute_hidden(hidden, unsorted_indices) 1154 else: 1155 if not is_batched: # type: ignore[possibly-undefined] 1156 output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] 1157 hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1)) 1158 return output, self.permute_hidden(hidden, unsorted_indices) 1159 1160 1161class GRU(RNNBase): 1162 r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None) 1163 1164 Apply a multi-layer gated recurrent unit (GRU) RNN to an input sequence. 1165 For each element in the input sequence, each layer computes the following 1166 function: 1167 1168 .. math:: 1169 \begin{array}{ll} 1170 r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ 1171 z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ 1172 n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\ 1173 h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)} 1174 \end{array} 1175 1176 where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input 1177 at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer 1178 at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`, 1179 :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively. 1180 :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. 1181 1182 In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer 1183 (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by 1184 dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random 1185 variable which is :math:`0` with probability :attr:`dropout`. 1186 1187 Args: 1188 input_size: The number of expected features in the input `x` 1189 hidden_size: The number of features in the hidden state `h` 1190 num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` 1191 would mean stacking two GRUs together to form a `stacked GRU`, 1192 with the second GRU taking in outputs of the first GRU and 1193 computing the final results. Default: 1 1194 bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. 1195 Default: ``True`` 1196 batch_first: If ``True``, then the input and output tensors are provided 1197 as `(batch, seq, feature)` instead of `(seq, batch, feature)`. 1198 Note that this does not apply to hidden or cell states. See the 1199 Inputs/Outputs sections below for details. Default: ``False`` 1200 dropout: If non-zero, introduces a `Dropout` layer on the outputs of each 1201 GRU layer except the last layer, with dropout probability equal to 1202 :attr:`dropout`. Default: 0 1203 bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False`` 1204 1205 Inputs: input, h_0 1206 * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, 1207 :math:`(L, N, H_{in})` when ``batch_first=False`` or 1208 :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of 1209 the input sequence. The input can also be a packed variable length sequence. 1210 See :func:`torch.nn.utils.rnn.pack_padded_sequence` or 1211 :func:`torch.nn.utils.rnn.pack_sequence` for details. 1212 * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or 1213 :math:`(D * \text{num\_layers}, N, H_{out})` 1214 containing the initial hidden state for the input sequence. Defaults to zeros if not provided. 1215 1216 where: 1217 1218 .. math:: 1219 \begin{aligned} 1220 N ={} & \text{batch size} \\ 1221 L ={} & \text{sequence length} \\ 1222 D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ 1223 H_{in} ={} & \text{input\_size} \\ 1224 H_{out} ={} & \text{hidden\_size} 1225 \end{aligned} 1226 1227 Outputs: output, h_n 1228 * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input, 1229 :math:`(L, N, D * H_{out})` when ``batch_first=False`` or 1230 :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features 1231 `(h_t)` from the last layer of the GRU, for each `t`. If a 1232 :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output 1233 will also be a packed sequence. 1234 * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or 1235 :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state 1236 for the input sequence. 1237 1238 Attributes: 1239 weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer 1240 (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`. 1241 Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)` 1242 weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer 1243 (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)` 1244 bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer 1245 (b_ir|b_iz|b_in), of shape `(3*hidden_size)` 1246 bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer 1247 (b_hr|b_hz|b_hn), of shape `(3*hidden_size)` 1248 1249 .. note:: 1250 All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` 1251 where :math:`k = \frac{1}{\text{hidden\_size}}` 1252 1253 .. note:: 1254 For bidirectional GRUs, forward and backward are directions 0 and 1 respectively. 1255 Example of splitting the output layers when ``batch_first=False``: 1256 ``output.view(seq_len, batch, num_directions, hidden_size)``. 1257 1258 .. note:: 1259 ``batch_first`` argument is ignored for unbatched inputs. 1260 1261 .. note:: 1262 The calculation of new gate :math:`n_t` subtly differs from the original paper and other frameworks. 1263 In the original implementation, the Hadamard product :math:`(\odot)` between :math:`r_t` and the 1264 previous hidden state :math:`h_{(t-1)}` is done before the multiplication with the weight matrix 1265 `W` and addition of bias: 1266 1267 .. math:: 1268 \begin{aligned} 1269 n_t = \tanh(W_{in} x_t + b_{in} + W_{hn} ( r_t \odot h_{(t-1)} ) + b_{hn}) 1270 \end{aligned} 1271 1272 This is in contrast to PyTorch implementation, which is done after :math:`W_{hn} h_{(t-1)}` 1273 1274 .. math:: 1275 \begin{aligned} 1276 n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) 1277 \end{aligned} 1278 1279 This implementation differs on purpose for efficiency. 1280 1281 .. include:: ../cudnn_persistent_rnn.rst 1282 1283 Examples:: 1284 1285 >>> rnn = nn.GRU(10, 20, 2) 1286 >>> input = torch.randn(5, 3, 10) 1287 >>> h0 = torch.randn(2, 3, 20) 1288 >>> output, hn = rnn(input, h0) 1289 """ 1290 1291 @overload 1292 def __init__( 1293 self, 1294 input_size: int, 1295 hidden_size: int, 1296 num_layers: int = 1, 1297 bias: bool = True, 1298 batch_first: bool = False, 1299 dropout: float = 0.0, 1300 bidirectional: bool = False, 1301 device=None, 1302 dtype=None, 1303 ) -> None: 1304 ... 1305 1306 @overload 1307 def __init__(self, *args, **kwargs): 1308 ... 1309 1310 def __init__(self, *args, **kwargs): 1311 if "proj_size" in kwargs: 1312 raise ValueError( 1313 "proj_size argument is only supported for LSTM, not RNN or GRU" 1314 ) 1315 super().__init__("GRU", *args, **kwargs) 1316 1317 @overload # type: ignore[override] 1318 @torch._jit_internal._overload_method # noqa: F811 1319 def forward( 1320 self, input: Tensor, hx: Optional[Tensor] = None 1321 ) -> Tuple[Tensor, Tensor]: # noqa: F811 1322 pass 1323 1324 @overload 1325 @torch._jit_internal._overload_method # noqa: F811 1326 def forward( 1327 self, input: PackedSequence, hx: Optional[Tensor] = None 1328 ) -> Tuple[PackedSequence, Tensor]: # noqa: F811 1329 pass 1330 1331 def forward(self, input, hx=None): # noqa: F811 1332 self._update_flat_weights() 1333 1334 orig_input = input 1335 # xxx: isinstance check needs to be in conditional for TorchScript to compile 1336 if isinstance(orig_input, PackedSequence): 1337 input, batch_sizes, sorted_indices, unsorted_indices = input 1338 max_batch_size = batch_sizes[0] 1339 if hx is None: 1340 num_directions = 2 if self.bidirectional else 1 1341 hx = torch.zeros( 1342 self.num_layers * num_directions, 1343 max_batch_size, 1344 self.hidden_size, 1345 dtype=input.dtype, 1346 device=input.device, 1347 ) 1348 else: 1349 # Each batch of the hidden state should match the input sequence that 1350 # the user believes he/she is passing in. 1351 hx = self.permute_hidden(hx, sorted_indices) 1352 else: 1353 batch_sizes = None 1354 if input.dim() not in (2, 3): 1355 raise ValueError( 1356 f"GRU: Expected input to be 2D or 3D, got {input.dim()}D instead" 1357 ) 1358 is_batched = input.dim() == 3 1359 batch_dim = 0 if self.batch_first else 1 1360 if not is_batched: 1361 input = input.unsqueeze(batch_dim) 1362 if hx is not None: 1363 if hx.dim() != 2: 1364 raise RuntimeError( 1365 f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor" 1366 ) 1367 hx = hx.unsqueeze(1) 1368 else: 1369 if hx is not None and hx.dim() != 3: 1370 raise RuntimeError( 1371 f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor" 1372 ) 1373 max_batch_size = input.size(0) if self.batch_first else input.size(1) 1374 sorted_indices = None 1375 unsorted_indices = None 1376 if hx is None: 1377 num_directions = 2 if self.bidirectional else 1 1378 hx = torch.zeros( 1379 self.num_layers * num_directions, 1380 max_batch_size, 1381 self.hidden_size, 1382 dtype=input.dtype, 1383 device=input.device, 1384 ) 1385 else: 1386 # Each batch of the hidden state should match the input sequence that 1387 # the user believes he/she is passing in. 1388 hx = self.permute_hidden(hx, sorted_indices) 1389 1390 self.check_forward_args(input, hx, batch_sizes) 1391 if batch_sizes is None: 1392 result = _VF.gru( 1393 input, 1394 hx, 1395 self._flat_weights, 1396 self.bias, 1397 self.num_layers, 1398 self.dropout, 1399 self.training, 1400 self.bidirectional, 1401 self.batch_first, 1402 ) 1403 else: 1404 result = _VF.gru( 1405 input, 1406 batch_sizes, 1407 hx, 1408 self._flat_weights, 1409 self.bias, 1410 self.num_layers, 1411 self.dropout, 1412 self.training, 1413 self.bidirectional, 1414 ) 1415 output = result[0] 1416 hidden = result[1] 1417 1418 # xxx: isinstance check needs to be in conditional for TorchScript to compile 1419 if isinstance(orig_input, PackedSequence): 1420 output_packed = PackedSequence( 1421 output, batch_sizes, sorted_indices, unsorted_indices 1422 ) 1423 return output_packed, self.permute_hidden(hidden, unsorted_indices) 1424 else: 1425 if not is_batched: # type: ignore[possibly-undefined] 1426 output = output.squeeze(batch_dim) # type: ignore[possibly-undefined] 1427 hidden = hidden.squeeze(1) 1428 1429 return output, self.permute_hidden(hidden, unsorted_indices) 1430 1431 1432class RNNCellBase(Module): 1433 __constants__ = ["input_size", "hidden_size", "bias"] 1434 1435 input_size: int 1436 hidden_size: int 1437 bias: bool 1438 weight_ih: Tensor 1439 weight_hh: Tensor 1440 # WARNING: bias_ih and bias_hh purposely not defined here. 1441 # See https://github.com/pytorch/pytorch/issues/39670 1442 1443 def __init__( 1444 self, 1445 input_size: int, 1446 hidden_size: int, 1447 bias: bool, 1448 num_chunks: int, 1449 device=None, 1450 dtype=None, 1451 ) -> None: 1452 factory_kwargs = {"device": device, "dtype": dtype} 1453 super().__init__() 1454 self.input_size = input_size 1455 self.hidden_size = hidden_size 1456 self.bias = bias 1457 self.weight_ih = Parameter( 1458 torch.empty((num_chunks * hidden_size, input_size), **factory_kwargs) 1459 ) 1460 self.weight_hh = Parameter( 1461 torch.empty((num_chunks * hidden_size, hidden_size), **factory_kwargs) 1462 ) 1463 if bias: 1464 self.bias_ih = Parameter( 1465 torch.empty(num_chunks * hidden_size, **factory_kwargs) 1466 ) 1467 self.bias_hh = Parameter( 1468 torch.empty(num_chunks * hidden_size, **factory_kwargs) 1469 ) 1470 else: 1471 self.register_parameter("bias_ih", None) 1472 self.register_parameter("bias_hh", None) 1473 1474 self.reset_parameters() 1475 1476 def extra_repr(self) -> str: 1477 s = "{input_size}, {hidden_size}" 1478 if "bias" in self.__dict__ and self.bias is not True: 1479 s += ", bias={bias}" 1480 if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh": 1481 s += ", nonlinearity={nonlinearity}" 1482 return s.format(**self.__dict__) 1483 1484 def reset_parameters(self) -> None: 1485 stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0 1486 for weight in self.parameters(): 1487 init.uniform_(weight, -stdv, stdv) 1488 1489 1490class RNNCell(RNNCellBase): 1491 r"""An Elman RNN cell with tanh or ReLU non-linearity. 1492 1493 .. math:: 1494 1495 h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh}) 1496 1497 If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh. 1498 1499 Args: 1500 input_size: The number of expected features in the input `x` 1501 hidden_size: The number of features in the hidden state `h` 1502 bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. 1503 Default: ``True`` 1504 nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'`` 1505 1506 Inputs: input, hidden 1507 - **input**: tensor containing input features 1508 - **hidden**: tensor containing the initial hidden state 1509 Defaults to zero if not provided. 1510 1511 Outputs: h' 1512 - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state 1513 for each element in the batch 1514 1515 Shape: 1516 - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where 1517 :math:`H_{in}` = `input_size`. 1518 - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden 1519 state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided. 1520 - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state. 1521 1522 Attributes: 1523 weight_ih: the learnable input-hidden weights, of shape 1524 `(hidden_size, input_size)` 1525 weight_hh: the learnable hidden-hidden weights, of shape 1526 `(hidden_size, hidden_size)` 1527 bias_ih: the learnable input-hidden bias, of shape `(hidden_size)` 1528 bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)` 1529 1530 .. note:: 1531 All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` 1532 where :math:`k = \frac{1}{\text{hidden\_size}}` 1533 1534 Examples:: 1535 1536 >>> rnn = nn.RNNCell(10, 20) 1537 >>> input = torch.randn(6, 3, 10) 1538 >>> hx = torch.randn(3, 20) 1539 >>> output = [] 1540 >>> for i in range(6): 1541 ... hx = rnn(input[i], hx) 1542 ... output.append(hx) 1543 """ 1544 1545 __constants__ = ["input_size", "hidden_size", "bias", "nonlinearity"] 1546 nonlinearity: str 1547 1548 def __init__( 1549 self, 1550 input_size: int, 1551 hidden_size: int, 1552 bias: bool = True, 1553 nonlinearity: str = "tanh", 1554 device=None, 1555 dtype=None, 1556 ) -> None: 1557 factory_kwargs = {"device": device, "dtype": dtype} 1558 super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs) 1559 self.nonlinearity = nonlinearity 1560 1561 def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: 1562 if input.dim() not in (1, 2): 1563 raise ValueError( 1564 f"RNNCell: Expected input to be 1D or 2D, got {input.dim()}D instead" 1565 ) 1566 if hx is not None and hx.dim() not in (1, 2): 1567 raise ValueError( 1568 f"RNNCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead" 1569 ) 1570 is_batched = input.dim() == 2 1571 if not is_batched: 1572 input = input.unsqueeze(0) 1573 1574 if hx is None: 1575 hx = torch.zeros( 1576 input.size(0), self.hidden_size, dtype=input.dtype, device=input.device 1577 ) 1578 else: 1579 hx = hx.unsqueeze(0) if not is_batched else hx 1580 1581 if self.nonlinearity == "tanh": 1582 ret = _VF.rnn_tanh_cell( 1583 input, 1584 hx, 1585 self.weight_ih, 1586 self.weight_hh, 1587 self.bias_ih, 1588 self.bias_hh, 1589 ) 1590 elif self.nonlinearity == "relu": 1591 ret = _VF.rnn_relu_cell( 1592 input, 1593 hx, 1594 self.weight_ih, 1595 self.weight_hh, 1596 self.bias_ih, 1597 self.bias_hh, 1598 ) 1599 else: 1600 ret = input # TODO: remove when jit supports exception flow 1601 raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}") 1602 1603 if not is_batched: 1604 ret = ret.squeeze(0) 1605 1606 return ret 1607 1608 1609class LSTMCell(RNNCellBase): 1610 r"""A long short-term memory (LSTM) cell. 1611 1612 .. math:: 1613 1614 \begin{array}{ll} 1615 i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ 1616 f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ 1617 g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\ 1618 o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ 1619 c' = f \odot c + i \odot g \\ 1620 h' = o \odot \tanh(c') \\ 1621 \end{array} 1622 1623 where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. 1624 1625 Args: 1626 input_size: The number of expected features in the input `x` 1627 hidden_size: The number of features in the hidden state `h` 1628 bias: If ``False``, then the layer does not use bias weights `b_ih` and 1629 `b_hh`. Default: ``True`` 1630 1631 Inputs: input, (h_0, c_0) 1632 - **input** of shape `(batch, input_size)` or `(input_size)`: tensor containing input features 1633 - **h_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial hidden state 1634 - **c_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial cell state 1635 1636 If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero. 1637 1638 Outputs: (h_1, c_1) 1639 - **h_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next hidden state 1640 - **c_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next cell state 1641 1642 Attributes: 1643 weight_ih: the learnable input-hidden weights, of shape 1644 `(4*hidden_size, input_size)` 1645 weight_hh: the learnable hidden-hidden weights, of shape 1646 `(4*hidden_size, hidden_size)` 1647 bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)` 1648 bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)` 1649 1650 .. note:: 1651 All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` 1652 where :math:`k = \frac{1}{\text{hidden\_size}}` 1653 1654 On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward. 1655 1656 Examples:: 1657 1658 >>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size) 1659 >>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size) 1660 >>> hx = torch.randn(3, 20) # (batch, hidden_size) 1661 >>> cx = torch.randn(3, 20) 1662 >>> output = [] 1663 >>> for i in range(input.size()[0]): 1664 ... hx, cx = rnn(input[i], (hx, cx)) 1665 ... output.append(hx) 1666 >>> output = torch.stack(output, dim=0) 1667 """ 1668 1669 def __init__( 1670 self, 1671 input_size: int, 1672 hidden_size: int, 1673 bias: bool = True, 1674 device=None, 1675 dtype=None, 1676 ) -> None: 1677 factory_kwargs = {"device": device, "dtype": dtype} 1678 super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs) 1679 1680 def forward( 1681 self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None 1682 ) -> Tuple[Tensor, Tensor]: 1683 if input.dim() not in (1, 2): 1684 raise ValueError( 1685 f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead" 1686 ) 1687 if hx is not None: 1688 for idx, value in enumerate(hx): 1689 if value.dim() not in (1, 2): 1690 raise ValueError( 1691 f"LSTMCell: Expected hx[{idx}] to be 1D or 2D, got {value.dim()}D instead" 1692 ) 1693 is_batched = input.dim() == 2 1694 if not is_batched: 1695 input = input.unsqueeze(0) 1696 1697 if hx is None: 1698 zeros = torch.zeros( 1699 input.size(0), self.hidden_size, dtype=input.dtype, device=input.device 1700 ) 1701 hx = (zeros, zeros) 1702 else: 1703 hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx 1704 1705 ret = _VF.lstm_cell( 1706 input, 1707 hx, 1708 self.weight_ih, 1709 self.weight_hh, 1710 self.bias_ih, 1711 self.bias_hh, 1712 ) 1713 1714 if not is_batched: 1715 ret = (ret[0].squeeze(0), ret[1].squeeze(0)) 1716 return ret 1717 1718 1719class GRUCell(RNNCellBase): 1720 r"""A gated recurrent unit (GRU) cell. 1721 1722 .. math:: 1723 1724 \begin{array}{ll} 1725 r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ 1726 z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ 1727 n = \tanh(W_{in} x + b_{in} + r \odot (W_{hn} h + b_{hn})) \\ 1728 h' = (1 - z) \odot n + z \odot h 1729 \end{array} 1730 1731 where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. 1732 1733 Args: 1734 input_size: The number of expected features in the input `x` 1735 hidden_size: The number of features in the hidden state `h` 1736 bias: If ``False``, then the layer does not use bias weights `b_ih` and 1737 `b_hh`. Default: ``True`` 1738 1739 Inputs: input, hidden 1740 - **input** : tensor containing input features 1741 - **hidden** : tensor containing the initial hidden 1742 state for each element in the batch. 1743 Defaults to zero if not provided. 1744 1745 Outputs: h' 1746 - **h'** : tensor containing the next hidden state 1747 for each element in the batch 1748 1749 Shape: 1750 - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where 1751 :math:`H_{in}` = `input_size`. 1752 - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden 1753 state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided. 1754 - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state. 1755 1756 Attributes: 1757 weight_ih: the learnable input-hidden weights, of shape 1758 `(3*hidden_size, input_size)` 1759 weight_hh: the learnable hidden-hidden weights, of shape 1760 `(3*hidden_size, hidden_size)` 1761 bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)` 1762 bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)` 1763 1764 .. note:: 1765 All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` 1766 where :math:`k = \frac{1}{\text{hidden\_size}}` 1767 1768 On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward. 1769 1770 Examples:: 1771 1772 >>> rnn = nn.GRUCell(10, 20) 1773 >>> input = torch.randn(6, 3, 10) 1774 >>> hx = torch.randn(3, 20) 1775 >>> output = [] 1776 >>> for i in range(6): 1777 ... hx = rnn(input[i], hx) 1778 ... output.append(hx) 1779 """ 1780 1781 def __init__( 1782 self, 1783 input_size: int, 1784 hidden_size: int, 1785 bias: bool = True, 1786 device=None, 1787 dtype=None, 1788 ) -> None: 1789 factory_kwargs = {"device": device, "dtype": dtype} 1790 super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) 1791 1792 def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: 1793 if input.dim() not in (1, 2): 1794 raise ValueError( 1795 f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead" 1796 ) 1797 if hx is not None and hx.dim() not in (1, 2): 1798 raise ValueError( 1799 f"GRUCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead" 1800 ) 1801 is_batched = input.dim() == 2 1802 if not is_batched: 1803 input = input.unsqueeze(0) 1804 1805 if hx is None: 1806 hx = torch.zeros( 1807 input.size(0), self.hidden_size, dtype=input.dtype, device=input.device 1808 ) 1809 else: 1810 hx = hx.unsqueeze(0) if not is_batched else hx 1811 1812 ret = _VF.gru_cell( 1813 input, 1814 hx, 1815 self.weight_ih, 1816 self.weight_hh, 1817 self.bias_ih, 1818 self.bias_hh, 1819 ) 1820 1821 if not is_batched: 1822 ret = ret.squeeze(0) 1823 1824 return ret 1825