xref: /aosp_15_r20/external/pytorch/torch/nn/modules/rnn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import math
4import numbers
5import warnings
6import weakref
7from typing import List, Optional, overload, Tuple
8from typing_extensions import deprecated
9
10import torch
11from torch import _VF, Tensor
12from torch.nn import init
13from torch.nn.parameter import Parameter
14from torch.nn.utils.rnn import PackedSequence
15
16from .module import Module
17
18
19__all__ = [
20    "RNNBase",
21    "RNN",
22    "LSTM",
23    "GRU",
24    "RNNCellBase",
25    "RNNCell",
26    "LSTMCell",
27    "GRUCell",
28]
29
30_rnn_impls = {
31    "RNN_TANH": _VF.rnn_tanh,
32    "RNN_RELU": _VF.rnn_relu,
33}
34
35
36def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
37    return tensor.index_select(dim, permutation)
38
39
40@deprecated(
41    "`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead",
42    category=FutureWarning,
43)
44def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
45    return _apply_permutation(tensor, permutation, dim)
46
47
48class RNNBase(Module):
49    r"""Base class for RNN modules (RNN, LSTM, GRU).
50
51    Implements aspects of RNNs shared by the RNN, LSTM, and GRU classes, such as module initialization
52    and utility methods for parameter storage management.
53
54    .. note::
55        The forward method is not implemented by the RNNBase class.
56
57    .. note::
58        LSTM and GRU classes override some methods implemented by RNNBase.
59    """
60
61    __constants__ = [
62        "mode",
63        "input_size",
64        "hidden_size",
65        "num_layers",
66        "bias",
67        "batch_first",
68        "dropout",
69        "bidirectional",
70        "proj_size",
71    ]
72    __jit_unused_properties__ = ["all_weights"]
73
74    mode: str
75    input_size: int
76    hidden_size: int
77    num_layers: int
78    bias: bool
79    batch_first: bool
80    dropout: float
81    bidirectional: bool
82    proj_size: int
83
84    def __init__(
85        self,
86        mode: str,
87        input_size: int,
88        hidden_size: int,
89        num_layers: int = 1,
90        bias: bool = True,
91        batch_first: bool = False,
92        dropout: float = 0.0,
93        bidirectional: bool = False,
94        proj_size: int = 0,
95        device=None,
96        dtype=None,
97    ) -> None:
98        factory_kwargs = {"device": device, "dtype": dtype}
99        super().__init__()
100        self.mode = mode
101        self.input_size = input_size
102        self.hidden_size = hidden_size
103        self.num_layers = num_layers
104        self.bias = bias
105        self.batch_first = batch_first
106        self.dropout = float(dropout)
107        self.bidirectional = bidirectional
108        self.proj_size = proj_size
109        self._flat_weight_refs: List[Optional[weakref.ReferenceType[Parameter]]] = []
110        num_directions = 2 if bidirectional else 1
111
112        if (
113            not isinstance(dropout, numbers.Number)
114            or not 0 <= dropout <= 1
115            or isinstance(dropout, bool)
116        ):
117            raise ValueError(
118                "dropout should be a number in range [0, 1] "
119                "representing the probability of an element being "
120                "zeroed"
121            )
122        if dropout > 0 and num_layers == 1:
123            warnings.warn(
124                "dropout option adds dropout after all but last "
125                "recurrent layer, so non-zero dropout expects "
126                f"num_layers greater than 1, but got dropout={dropout} and "
127                f"num_layers={num_layers}"
128            )
129
130        if not isinstance(hidden_size, int):
131            raise TypeError(
132                f"hidden_size should be of type int, got: {type(hidden_size).__name__}"
133            )
134        if hidden_size <= 0:
135            raise ValueError("hidden_size must be greater than zero")
136        if num_layers <= 0:
137            raise ValueError("num_layers must be greater than zero")
138        if proj_size < 0:
139            raise ValueError(
140                "proj_size should be a positive integer or zero to disable projections"
141            )
142        if proj_size >= hidden_size:
143            raise ValueError("proj_size has to be smaller than hidden_size")
144
145        if mode == "LSTM":
146            gate_size = 4 * hidden_size
147        elif mode == "GRU":
148            gate_size = 3 * hidden_size
149        elif mode == "RNN_TANH":
150            gate_size = hidden_size
151        elif mode == "RNN_RELU":
152            gate_size = hidden_size
153        else:
154            raise ValueError("Unrecognized RNN mode: " + mode)
155
156        self._flat_weights_names = []
157        self._all_weights = []
158        for layer in range(num_layers):
159            for direction in range(num_directions):
160                real_hidden_size = proj_size if proj_size > 0 else hidden_size
161                layer_input_size = (
162                    input_size if layer == 0 else real_hidden_size * num_directions
163                )
164
165                w_ih = Parameter(
166                    torch.empty((gate_size, layer_input_size), **factory_kwargs)
167                )
168                w_hh = Parameter(
169                    torch.empty((gate_size, real_hidden_size), **factory_kwargs)
170                )
171                b_ih = Parameter(torch.empty(gate_size, **factory_kwargs))
172                # Second bias vector included for CuDNN compatibility. Only one
173                # bias vector is needed in standard definition.
174                b_hh = Parameter(torch.empty(gate_size, **factory_kwargs))
175                layer_params: Tuple[Tensor, ...] = ()
176                if self.proj_size == 0:
177                    if bias:
178                        layer_params = (w_ih, w_hh, b_ih, b_hh)
179                    else:
180                        layer_params = (w_ih, w_hh)
181                else:
182                    w_hr = Parameter(
183                        torch.empty((proj_size, hidden_size), **factory_kwargs)
184                    )
185                    if bias:
186                        layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr)
187                    else:
188                        layer_params = (w_ih, w_hh, w_hr)
189
190                suffix = "_reverse" if direction == 1 else ""
191                param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"]
192                if bias:
193                    param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"]
194                if self.proj_size > 0:
195                    param_names += ["weight_hr_l{}{}"]
196                param_names = [x.format(layer, suffix) for x in param_names]
197
198                for name, param in zip(param_names, layer_params):
199                    setattr(self, name, param)
200                self._flat_weights_names.extend(param_names)
201                self._all_weights.append(param_names)
202
203        self._init_flat_weights()
204
205        self.reset_parameters()
206
207    def _init_flat_weights(self):
208        self._flat_weights = [
209            getattr(self, wn) if hasattr(self, wn) else None
210            for wn in self._flat_weights_names
211        ]
212        self._flat_weight_refs = [
213            weakref.ref(w) if w is not None else None for w in self._flat_weights
214        ]
215        self.flatten_parameters()
216
217    def __setattr__(self, attr, value):
218        if hasattr(self, "_flat_weights_names") and attr in self._flat_weights_names:
219            # keep self._flat_weights up to date if you do self.weight = ...
220            idx = self._flat_weights_names.index(attr)
221            self._flat_weights[idx] = value
222        super().__setattr__(attr, value)
223
224    def flatten_parameters(self) -> None:
225        """Reset parameter data pointer so that they can use faster code paths.
226
227        Right now, this works only if the module is on the GPU and cuDNN is enabled.
228        Otherwise, it's a no-op.
229        """
230        # Short-circuits if _flat_weights is only partially instantiated
231        if len(self._flat_weights) != len(self._flat_weights_names):
232            return
233
234        for w in self._flat_weights:
235            if not isinstance(w, Tensor):
236                return
237        # Short-circuits if any tensor in self._flat_weights is not acceptable to cuDNN
238        # or the tensors in _flat_weights are of different dtypes
239
240        first_fw = self._flat_weights[0]
241        dtype = first_fw.dtype
242        for fw in self._flat_weights:
243            if (
244                not isinstance(fw, Tensor)
245                or not (fw.dtype == dtype)
246                or not fw.is_cuda
247                or not torch.backends.cudnn.is_acceptable(fw)
248            ):
249                return
250
251        # If any parameters alias, we fall back to the slower, copying code path. This is
252        # a sufficient check, because overlapping parameter buffers that don't completely
253        # alias would break the assumptions of the uniqueness check in
254        # Module.named_parameters().
255        unique_data_ptrs = {p.data_ptr() for p in self._flat_weights}
256        if len(unique_data_ptrs) != len(self._flat_weights):
257            return
258
259        with torch.cuda.device_of(first_fw):
260            import torch.backends.cudnn.rnn as rnn
261
262            # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
263            # an inplace operation on self._flat_weights
264            with torch.no_grad():
265                if torch._use_cudnn_rnn_flatten_weight():
266                    num_weights = 4 if self.bias else 2
267                    if self.proj_size > 0:
268                        num_weights += 1
269                    torch._cudnn_rnn_flatten_weight(
270                        self._flat_weights,
271                        num_weights,
272                        self.input_size,
273                        rnn.get_cudnn_mode(self.mode),
274                        self.hidden_size,
275                        self.proj_size,
276                        self.num_layers,
277                        self.batch_first,
278                        bool(self.bidirectional),
279                    )
280
281    def _apply(self, fn, recurse=True):
282        self._flat_weight_refs = []
283        ret = super()._apply(fn, recurse)
284
285        # Resets _flat_weights
286        # Note: be v. careful before removing this, as 3rd party device types
287        # likely rely on this behavior to properly .to() modules like LSTM.
288        self._init_flat_weights()
289
290        return ret
291
292    def reset_parameters(self) -> None:
293        stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0
294        for weight in self.parameters():
295            init.uniform_(weight, -stdv, stdv)
296
297    def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
298        if not torch.jit.is_scripting():
299            if (
300                input.dtype != self._flat_weights[0].dtype
301                and not torch._C._is_any_autocast_enabled()
302            ):
303                raise ValueError(
304                    f"input must have the type {self._flat_weights[0].dtype}, got type {input.dtype}"
305                )
306        expected_input_dim = 2 if batch_sizes is not None else 3
307        if input.dim() != expected_input_dim:
308            raise RuntimeError(
309                f"input must have {expected_input_dim} dimensions, got {input.dim()}"
310            )
311        if self.input_size != input.size(-1):
312            raise RuntimeError(
313                f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}"
314            )
315
316    def get_expected_hidden_size(
317        self, input: Tensor, batch_sizes: Optional[Tensor]
318    ) -> Tuple[int, int, int]:
319        if batch_sizes is not None:
320            mini_batch = int(batch_sizes[0])
321        else:
322            mini_batch = input.size(0) if self.batch_first else input.size(1)
323        num_directions = 2 if self.bidirectional else 1
324        if self.proj_size > 0:
325            expected_hidden_size = (
326                self.num_layers * num_directions,
327                mini_batch,
328                self.proj_size,
329            )
330        else:
331            expected_hidden_size = (
332                self.num_layers * num_directions,
333                mini_batch,
334                self.hidden_size,
335            )
336        return expected_hidden_size
337
338    def check_hidden_size(
339        self,
340        hx: Tensor,
341        expected_hidden_size: Tuple[int, int, int],
342        msg: str = "Expected hidden size {}, got {}",
343    ) -> None:
344        if hx.size() != expected_hidden_size:
345            raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
346
347    def _weights_have_changed(self):
348        # Returns True if the weight tensors have changed since the last forward pass.
349        # This is the case when used with torch.func.functional_call(), for example.
350        weights_changed = False
351        for ref, name in zip(self._flat_weight_refs, self._flat_weights_names):
352            weight = getattr(self, name) if hasattr(self, name) else None
353            if weight is not None and ref is not None and ref() is not weight:
354                weights_changed = True
355                break
356        return weights_changed
357
358    def check_forward_args(
359        self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]
360    ):
361        self.check_input(input, batch_sizes)
362        expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
363
364        self.check_hidden_size(hidden, expected_hidden_size)
365
366    def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]):
367        if permutation is None:
368            return hx
369        return _apply_permutation(hx, permutation)
370
371    def extra_repr(self) -> str:
372        s = "{input_size}, {hidden_size}"
373        if self.proj_size != 0:
374            s += ", proj_size={proj_size}"
375        if self.num_layers != 1:
376            s += ", num_layers={num_layers}"
377        if self.bias is not True:
378            s += ", bias={bias}"
379        if self.batch_first is not False:
380            s += ", batch_first={batch_first}"
381        if self.dropout != 0:
382            s += ", dropout={dropout}"
383        if self.bidirectional is not False:
384            s += ", bidirectional={bidirectional}"
385        return s.format(**self.__dict__)
386
387    def _update_flat_weights(self):
388        if not torch.jit.is_scripting():
389            if self._weights_have_changed():
390                self._init_flat_weights()
391
392    def __getstate__(self):
393        # If weights have been changed, update the _flat_weights in __getstate__ here.
394        self._update_flat_weights()
395        # Don't serialize the weight references.
396        state = self.__dict__.copy()
397        del state["_flat_weight_refs"]
398        return state
399
400    def __setstate__(self, d):
401        super().__setstate__(d)
402        if "all_weights" in d:
403            self._all_weights = d["all_weights"]
404        # In PyTorch 1.8 we added a proj_size member variable to LSTM.
405        # LSTMs that were serialized via torch.save(module) before PyTorch 1.8
406        # don't have it, so to preserve compatibility we set proj_size here.
407        if "proj_size" not in d:
408            self.proj_size = 0
409
410        if not isinstance(self._all_weights[0][0], str):
411            num_layers = self.num_layers
412            num_directions = 2 if self.bidirectional else 1
413            self._flat_weights_names = []
414            self._all_weights = []
415            for layer in range(num_layers):
416                for direction in range(num_directions):
417                    suffix = "_reverse" if direction == 1 else ""
418                    weights = [
419                        "weight_ih_l{}{}",
420                        "weight_hh_l{}{}",
421                        "bias_ih_l{}{}",
422                        "bias_hh_l{}{}",
423                        "weight_hr_l{}{}",
424                    ]
425                    weights = [x.format(layer, suffix) for x in weights]
426                    if self.bias:
427                        if self.proj_size > 0:
428                            self._all_weights += [weights]
429                            self._flat_weights_names.extend(weights)
430                        else:
431                            self._all_weights += [weights[:4]]
432                            self._flat_weights_names.extend(weights[:4])
433                    else:
434                        if self.proj_size > 0:
435                            self._all_weights += [weights[:2]] + [weights[-1:]]
436                            self._flat_weights_names.extend(
437                                weights[:2] + [weights[-1:]]
438                            )
439                        else:
440                            self._all_weights += [weights[:2]]
441                            self._flat_weights_names.extend(weights[:2])
442            self._flat_weights = [
443                getattr(self, wn) if hasattr(self, wn) else None
444                for wn in self._flat_weights_names
445            ]
446
447        self._flat_weight_refs = [
448            weakref.ref(w) if w is not None else None for w in self._flat_weights
449        ]
450
451    @property
452    def all_weights(self) -> List[List[Parameter]]:
453        return [
454            [getattr(self, weight) for weight in weights]
455            for weights in self._all_weights
456        ]
457
458    def _replicate_for_data_parallel(self):
459        replica = super()._replicate_for_data_parallel()
460        # Need to copy these caches, otherwise the replica will share the same
461        # flat weights list.
462        replica._flat_weights = replica._flat_weights[:]
463        replica._flat_weights_names = replica._flat_weights_names[:]
464        return replica
465
466
467class RNN(RNNBase):
468    r"""__init__(input_size,hidden_size,num_layers=1,nonlinearity='tanh',bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None)
469
470    Apply a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}`
471    non-linearity to an input sequence. For each element in the input sequence,
472    each layer computes the following function:
473
474    .. math::
475        h_t = \tanh(x_t W_{ih}^T + b_{ih} + h_{t-1}W_{hh}^T + b_{hh})
476
477    where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is
478    the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the
479    previous layer at time `t-1` or the initial hidden state at time `0`.
480    If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`.
481
482    .. code-block:: python
483
484        # Efficient implementation equivalent to the following with bidirectional=False
485        def forward(x, h_0=None):
486            if batch_first:
487                x = x.transpose(0, 1)
488            seq_len, batch_size, _ = x.size()
489            if h_0 is None:
490                h_0 = torch.zeros(num_layers, batch_size, hidden_size)
491            h_t_minus_1 = h_0
492            h_t = h_0
493            output = []
494            for t in range(seq_len):
495                for layer in range(num_layers):
496                    h_t[layer] = torch.tanh(
497                        x[t] @ weight_ih[layer].T
498                        + bias_ih[layer]
499                        + h_t_minus_1[layer] @ weight_hh[layer].T
500                        + bias_hh[layer]
501                    )
502                output.append(h_t[-1])
503                h_t_minus_1 = h_t
504            output = torch.stack(output)
505            if batch_first:
506                output = output.transpose(0, 1)
507            return output, h_t
508
509    Args:
510        input_size: The number of expected features in the input `x`
511        hidden_size: The number of features in the hidden state `h`
512        num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
513            would mean stacking two RNNs together to form a `stacked RNN`,
514            with the second RNN taking in outputs of the first RNN and
515            computing the final results. Default: 1
516        nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
517        bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
518            Default: ``True``
519        batch_first: If ``True``, then the input and output tensors are provided
520            as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
521            Note that this does not apply to hidden or cell states. See the
522            Inputs/Outputs sections below for details.  Default: ``False``
523        dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
524            RNN layer except the last layer, with dropout probability equal to
525            :attr:`dropout`. Default: 0
526        bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False``
527
528    Inputs: input, h_0
529        * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
530          :math:`(L, N, H_{in})` when ``batch_first=False`` or
531          :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
532          the input sequence.  The input can also be a packed variable length sequence.
533          See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
534          :func:`torch.nn.utils.rnn.pack_sequence` for details.
535        * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
536          :math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden
537          state for the input sequence batch. Defaults to zeros if not provided.
538
539        where:
540
541        .. math::
542            \begin{aligned}
543                N ={} & \text{batch size} \\
544                L ={} & \text{sequence length} \\
545                D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
546                H_{in} ={} & \text{input\_size} \\
547                H_{out} ={} & \text{hidden\_size}
548            \end{aligned}
549
550    Outputs: output, h_n
551        * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
552          :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
553          :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
554          `(h_t)` from the last layer of the RNN, for each `t`. If a
555          :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
556          will also be a packed sequence.
557        * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
558          :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state
559          for each element in the batch.
560
561    Attributes:
562        weight_ih_l[k]: the learnable input-hidden weights of the k-th layer,
563            of shape `(hidden_size, input_size)` for `k = 0`. Otherwise, the shape is
564            `(hidden_size, num_directions * hidden_size)`
565        weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer,
566            of shape `(hidden_size, hidden_size)`
567        bias_ih_l[k]: the learnable input-hidden bias of the k-th layer,
568            of shape `(hidden_size)`
569        bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer,
570            of shape `(hidden_size)`
571
572    .. note::
573        All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
574        where :math:`k = \frac{1}{\text{hidden\_size}}`
575
576    .. note::
577        For bidirectional RNNs, forward and backward are directions 0 and 1 respectively.
578        Example of splitting the output layers when ``batch_first=False``:
579        ``output.view(seq_len, batch, num_directions, hidden_size)``.
580
581    .. note::
582        ``batch_first`` argument is ignored for unbatched inputs.
583
584    .. include:: ../cudnn_rnn_determinism.rst
585
586    .. include:: ../cudnn_persistent_rnn.rst
587
588    Examples::
589
590        >>> rnn = nn.RNN(10, 20, 2)
591        >>> input = torch.randn(5, 3, 10)
592        >>> h0 = torch.randn(2, 3, 20)
593        >>> output, hn = rnn(input, h0)
594    """
595
596    @overload
597    def __init__(
598        self,
599        input_size: int,
600        hidden_size: int,
601        num_layers: int = 1,
602        nonlinearity: str = "tanh",
603        bias: bool = True,
604        batch_first: bool = False,
605        dropout: float = 0.0,
606        bidirectional: bool = False,
607        device=None,
608        dtype=None,
609    ) -> None:
610        ...
611
612    @overload
613    def __init__(self, *args, **kwargs):
614        ...
615
616    def __init__(self, *args, **kwargs):
617        if "proj_size" in kwargs:
618            raise ValueError(
619                "proj_size argument is only supported for LSTM, not RNN or GRU"
620            )
621        if len(args) > 3:
622            self.nonlinearity = args[3]
623            args = args[:3] + args[4:]
624        else:
625            self.nonlinearity = kwargs.pop("nonlinearity", "tanh")
626        if self.nonlinearity == "tanh":
627            mode = "RNN_TANH"
628        elif self.nonlinearity == "relu":
629            mode = "RNN_RELU"
630        else:
631            raise ValueError(
632                f"Unknown nonlinearity '{self.nonlinearity}'. Select from 'tanh' or 'relu'."
633            )
634        super().__init__(mode, *args, **kwargs)
635
636    @overload
637    @torch._jit_internal._overload_method  # noqa: F811
638    def forward(
639        self, input: Tensor, hx: Optional[Tensor] = None
640    ) -> Tuple[Tensor, Tensor]:
641        pass
642
643    @overload
644    @torch._jit_internal._overload_method  # noqa: F811
645    def forward(
646        self, input: PackedSequence, hx: Optional[Tensor] = None
647    ) -> Tuple[PackedSequence, Tensor]:
648        pass
649
650    def forward(self, input, hx=None):  # noqa: F811
651        self._update_flat_weights()
652
653        num_directions = 2 if self.bidirectional else 1
654        orig_input = input
655
656        if isinstance(orig_input, PackedSequence):
657            input, batch_sizes, sorted_indices, unsorted_indices = input
658            max_batch_size = batch_sizes[0]
659            # script() is unhappy when max_batch_size is different type in cond branches, so we duplicate
660            if hx is None:
661                hx = torch.zeros(
662                    self.num_layers * num_directions,
663                    max_batch_size,
664                    self.hidden_size,
665                    dtype=input.dtype,
666                    device=input.device,
667                )
668            else:
669                # Each batch of the hidden state should match the input sequence that
670                # the user believes he/she is passing in.
671                hx = self.permute_hidden(hx, sorted_indices)
672        else:
673            batch_sizes = None
674            if input.dim() not in (2, 3):
675                raise ValueError(
676                    f"RNN: Expected input to be 2D or 3D, got {input.dim()}D tensor instead"
677                )
678            is_batched = input.dim() == 3
679            batch_dim = 0 if self.batch_first else 1
680            if not is_batched:
681                input = input.unsqueeze(batch_dim)
682                if hx is not None:
683                    if hx.dim() != 2:
684                        raise RuntimeError(
685                            f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor"
686                        )
687                    hx = hx.unsqueeze(1)
688            else:
689                if hx is not None and hx.dim() != 3:
690                    raise RuntimeError(
691                        f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor"
692                    )
693            max_batch_size = input.size(0) if self.batch_first else input.size(1)
694            sorted_indices = None
695            unsorted_indices = None
696            if hx is None:
697                hx = torch.zeros(
698                    self.num_layers * num_directions,
699                    max_batch_size,
700                    self.hidden_size,
701                    dtype=input.dtype,
702                    device=input.device,
703                )
704            else:
705                # Each batch of the hidden state should match the input sequence that
706                # the user believes he/she is passing in.
707                hx = self.permute_hidden(hx, sorted_indices)
708
709        assert hx is not None
710        self.check_forward_args(input, hx, batch_sizes)
711        assert self.mode == "RNN_TANH" or self.mode == "RNN_RELU"
712        if batch_sizes is None:
713            if self.mode == "RNN_TANH":
714                result = _VF.rnn_tanh(
715                    input,
716                    hx,
717                    self._flat_weights,
718                    self.bias,
719                    self.num_layers,
720                    self.dropout,
721                    self.training,
722                    self.bidirectional,
723                    self.batch_first,
724                )
725            else:
726                result = _VF.rnn_relu(
727                    input,
728                    hx,
729                    self._flat_weights,
730                    self.bias,
731                    self.num_layers,
732                    self.dropout,
733                    self.training,
734                    self.bidirectional,
735                    self.batch_first,
736                )
737        else:
738            if self.mode == "RNN_TANH":
739                result = _VF.rnn_tanh(
740                    input,
741                    batch_sizes,
742                    hx,
743                    self._flat_weights,
744                    self.bias,
745                    self.num_layers,
746                    self.dropout,
747                    self.training,
748                    self.bidirectional,
749                )
750            else:
751                result = _VF.rnn_relu(
752                    input,
753                    batch_sizes,
754                    hx,
755                    self._flat_weights,
756                    self.bias,
757                    self.num_layers,
758                    self.dropout,
759                    self.training,
760                    self.bidirectional,
761                )
762
763        output = result[0]
764        hidden = result[1]
765
766        if isinstance(orig_input, PackedSequence):
767            output_packed = PackedSequence(
768                output, batch_sizes, sorted_indices, unsorted_indices
769            )
770            return output_packed, self.permute_hidden(hidden, unsorted_indices)
771
772        if not is_batched:  # type: ignore[possibly-undefined]
773            output = output.squeeze(batch_dim)  # type: ignore[possibly-undefined]
774            hidden = hidden.squeeze(1)
775
776        return output, self.permute_hidden(hidden, unsorted_indices)
777
778
779# XXX: LSTM and GRU implementation is different from RNNBase, this is because:
780# 1. we want to support nn.LSTM and nn.GRU in TorchScript and TorchScript in
781#    its current state could not support the python Union Type or Any Type
782# 2. TorchScript static typing does not allow a Function or Callable type in
783#    Dict values, so we have to separately call _VF instead of using _rnn_impls
784# 3. This is temporary only and in the transition state that we want to make it
785#    on time for the release
786#
787# More discussion details in https://github.com/pytorch/pytorch/pull/23266
788#
789# TODO: remove the overriding implementations for LSTM and GRU when TorchScript
790# support expressing these two modules generally.
791
792
793class LSTM(RNNBase):
794    r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,proj_size=0,device=None,dtype=None)
795
796    Apply a multi-layer long short-term memory (LSTM) RNN to an input sequence.
797    For each element in the input sequence, each layer computes the following
798    function:
799
800    .. math::
801        \begin{array}{ll} \\
802            i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
803            f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
804            g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
805            o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
806            c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
807            h_t = o_t \odot \tanh(c_t) \\
808        \end{array}
809
810    where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
811    state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}`
812    is the hidden state of the layer at time `t-1` or the initial hidden
813    state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`,
814    :math:`o_t` are the input, forget, cell, and output gates, respectively.
815    :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
816
817    In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
818    (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
819    dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
820    variable which is :math:`0` with probability :attr:`dropout`.
821
822    If ``proj_size > 0`` is specified, LSTM with projections will be used. This changes
823    the LSTM cell in the following way. First, the dimension of :math:`h_t` will be changed from
824    ``hidden_size`` to ``proj_size`` (dimensions of :math:`W_{hi}` will be changed accordingly).
825    Second, the output hidden state of each layer will be multiplied by a learnable projection
826    matrix: :math:`h_t = W_{hr}h_t`. Note that as a consequence of this, the output
827    of LSTM network will be of different shape as well. See Inputs/Outputs sections below for exact
828    dimensions of all variables. You can find more details in https://arxiv.org/abs/1402.1128.
829
830    Args:
831        input_size: The number of expected features in the input `x`
832        hidden_size: The number of features in the hidden state `h`
833        num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
834            would mean stacking two LSTMs together to form a `stacked LSTM`,
835            with the second LSTM taking in outputs of the first LSTM and
836            computing the final results. Default: 1
837        bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
838            Default: ``True``
839        batch_first: If ``True``, then the input and output tensors are provided
840            as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
841            Note that this does not apply to hidden or cell states. See the
842            Inputs/Outputs sections below for details.  Default: ``False``
843        dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
844            LSTM layer except the last layer, with dropout probability equal to
845            :attr:`dropout`. Default: 0
846        bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False``
847        proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0
848
849    Inputs: input, (h_0, c_0)
850        * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
851          :math:`(L, N, H_{in})` when ``batch_first=False`` or
852          :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
853          the input sequence.  The input can also be a packed variable length sequence.
854          See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
855          :func:`torch.nn.utils.rnn.pack_sequence` for details.
856        * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
857          :math:`(D * \text{num\_layers}, N, H_{out})` containing the
858          initial hidden state for each element in the input sequence.
859          Defaults to zeros if (h_0, c_0) is not provided.
860        * **c_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
861          :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
862          initial cell state for each element in the input sequence.
863          Defaults to zeros if (h_0, c_0) is not provided.
864
865        where:
866
867        .. math::
868            \begin{aligned}
869                N ={} & \text{batch size} \\
870                L ={} & \text{sequence length} \\
871                D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
872                H_{in} ={} & \text{input\_size} \\
873                H_{cell} ={} & \text{hidden\_size} \\
874                H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\
875            \end{aligned}
876
877    Outputs: output, (h_n, c_n)
878        * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
879          :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
880          :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
881          `(h_t)` from the last layer of the LSTM, for each `t`. If a
882          :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
883          will also be a packed sequence. When ``bidirectional=True``, `output` will contain
884          a concatenation of the forward and reverse hidden states at each time step in the sequence.
885        * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or
886          :math:`(D * \text{num\_layers}, N, H_{out})` containing the
887          final hidden state for each element in the sequence. When ``bidirectional=True``,
888          `h_n` will contain a concatenation of the final forward and reverse hidden states, respectively.
889        * **c_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or
890          :math:`(D * \text{num\_layers}, N, H_{cell})` containing the
891          final cell state for each element in the sequence. When ``bidirectional=True``,
892          `c_n` will contain a concatenation of the final forward and reverse cell states, respectively.
893
894    Attributes:
895        weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
896            `(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`.
897            Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)`. If
898            ``proj_size > 0`` was specified, the shape will be
899            `(4*hidden_size, num_directions * proj_size)` for `k > 0`
900        weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
901            `(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`. If ``proj_size > 0``
902            was specified, the shape will be `(4*hidden_size, proj_size)`.
903        bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
904            `(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)`
905        bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
906            `(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)`
907        weight_hr_l[k] : the learnable projection weights of the :math:`\text{k}^{th}` layer
908            of shape `(proj_size, hidden_size)`. Only present when ``proj_size > 0`` was
909            specified.
910        weight_ih_l[k]_reverse: Analogous to `weight_ih_l[k]` for the reverse direction.
911            Only present when ``bidirectional=True``.
912        weight_hh_l[k]_reverse:  Analogous to `weight_hh_l[k]` for the reverse direction.
913            Only present when ``bidirectional=True``.
914        bias_ih_l[k]_reverse:  Analogous to `bias_ih_l[k]` for the reverse direction.
915            Only present when ``bidirectional=True``.
916        bias_hh_l[k]_reverse:  Analogous to `bias_hh_l[k]` for the reverse direction.
917            Only present when ``bidirectional=True``.
918        weight_hr_l[k]_reverse:  Analogous to `weight_hr_l[k]` for the reverse direction.
919            Only present when ``bidirectional=True`` and ``proj_size > 0`` was specified.
920
921    .. note::
922        All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
923        where :math:`k = \frac{1}{\text{hidden\_size}}`
924
925    .. note::
926        For bidirectional LSTMs, forward and backward are directions 0 and 1 respectively.
927        Example of splitting the output layers when ``batch_first=False``:
928        ``output.view(seq_len, batch, num_directions, hidden_size)``.
929
930    .. note::
931        For bidirectional LSTMs, `h_n` is not equivalent to the last element of `output`; the
932        former contains the final forward and reverse hidden states, while the latter contains the
933        final forward hidden state and the initial reverse hidden state.
934
935    .. note::
936        ``batch_first`` argument is ignored for unbatched inputs.
937
938    .. note::
939        ``proj_size`` should be smaller than ``hidden_size``.
940
941    .. include:: ../cudnn_rnn_determinism.rst
942
943    .. include:: ../cudnn_persistent_rnn.rst
944
945    Examples::
946
947        >>> rnn = nn.LSTM(10, 20, 2)
948        >>> input = torch.randn(5, 3, 10)
949        >>> h0 = torch.randn(2, 3, 20)
950        >>> c0 = torch.randn(2, 3, 20)
951        >>> output, (hn, cn) = rnn(input, (h0, c0))
952    """
953
954    @overload
955    def __init__(
956        self,
957        input_size: int,
958        hidden_size: int,
959        num_layers: int = 1,
960        bias: bool = True,
961        batch_first: bool = False,
962        dropout: float = 0.0,
963        bidirectional: bool = False,
964        proj_size: int = 0,
965        device=None,
966        dtype=None,
967    ) -> None:
968        ...
969
970    @overload
971    def __init__(self, *args, **kwargs):
972        ...
973
974    def __init__(self, *args, **kwargs):
975        super().__init__("LSTM", *args, **kwargs)
976
977    def get_expected_cell_size(
978        self, input: Tensor, batch_sizes: Optional[Tensor]
979    ) -> Tuple[int, int, int]:
980        if batch_sizes is not None:
981            mini_batch = int(batch_sizes[0])
982        else:
983            mini_batch = input.size(0) if self.batch_first else input.size(1)
984        num_directions = 2 if self.bidirectional else 1
985        expected_hidden_size = (
986            self.num_layers * num_directions,
987            mini_batch,
988            self.hidden_size,
989        )
990        return expected_hidden_size
991
992    # In the future, we should prevent mypy from applying contravariance rules here.
993    # See torch/nn/modules/module.py::_forward_unimplemented
994    def check_forward_args(
995        self,
996        input: Tensor,
997        hidden: Tuple[Tensor, Tensor],  # type: ignore[override]
998        batch_sizes: Optional[Tensor],
999    ):
1000        self.check_input(input, batch_sizes)
1001        self.check_hidden_size(
1002            hidden[0],
1003            self.get_expected_hidden_size(input, batch_sizes),
1004            "Expected hidden[0] size {}, got {}",
1005        )
1006        self.check_hidden_size(
1007            hidden[1],
1008            self.get_expected_cell_size(input, batch_sizes),
1009            "Expected hidden[1] size {}, got {}",
1010        )
1011
1012    # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
1013    def permute_hidden(  # type: ignore[override]
1014        self,
1015        hx: Tuple[Tensor, Tensor],
1016        permutation: Optional[Tensor],
1017    ) -> Tuple[Tensor, Tensor]:
1018        if permutation is None:
1019            return hx
1020        return _apply_permutation(hx[0], permutation), _apply_permutation(
1021            hx[1], permutation
1022        )
1023
1024    # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
1025    @overload  # type: ignore[override]
1026    @torch._jit_internal._overload_method  # noqa: F811
1027    def forward(
1028        self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
1029    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:  # noqa: F811
1030        pass
1031
1032    # Same as above, see torch/nn/modules/module.py::_forward_unimplemented
1033    @overload
1034    @torch._jit_internal._overload_method  # noqa: F811
1035    def forward(
1036        self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
1037    ) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]:  # noqa: F811
1038        pass
1039
1040    def forward(self, input, hx=None):  # noqa: F811
1041        self._update_flat_weights()
1042
1043        orig_input = input
1044        # xxx: isinstance check needs to be in conditional for TorchScript to compile
1045        batch_sizes = None
1046        do_permute = False
1047        num_directions = 2 if self.bidirectional else 1
1048        real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
1049        if isinstance(orig_input, PackedSequence):
1050            input, batch_sizes, sorted_indices, unsorted_indices = input
1051            max_batch_size = batch_sizes[0]
1052            if hx is None:
1053                h_zeros = torch.zeros(
1054                    self.num_layers * num_directions,
1055                    max_batch_size,
1056                    real_hidden_size,
1057                    dtype=input.dtype,
1058                    device=input.device,
1059                )
1060                c_zeros = torch.zeros(
1061                    self.num_layers * num_directions,
1062                    max_batch_size,
1063                    self.hidden_size,
1064                    dtype=input.dtype,
1065                    device=input.device,
1066                )
1067                hx = (h_zeros, c_zeros)
1068            else:
1069                # Each batch of the hidden state should match the input sequence that
1070                # the user believes he/she is passing in.
1071                hx = self.permute_hidden(hx, sorted_indices)
1072        else:
1073            if input.dim() not in (2, 3):
1074                raise ValueError(
1075                    f"LSTM: Expected input to be 2D or 3D, got {input.dim()}D instead"
1076                )
1077            is_batched = input.dim() == 3
1078            batch_dim = 0 if self.batch_first else 1
1079            if not is_batched:
1080                input = input.unsqueeze(batch_dim)
1081            max_batch_size = input.size(0) if self.batch_first else input.size(1)
1082            sorted_indices = None
1083            unsorted_indices = None
1084            if hx is None:
1085                h_zeros = torch.zeros(
1086                    self.num_layers * num_directions,
1087                    max_batch_size,
1088                    real_hidden_size,
1089                    dtype=input.dtype,
1090                    device=input.device,
1091                )
1092                c_zeros = torch.zeros(
1093                    self.num_layers * num_directions,
1094                    max_batch_size,
1095                    self.hidden_size,
1096                    dtype=input.dtype,
1097                    device=input.device,
1098                )
1099                hx = (h_zeros, c_zeros)
1100                self.check_forward_args(input, hx, batch_sizes)
1101            else:
1102                if is_batched:
1103                    if hx[0].dim() != 3 or hx[1].dim() != 3:
1104                        msg = (
1105                            "For batched 3-D input, hx and cx should "
1106                            f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors"
1107                        )
1108                        raise RuntimeError(msg)
1109                else:
1110                    if hx[0].dim() != 2 or hx[1].dim() != 2:
1111                        msg = (
1112                            "For unbatched 2-D input, hx and cx should "
1113                            f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors"
1114                        )
1115                        raise RuntimeError(msg)
1116                    hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
1117                # Each batch of the hidden state should match the input sequence that
1118                # the user believes he/she is passing in.
1119                self.check_forward_args(input, hx, batch_sizes)
1120                hx = self.permute_hidden(hx, sorted_indices)
1121
1122        if batch_sizes is None:
1123            result = _VF.lstm(
1124                input,
1125                hx,
1126                self._flat_weights,
1127                self.bias,
1128                self.num_layers,
1129                self.dropout,
1130                self.training,
1131                self.bidirectional,
1132                self.batch_first,
1133            )
1134        else:
1135            result = _VF.lstm(
1136                input,
1137                batch_sizes,
1138                hx,
1139                self._flat_weights,
1140                self.bias,
1141                self.num_layers,
1142                self.dropout,
1143                self.training,
1144                self.bidirectional,
1145            )
1146        output = result[0]
1147        hidden = result[1:]
1148        # xxx: isinstance check needs to be in conditional for TorchScript to compile
1149        if isinstance(orig_input, PackedSequence):
1150            output_packed = PackedSequence(
1151                output, batch_sizes, sorted_indices, unsorted_indices
1152            )
1153            return output_packed, self.permute_hidden(hidden, unsorted_indices)
1154        else:
1155            if not is_batched:  # type: ignore[possibly-undefined]
1156                output = output.squeeze(batch_dim)  # type: ignore[possibly-undefined]
1157                hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
1158            return output, self.permute_hidden(hidden, unsorted_indices)
1159
1160
1161class GRU(RNNBase):
1162    r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None)
1163
1164    Apply a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
1165    For each element in the input sequence, each layer computes the following
1166    function:
1167
1168    .. math::
1169        \begin{array}{ll}
1170            r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
1171            z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
1172            n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\
1173            h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)}
1174        \end{array}
1175
1176    where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input
1177    at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer
1178    at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`,
1179    :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively.
1180    :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
1181
1182    In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer
1183    (:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by
1184    dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random
1185    variable which is :math:`0` with probability :attr:`dropout`.
1186
1187    Args:
1188        input_size: The number of expected features in the input `x`
1189        hidden_size: The number of features in the hidden state `h`
1190        num_layers: Number of recurrent layers. E.g., setting ``num_layers=2``
1191            would mean stacking two GRUs together to form a `stacked GRU`,
1192            with the second GRU taking in outputs of the first GRU and
1193            computing the final results. Default: 1
1194        bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
1195            Default: ``True``
1196        batch_first: If ``True``, then the input and output tensors are provided
1197            as `(batch, seq, feature)` instead of `(seq, batch, feature)`.
1198            Note that this does not apply to hidden or cell states. See the
1199            Inputs/Outputs sections below for details.  Default: ``False``
1200        dropout: If non-zero, introduces a `Dropout` layer on the outputs of each
1201            GRU layer except the last layer, with dropout probability equal to
1202            :attr:`dropout`. Default: 0
1203        bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False``
1204
1205    Inputs: input, h_0
1206        * **input**: tensor of shape :math:`(L, H_{in})` for unbatched input,
1207          :math:`(L, N, H_{in})` when ``batch_first=False`` or
1208          :math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of
1209          the input sequence.  The input can also be a packed variable length sequence.
1210          See :func:`torch.nn.utils.rnn.pack_padded_sequence` or
1211          :func:`torch.nn.utils.rnn.pack_sequence` for details.
1212        * **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or
1213          :math:`(D * \text{num\_layers}, N, H_{out})`
1214          containing the initial hidden state for the input sequence. Defaults to zeros if not provided.
1215
1216        where:
1217
1218        .. math::
1219            \begin{aligned}
1220                N ={} & \text{batch size} \\
1221                L ={} & \text{sequence length} \\
1222                D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
1223                H_{in} ={} & \text{input\_size} \\
1224                H_{out} ={} & \text{hidden\_size}
1225            \end{aligned}
1226
1227    Outputs: output, h_n
1228        * **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input,
1229          :math:`(L, N, D * H_{out})` when ``batch_first=False`` or
1230          :math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features
1231          `(h_t)` from the last layer of the GRU, for each `t`. If a
1232          :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output
1233          will also be a packed sequence.
1234        * **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or
1235          :math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state
1236          for the input sequence.
1237
1238    Attributes:
1239        weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer
1240            (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`.
1241            Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)`
1242        weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer
1243            (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)`
1244        bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer
1245            (b_ir|b_iz|b_in), of shape `(3*hidden_size)`
1246        bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer
1247            (b_hr|b_hz|b_hn), of shape `(3*hidden_size)`
1248
1249    .. note::
1250        All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
1251        where :math:`k = \frac{1}{\text{hidden\_size}}`
1252
1253    .. note::
1254        For bidirectional GRUs, forward and backward are directions 0 and 1 respectively.
1255        Example of splitting the output layers when ``batch_first=False``:
1256        ``output.view(seq_len, batch, num_directions, hidden_size)``.
1257
1258    .. note::
1259        ``batch_first`` argument is ignored for unbatched inputs.
1260
1261    .. note::
1262        The calculation of new gate :math:`n_t` subtly differs from the original paper and other frameworks.
1263        In the original implementation, the Hadamard product :math:`(\odot)` between :math:`r_t` and the
1264        previous hidden state :math:`h_{(t-1)}` is done before the multiplication with the weight matrix
1265        `W` and addition of bias:
1266
1267        .. math::
1268            \begin{aligned}
1269                n_t = \tanh(W_{in} x_t + b_{in} + W_{hn} ( r_t \odot h_{(t-1)} ) + b_{hn})
1270            \end{aligned}
1271
1272        This is in contrast to PyTorch implementation, which is done after :math:`W_{hn} h_{(t-1)}`
1273
1274        .. math::
1275            \begin{aligned}
1276                n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn}))
1277            \end{aligned}
1278
1279        This implementation differs on purpose for efficiency.
1280
1281    .. include:: ../cudnn_persistent_rnn.rst
1282
1283    Examples::
1284
1285        >>> rnn = nn.GRU(10, 20, 2)
1286        >>> input = torch.randn(5, 3, 10)
1287        >>> h0 = torch.randn(2, 3, 20)
1288        >>> output, hn = rnn(input, h0)
1289    """
1290
1291    @overload
1292    def __init__(
1293        self,
1294        input_size: int,
1295        hidden_size: int,
1296        num_layers: int = 1,
1297        bias: bool = True,
1298        batch_first: bool = False,
1299        dropout: float = 0.0,
1300        bidirectional: bool = False,
1301        device=None,
1302        dtype=None,
1303    ) -> None:
1304        ...
1305
1306    @overload
1307    def __init__(self, *args, **kwargs):
1308        ...
1309
1310    def __init__(self, *args, **kwargs):
1311        if "proj_size" in kwargs:
1312            raise ValueError(
1313                "proj_size argument is only supported for LSTM, not RNN or GRU"
1314            )
1315        super().__init__("GRU", *args, **kwargs)
1316
1317    @overload  # type: ignore[override]
1318    @torch._jit_internal._overload_method  # noqa: F811
1319    def forward(
1320        self, input: Tensor, hx: Optional[Tensor] = None
1321    ) -> Tuple[Tensor, Tensor]:  # noqa: F811
1322        pass
1323
1324    @overload
1325    @torch._jit_internal._overload_method  # noqa: F811
1326    def forward(
1327        self, input: PackedSequence, hx: Optional[Tensor] = None
1328    ) -> Tuple[PackedSequence, Tensor]:  # noqa: F811
1329        pass
1330
1331    def forward(self, input, hx=None):  # noqa: F811
1332        self._update_flat_weights()
1333
1334        orig_input = input
1335        # xxx: isinstance check needs to be in conditional for TorchScript to compile
1336        if isinstance(orig_input, PackedSequence):
1337            input, batch_sizes, sorted_indices, unsorted_indices = input
1338            max_batch_size = batch_sizes[0]
1339            if hx is None:
1340                num_directions = 2 if self.bidirectional else 1
1341                hx = torch.zeros(
1342                    self.num_layers * num_directions,
1343                    max_batch_size,
1344                    self.hidden_size,
1345                    dtype=input.dtype,
1346                    device=input.device,
1347                )
1348            else:
1349                # Each batch of the hidden state should match the input sequence that
1350                # the user believes he/she is passing in.
1351                hx = self.permute_hidden(hx, sorted_indices)
1352        else:
1353            batch_sizes = None
1354            if input.dim() not in (2, 3):
1355                raise ValueError(
1356                    f"GRU: Expected input to be 2D or 3D, got {input.dim()}D instead"
1357                )
1358            is_batched = input.dim() == 3
1359            batch_dim = 0 if self.batch_first else 1
1360            if not is_batched:
1361                input = input.unsqueeze(batch_dim)
1362                if hx is not None:
1363                    if hx.dim() != 2:
1364                        raise RuntimeError(
1365                            f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor"
1366                        )
1367                    hx = hx.unsqueeze(1)
1368            else:
1369                if hx is not None and hx.dim() != 3:
1370                    raise RuntimeError(
1371                        f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor"
1372                    )
1373            max_batch_size = input.size(0) if self.batch_first else input.size(1)
1374            sorted_indices = None
1375            unsorted_indices = None
1376            if hx is None:
1377                num_directions = 2 if self.bidirectional else 1
1378                hx = torch.zeros(
1379                    self.num_layers * num_directions,
1380                    max_batch_size,
1381                    self.hidden_size,
1382                    dtype=input.dtype,
1383                    device=input.device,
1384                )
1385            else:
1386                # Each batch of the hidden state should match the input sequence that
1387                # the user believes he/she is passing in.
1388                hx = self.permute_hidden(hx, sorted_indices)
1389
1390        self.check_forward_args(input, hx, batch_sizes)
1391        if batch_sizes is None:
1392            result = _VF.gru(
1393                input,
1394                hx,
1395                self._flat_weights,
1396                self.bias,
1397                self.num_layers,
1398                self.dropout,
1399                self.training,
1400                self.bidirectional,
1401                self.batch_first,
1402            )
1403        else:
1404            result = _VF.gru(
1405                input,
1406                batch_sizes,
1407                hx,
1408                self._flat_weights,
1409                self.bias,
1410                self.num_layers,
1411                self.dropout,
1412                self.training,
1413                self.bidirectional,
1414            )
1415        output = result[0]
1416        hidden = result[1]
1417
1418        # xxx: isinstance check needs to be in conditional for TorchScript to compile
1419        if isinstance(orig_input, PackedSequence):
1420            output_packed = PackedSequence(
1421                output, batch_sizes, sorted_indices, unsorted_indices
1422            )
1423            return output_packed, self.permute_hidden(hidden, unsorted_indices)
1424        else:
1425            if not is_batched:  # type: ignore[possibly-undefined]
1426                output = output.squeeze(batch_dim)  # type: ignore[possibly-undefined]
1427                hidden = hidden.squeeze(1)
1428
1429            return output, self.permute_hidden(hidden, unsorted_indices)
1430
1431
1432class RNNCellBase(Module):
1433    __constants__ = ["input_size", "hidden_size", "bias"]
1434
1435    input_size: int
1436    hidden_size: int
1437    bias: bool
1438    weight_ih: Tensor
1439    weight_hh: Tensor
1440    # WARNING: bias_ih and bias_hh purposely not defined here.
1441    # See https://github.com/pytorch/pytorch/issues/39670
1442
1443    def __init__(
1444        self,
1445        input_size: int,
1446        hidden_size: int,
1447        bias: bool,
1448        num_chunks: int,
1449        device=None,
1450        dtype=None,
1451    ) -> None:
1452        factory_kwargs = {"device": device, "dtype": dtype}
1453        super().__init__()
1454        self.input_size = input_size
1455        self.hidden_size = hidden_size
1456        self.bias = bias
1457        self.weight_ih = Parameter(
1458            torch.empty((num_chunks * hidden_size, input_size), **factory_kwargs)
1459        )
1460        self.weight_hh = Parameter(
1461            torch.empty((num_chunks * hidden_size, hidden_size), **factory_kwargs)
1462        )
1463        if bias:
1464            self.bias_ih = Parameter(
1465                torch.empty(num_chunks * hidden_size, **factory_kwargs)
1466            )
1467            self.bias_hh = Parameter(
1468                torch.empty(num_chunks * hidden_size, **factory_kwargs)
1469            )
1470        else:
1471            self.register_parameter("bias_ih", None)
1472            self.register_parameter("bias_hh", None)
1473
1474        self.reset_parameters()
1475
1476    def extra_repr(self) -> str:
1477        s = "{input_size}, {hidden_size}"
1478        if "bias" in self.__dict__ and self.bias is not True:
1479            s += ", bias={bias}"
1480        if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh":
1481            s += ", nonlinearity={nonlinearity}"
1482        return s.format(**self.__dict__)
1483
1484    def reset_parameters(self) -> None:
1485        stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0
1486        for weight in self.parameters():
1487            init.uniform_(weight, -stdv, stdv)
1488
1489
1490class RNNCell(RNNCellBase):
1491    r"""An Elman RNN cell with tanh or ReLU non-linearity.
1492
1493    .. math::
1494
1495        h' = \tanh(W_{ih} x + b_{ih}  +  W_{hh} h + b_{hh})
1496
1497    If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh.
1498
1499    Args:
1500        input_size: The number of expected features in the input `x`
1501        hidden_size: The number of features in the hidden state `h`
1502        bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`.
1503            Default: ``True``
1504        nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'``
1505
1506    Inputs: input, hidden
1507        - **input**: tensor containing input features
1508        - **hidden**: tensor containing the initial hidden state
1509          Defaults to zero if not provided.
1510
1511    Outputs: h'
1512        - **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state
1513          for each element in the batch
1514
1515    Shape:
1516        - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where
1517          :math:`H_{in}` = `input_size`.
1518        - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden
1519          state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided.
1520        - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state.
1521
1522    Attributes:
1523        weight_ih: the learnable input-hidden weights, of shape
1524            `(hidden_size, input_size)`
1525        weight_hh: the learnable hidden-hidden weights, of shape
1526            `(hidden_size, hidden_size)`
1527        bias_ih: the learnable input-hidden bias, of shape `(hidden_size)`
1528        bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)`
1529
1530    .. note::
1531        All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
1532        where :math:`k = \frac{1}{\text{hidden\_size}}`
1533
1534    Examples::
1535
1536        >>> rnn = nn.RNNCell(10, 20)
1537        >>> input = torch.randn(6, 3, 10)
1538        >>> hx = torch.randn(3, 20)
1539        >>> output = []
1540        >>> for i in range(6):
1541        ...     hx = rnn(input[i], hx)
1542        ...     output.append(hx)
1543    """
1544
1545    __constants__ = ["input_size", "hidden_size", "bias", "nonlinearity"]
1546    nonlinearity: str
1547
1548    def __init__(
1549        self,
1550        input_size: int,
1551        hidden_size: int,
1552        bias: bool = True,
1553        nonlinearity: str = "tanh",
1554        device=None,
1555        dtype=None,
1556    ) -> None:
1557        factory_kwargs = {"device": device, "dtype": dtype}
1558        super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs)
1559        self.nonlinearity = nonlinearity
1560
1561    def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
1562        if input.dim() not in (1, 2):
1563            raise ValueError(
1564                f"RNNCell: Expected input to be 1D or 2D, got {input.dim()}D instead"
1565            )
1566        if hx is not None and hx.dim() not in (1, 2):
1567            raise ValueError(
1568                f"RNNCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead"
1569            )
1570        is_batched = input.dim() == 2
1571        if not is_batched:
1572            input = input.unsqueeze(0)
1573
1574        if hx is None:
1575            hx = torch.zeros(
1576                input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
1577            )
1578        else:
1579            hx = hx.unsqueeze(0) if not is_batched else hx
1580
1581        if self.nonlinearity == "tanh":
1582            ret = _VF.rnn_tanh_cell(
1583                input,
1584                hx,
1585                self.weight_ih,
1586                self.weight_hh,
1587                self.bias_ih,
1588                self.bias_hh,
1589            )
1590        elif self.nonlinearity == "relu":
1591            ret = _VF.rnn_relu_cell(
1592                input,
1593                hx,
1594                self.weight_ih,
1595                self.weight_hh,
1596                self.bias_ih,
1597                self.bias_hh,
1598            )
1599        else:
1600            ret = input  # TODO: remove when jit supports exception flow
1601            raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}")
1602
1603        if not is_batched:
1604            ret = ret.squeeze(0)
1605
1606        return ret
1607
1608
1609class LSTMCell(RNNCellBase):
1610    r"""A long short-term memory (LSTM) cell.
1611
1612    .. math::
1613
1614        \begin{array}{ll}
1615        i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
1616        f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
1617        g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
1618        o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
1619        c' = f \odot c + i \odot g \\
1620        h' = o \odot \tanh(c') \\
1621        \end{array}
1622
1623    where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
1624
1625    Args:
1626        input_size: The number of expected features in the input `x`
1627        hidden_size: The number of features in the hidden state `h`
1628        bias: If ``False``, then the layer does not use bias weights `b_ih` and
1629            `b_hh`. Default: ``True``
1630
1631    Inputs: input, (h_0, c_0)
1632        - **input** of shape `(batch, input_size)` or `(input_size)`: tensor containing input features
1633        - **h_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial hidden state
1634        - **c_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial cell state
1635
1636          If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero.
1637
1638    Outputs: (h_1, c_1)
1639        - **h_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next hidden state
1640        - **c_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next cell state
1641
1642    Attributes:
1643        weight_ih: the learnable input-hidden weights, of shape
1644            `(4*hidden_size, input_size)`
1645        weight_hh: the learnable hidden-hidden weights, of shape
1646            `(4*hidden_size, hidden_size)`
1647        bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)`
1648        bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)`
1649
1650    .. note::
1651        All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
1652        where :math:`k = \frac{1}{\text{hidden\_size}}`
1653
1654    On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
1655
1656    Examples::
1657
1658        >>> rnn = nn.LSTMCell(10, 20)  # (input_size, hidden_size)
1659        >>> input = torch.randn(2, 3, 10)  # (time_steps, batch, input_size)
1660        >>> hx = torch.randn(3, 20)  # (batch, hidden_size)
1661        >>> cx = torch.randn(3, 20)
1662        >>> output = []
1663        >>> for i in range(input.size()[0]):
1664        ...     hx, cx = rnn(input[i], (hx, cx))
1665        ...     output.append(hx)
1666        >>> output = torch.stack(output, dim=0)
1667    """
1668
1669    def __init__(
1670        self,
1671        input_size: int,
1672        hidden_size: int,
1673        bias: bool = True,
1674        device=None,
1675        dtype=None,
1676    ) -> None:
1677        factory_kwargs = {"device": device, "dtype": dtype}
1678        super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs)
1679
1680    def forward(
1681        self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
1682    ) -> Tuple[Tensor, Tensor]:
1683        if input.dim() not in (1, 2):
1684            raise ValueError(
1685                f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead"
1686            )
1687        if hx is not None:
1688            for idx, value in enumerate(hx):
1689                if value.dim() not in (1, 2):
1690                    raise ValueError(
1691                        f"LSTMCell: Expected hx[{idx}] to be 1D or 2D, got {value.dim()}D instead"
1692                    )
1693        is_batched = input.dim() == 2
1694        if not is_batched:
1695            input = input.unsqueeze(0)
1696
1697        if hx is None:
1698            zeros = torch.zeros(
1699                input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
1700            )
1701            hx = (zeros, zeros)
1702        else:
1703            hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx
1704
1705        ret = _VF.lstm_cell(
1706            input,
1707            hx,
1708            self.weight_ih,
1709            self.weight_hh,
1710            self.bias_ih,
1711            self.bias_hh,
1712        )
1713
1714        if not is_batched:
1715            ret = (ret[0].squeeze(0), ret[1].squeeze(0))
1716        return ret
1717
1718
1719class GRUCell(RNNCellBase):
1720    r"""A gated recurrent unit (GRU) cell.
1721
1722    .. math::
1723
1724        \begin{array}{ll}
1725        r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
1726        z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
1727        n = \tanh(W_{in} x + b_{in} + r \odot (W_{hn} h + b_{hn})) \\
1728        h' = (1 - z) \odot n + z \odot h
1729        \end{array}
1730
1731    where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product.
1732
1733    Args:
1734        input_size: The number of expected features in the input `x`
1735        hidden_size: The number of features in the hidden state `h`
1736        bias: If ``False``, then the layer does not use bias weights `b_ih` and
1737            `b_hh`. Default: ``True``
1738
1739    Inputs: input, hidden
1740        - **input** : tensor containing input features
1741        - **hidden** : tensor containing the initial hidden
1742          state for each element in the batch.
1743          Defaults to zero if not provided.
1744
1745    Outputs: h'
1746        - **h'** : tensor containing the next hidden state
1747          for each element in the batch
1748
1749    Shape:
1750        - input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where
1751          :math:`H_{in}` = `input_size`.
1752        - hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden
1753          state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided.
1754        - output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state.
1755
1756    Attributes:
1757        weight_ih: the learnable input-hidden weights, of shape
1758            `(3*hidden_size, input_size)`
1759        weight_hh: the learnable hidden-hidden weights, of shape
1760            `(3*hidden_size, hidden_size)`
1761        bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)`
1762        bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)`
1763
1764    .. note::
1765        All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`
1766        where :math:`k = \frac{1}{\text{hidden\_size}}`
1767
1768    On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
1769
1770    Examples::
1771
1772        >>> rnn = nn.GRUCell(10, 20)
1773        >>> input = torch.randn(6, 3, 10)
1774        >>> hx = torch.randn(3, 20)
1775        >>> output = []
1776        >>> for i in range(6):
1777        ...     hx = rnn(input[i], hx)
1778        ...     output.append(hx)
1779    """
1780
1781    def __init__(
1782        self,
1783        input_size: int,
1784        hidden_size: int,
1785        bias: bool = True,
1786        device=None,
1787        dtype=None,
1788    ) -> None:
1789        factory_kwargs = {"device": device, "dtype": dtype}
1790        super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs)
1791
1792    def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
1793        if input.dim() not in (1, 2):
1794            raise ValueError(
1795                f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead"
1796            )
1797        if hx is not None and hx.dim() not in (1, 2):
1798            raise ValueError(
1799                f"GRUCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead"
1800            )
1801        is_batched = input.dim() == 2
1802        if not is_batched:
1803            input = input.unsqueeze(0)
1804
1805        if hx is None:
1806            hx = torch.zeros(
1807                input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
1808            )
1809        else:
1810            hx = hx.unsqueeze(0) if not is_batched else hx
1811
1812        ret = _VF.gru_cell(
1813            input,
1814            hx,
1815            self.weight_ih,
1816            self.weight_hh,
1817            self.bias_ih,
1818            self.bias_hh,
1819        )
1820
1821        if not is_batched:
1822            ret = ret.squeeze(0)
1823
1824        return ret
1825