xref: /aosp_15_r20/external/pytorch/torch/ao/nn/quantizable/modules/rnn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2We will recreate all the RNN modules as we require the modules to be decomposed
3into its building blocks to be able to observe.
4"""
5
6# mypy: allow-untyped-defs
7
8import numbers
9import warnings
10from typing import Optional, Tuple
11
12import torch
13from torch import Tensor
14
15
16__all__ = ["LSTMCell", "LSTM"]
17
18
19class LSTMCell(torch.nn.Module):
20    r"""A quantizable long short-term memory (LSTM) cell.
21
22    For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell`
23
24    Examples::
25
26        >>> import torch.ao.nn.quantizable as nnqa
27        >>> rnn = nnqa.LSTMCell(10, 20)
28        >>> input = torch.randn(6, 10)
29        >>> hx = torch.randn(3, 20)
30        >>> cx = torch.randn(3, 20)
31        >>> output = []
32        >>> for i in range(6):
33        ...     hx, cx = rnn(input[i], (hx, cx))
34        ...     output.append(hx)
35    """
36    _FLOAT_MODULE = torch.nn.LSTMCell
37
38    def __init__(
39        self,
40        input_dim: int,
41        hidden_dim: int,
42        bias: bool = True,
43        device=None,
44        dtype=None,
45    ) -> None:
46        factory_kwargs = {"device": device, "dtype": dtype}
47        super().__init__()
48        self.input_size = input_dim
49        self.hidden_size = hidden_dim
50        self.bias = bias
51
52        self.igates = torch.nn.Linear(
53            input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
54        )
55        self.hgates = torch.nn.Linear(
56            hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
57        )
58        self.gates = torch.ao.nn.quantized.FloatFunctional()
59
60        self.input_gate = torch.nn.Sigmoid()
61        self.forget_gate = torch.nn.Sigmoid()
62        self.cell_gate = torch.nn.Tanh()
63        self.output_gate = torch.nn.Sigmoid()
64
65        self.fgate_cx = torch.ao.nn.quantized.FloatFunctional()
66        self.igate_cgate = torch.ao.nn.quantized.FloatFunctional()
67        self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional()
68
69        self.ogate_cy = torch.ao.nn.quantized.FloatFunctional()
70
71        self.initial_hidden_state_qparams: Tuple[float, int] = (1.0, 0)
72        self.initial_cell_state_qparams: Tuple[float, int] = (1.0, 0)
73        self.hidden_state_dtype: torch.dtype = torch.quint8
74        self.cell_state_dtype: torch.dtype = torch.quint8
75
76    def forward(
77        self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None
78    ) -> Tuple[Tensor, Tensor]:
79        if hidden is None or hidden[0] is None or hidden[1] is None:
80            hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
81        hx, cx = hidden
82
83        igates = self.igates(x)
84        hgates = self.hgates(hx)
85        gates = self.gates.add(igates, hgates)
86
87        input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)
88
89        input_gate = self.input_gate(input_gate)
90        forget_gate = self.forget_gate(forget_gate)
91        cell_gate = self.cell_gate(cell_gate)
92        out_gate = self.output_gate(out_gate)
93
94        fgate_cx = self.fgate_cx.mul(forget_gate, cx)
95        igate_cgate = self.igate_cgate.mul(input_gate, cell_gate)
96        fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate)
97        cy = fgate_cx_igate_cgate
98
99        # TODO: make this tanh a member of the module so its qparams can be configured
100        tanh_cy = torch.tanh(cy)
101        hy = self.ogate_cy.mul(out_gate, tanh_cy)
102        return hy, cy
103
104    def initialize_hidden(
105        self, batch_size: int, is_quantized: bool = False
106    ) -> Tuple[Tensor, Tensor]:
107        h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros(
108            (batch_size, self.hidden_size)
109        )
110        if is_quantized:
111            (h_scale, h_zp) = self.initial_hidden_state_qparams
112            (c_scale, c_zp) = self.initial_cell_state_qparams
113            h = torch.quantize_per_tensor(
114                h, scale=h_scale, zero_point=h_zp, dtype=self.hidden_state_dtype
115            )
116            c = torch.quantize_per_tensor(
117                c, scale=c_scale, zero_point=c_zp, dtype=self.cell_state_dtype
118            )
119        return h, c
120
121    def _get_name(self):
122        return "QuantizableLSTMCell"
123
124    @classmethod
125    def from_params(cls, wi, wh, bi=None, bh=None):
126        """Uses the weights and biases to create a new LSTM cell.
127
128        Args:
129            wi, wh: Weights for the input and hidden layers
130            bi, bh: Biases for the input and hidden layers
131        """
132        assert (bi is None) == (bh is None)  # Either both None or both have values
133        input_size = wi.shape[1]
134        hidden_size = wh.shape[1]
135        cell = cls(input_dim=input_size, hidden_dim=hidden_size, bias=(bi is not None))
136        cell.igates.weight = torch.nn.Parameter(wi)
137        if bi is not None:
138            cell.igates.bias = torch.nn.Parameter(bi)
139        cell.hgates.weight = torch.nn.Parameter(wh)
140        if bh is not None:
141            cell.hgates.bias = torch.nn.Parameter(bh)
142        return cell
143
144    @classmethod
145    def from_float(cls, other, use_precomputed_fake_quant=False):
146        assert type(other) == cls._FLOAT_MODULE
147        assert hasattr(other, "qconfig"), "The float module must have 'qconfig'"
148        observed = cls.from_params(
149            other.weight_ih, other.weight_hh, other.bias_ih, other.bias_hh
150        )
151        observed.qconfig = other.qconfig
152        observed.igates.qconfig = other.qconfig
153        observed.hgates.qconfig = other.qconfig
154        return observed
155
156
157class _LSTMSingleLayer(torch.nn.Module):
158    r"""A single one-directional LSTM layer.
159
160    The difference between a layer and a cell is that the layer can process a
161    sequence, while the cell only expects an instantaneous value.
162    """
163
164    def __init__(
165        self,
166        input_dim: int,
167        hidden_dim: int,
168        bias: bool = True,
169        device=None,
170        dtype=None,
171    ) -> None:
172        factory_kwargs = {"device": device, "dtype": dtype}
173        super().__init__()
174        self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs)
175
176    def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
177        result = []
178        seq_len = x.shape[0]
179        for i in range(seq_len):
180            hidden = self.cell(x[i], hidden)
181            result.append(hidden[0])  # type: ignore[index]
182        result_tensor = torch.stack(result, 0)
183        return result_tensor, hidden
184
185    @classmethod
186    def from_params(cls, *args, **kwargs):
187        cell = LSTMCell.from_params(*args, **kwargs)
188        layer = cls(cell.input_size, cell.hidden_size, cell.bias)
189        layer.cell = cell
190        return layer
191
192
193class _LSTMLayer(torch.nn.Module):
194    r"""A single bi-directional LSTM layer."""
195
196    def __init__(
197        self,
198        input_dim: int,
199        hidden_dim: int,
200        bias: bool = True,
201        batch_first: bool = False,
202        bidirectional: bool = False,
203        device=None,
204        dtype=None,
205    ) -> None:
206        factory_kwargs = {"device": device, "dtype": dtype}
207        super().__init__()
208        self.batch_first = batch_first
209        self.bidirectional = bidirectional
210        self.layer_fw = _LSTMSingleLayer(
211            input_dim, hidden_dim, bias=bias, **factory_kwargs
212        )
213        if self.bidirectional:
214            self.layer_bw = _LSTMSingleLayer(
215                input_dim, hidden_dim, bias=bias, **factory_kwargs
216            )
217
218    def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
219        if self.batch_first:
220            x = x.transpose(0, 1)
221        if hidden is None:
222            hx_fw, cx_fw = (None, None)
223        else:
224            hx_fw, cx_fw = hidden
225        hidden_bw: Optional[Tuple[Tensor, Tensor]] = None
226        if self.bidirectional:
227            if hx_fw is None:
228                hx_bw = None
229            else:
230                hx_bw = hx_fw[1]
231                hx_fw = hx_fw[0]
232            if cx_fw is None:
233                cx_bw = None
234            else:
235                cx_bw = cx_fw[1]
236                cx_fw = cx_fw[0]
237            if hx_bw is not None and cx_bw is not None:
238                hidden_bw = hx_bw, cx_bw
239        if hx_fw is None and cx_fw is None:
240            hidden_fw = None
241        else:
242            hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(
243                cx_fw
244            )
245        result_fw, hidden_fw = self.layer_fw(x, hidden_fw)
246
247        if hasattr(self, "layer_bw") and self.bidirectional:
248            x_reversed = x.flip(0)
249            result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw)
250            result_bw = result_bw.flip(0)
251
252            result = torch.cat([result_fw, result_bw], result_fw.dim() - 1)
253            if hidden_fw is None and hidden_bw is None:
254                h = None
255                c = None
256            elif hidden_fw is None:
257                (h, c) = torch.jit._unwrap_optional(hidden_bw)
258            elif hidden_bw is None:
259                (h, c) = torch.jit._unwrap_optional(hidden_fw)
260            else:
261                h = torch.stack([hidden_fw[0], hidden_bw[0]], 0)  # type: ignore[list-item]
262                c = torch.stack([hidden_fw[1], hidden_bw[1]], 0)  # type: ignore[list-item]
263        else:
264            result = result_fw
265            h, c = torch.jit._unwrap_optional(hidden_fw)  # type: ignore[assignment]
266
267        if self.batch_first:
268            result.transpose_(0, 1)
269
270        return result, (h, c)
271
272    @classmethod
273    def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs):
274        r"""
275        There is no FP equivalent of this class. This function is here just to
276        mimic the behavior of the `prepare` within the `torch.ao.quantization`
277        flow.
278        """
279        assert hasattr(other, "qconfig") or (qconfig is not None)
280
281        input_size = kwargs.get("input_size", other.input_size)
282        hidden_size = kwargs.get("hidden_size", other.hidden_size)
283        bias = kwargs.get("bias", other.bias)
284        batch_first = kwargs.get("batch_first", other.batch_first)
285        bidirectional = kwargs.get("bidirectional", other.bidirectional)
286
287        layer = cls(input_size, hidden_size, bias, batch_first, bidirectional)
288        layer.qconfig = getattr(other, "qconfig", qconfig)
289        wi = getattr(other, f"weight_ih_l{layer_idx}")
290        wh = getattr(other, f"weight_hh_l{layer_idx}")
291        bi = getattr(other, f"bias_ih_l{layer_idx}", None)
292        bh = getattr(other, f"bias_hh_l{layer_idx}", None)
293
294        layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
295
296        if other.bidirectional:
297            wi = getattr(other, f"weight_ih_l{layer_idx}_reverse")
298            wh = getattr(other, f"weight_hh_l{layer_idx}_reverse")
299            bi = getattr(other, f"bias_ih_l{layer_idx}_reverse", None)
300            bh = getattr(other, f"bias_hh_l{layer_idx}_reverse", None)
301            layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
302        return layer
303
304
305class LSTM(torch.nn.Module):
306    r"""A quantizable long short-term memory (LSTM).
307
308    For the description and the argument types, please, refer to :class:`~torch.nn.LSTM`
309
310    Attributes:
311        layers : instances of the `_LSTMLayer`
312
313    .. note::
314        To access the weights and biases, you need to access them per layer.
315        See examples below.
316
317    Examples::
318
319        >>> import torch.ao.nn.quantizable as nnqa
320        >>> rnn = nnqa.LSTM(10, 20, 2)
321        >>> input = torch.randn(5, 3, 10)
322        >>> h0 = torch.randn(2, 3, 20)
323        >>> c0 = torch.randn(2, 3, 20)
324        >>> output, (hn, cn) = rnn(input, (h0, c0))
325        >>> # To get the weights:
326        >>> # xdoctest: +SKIP
327        >>> print(rnn.layers[0].weight_ih)
328        tensor([[...]])
329        >>> print(rnn.layers[0].weight_hh)
330        AssertionError: There is no reverse path in the non-bidirectional layer
331    """
332    _FLOAT_MODULE = torch.nn.LSTM
333
334    def __init__(
335        self,
336        input_size: int,
337        hidden_size: int,
338        num_layers: int = 1,
339        bias: bool = True,
340        batch_first: bool = False,
341        dropout: float = 0.0,
342        bidirectional: bool = False,
343        device=None,
344        dtype=None,
345    ) -> None:
346        factory_kwargs = {"device": device, "dtype": dtype}
347        super().__init__()
348        self.input_size = input_size
349        self.hidden_size = hidden_size
350        self.num_layers = num_layers
351        self.bias = bias
352        self.batch_first = batch_first
353        self.dropout = float(dropout)
354        self.bidirectional = bidirectional
355        self.training = False  # Default to eval mode. If we want to train, we will explicitly set to training.
356        num_directions = 2 if bidirectional else 1
357
358        if (
359            not isinstance(dropout, numbers.Number)
360            or not 0 <= dropout <= 1
361            or isinstance(dropout, bool)
362        ):
363            raise ValueError(
364                "dropout should be a number in range [0, 1] "
365                "representing the probability of an element being "
366                "zeroed"
367            )
368        if dropout > 0:
369            warnings.warn(
370                "dropout option for quantizable LSTM is ignored. "
371                "If you are training, please, use nn.LSTM version "
372                "followed by `prepare` step."
373            )
374            if num_layers == 1:
375                warnings.warn(
376                    "dropout option adds dropout after all but last "
377                    "recurrent layer, so non-zero dropout expects "
378                    f"num_layers greater than 1, but got dropout={dropout} "
379                    f"and num_layers={num_layers}"
380                )
381
382        layers = [
383            _LSTMLayer(
384                self.input_size,
385                self.hidden_size,
386                self.bias,
387                batch_first=False,
388                bidirectional=self.bidirectional,
389                **factory_kwargs,
390            )
391        ]
392        for layer in range(1, num_layers):
393            layers.append(
394                _LSTMLayer(
395                    self.hidden_size,
396                    self.hidden_size,
397                    self.bias,
398                    batch_first=False,
399                    bidirectional=self.bidirectional,
400                    **factory_kwargs,
401                )
402            )
403        self.layers = torch.nn.ModuleList(layers)
404
405    def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
406        if self.batch_first:
407            x = x.transpose(0, 1)
408
409        max_batch_size = x.size(1)
410        num_directions = 2 if self.bidirectional else 1
411        if hidden is None:
412            zeros = torch.zeros(
413                num_directions,
414                max_batch_size,
415                self.hidden_size,
416                dtype=torch.float,
417                device=x.device,
418            )
419            zeros.squeeze_(0)
420            if x.is_quantized:
421                zeros = torch.quantize_per_tensor(
422                    zeros, scale=1.0, zero_point=0, dtype=x.dtype
423                )
424            hxcx = [(zeros, zeros) for _ in range(self.num_layers)]
425        else:
426            hidden_non_opt = torch.jit._unwrap_optional(hidden)
427            if isinstance(hidden_non_opt[0], Tensor):
428                hx = hidden_non_opt[0].reshape(
429                    self.num_layers, num_directions, max_batch_size, self.hidden_size
430                )
431                cx = hidden_non_opt[1].reshape(
432                    self.num_layers, num_directions, max_batch_size, self.hidden_size
433                )
434                hxcx = [
435                    (hx[idx].squeeze(0), cx[idx].squeeze(0))
436                    for idx in range(self.num_layers)
437                ]
438            else:
439                hxcx = hidden_non_opt
440
441        hx_list = []
442        cx_list = []
443        for idx, layer in enumerate(self.layers):
444            x, (h, c) = layer(x, hxcx[idx])
445            hx_list.append(torch.jit._unwrap_optional(h))
446            cx_list.append(torch.jit._unwrap_optional(c))
447        hx_tensor = torch.stack(hx_list)
448        cx_tensor = torch.stack(cx_list)
449
450        # We are creating another dimension for bidirectional case
451        # need to collapse it
452        hx_tensor = hx_tensor.reshape(-1, hx_tensor.shape[-2], hx_tensor.shape[-1])
453        cx_tensor = cx_tensor.reshape(-1, cx_tensor.shape[-2], cx_tensor.shape[-1])
454
455        if self.batch_first:
456            x = x.transpose(0, 1)
457
458        return x, (hx_tensor, cx_tensor)
459
460    def _get_name(self):
461        return "QuantizableLSTM"
462
463    @classmethod
464    def from_float(cls, other, qconfig=None):
465        assert isinstance(other, cls._FLOAT_MODULE)
466        assert hasattr(other, "qconfig") or qconfig
467        observed = cls(
468            other.input_size,
469            other.hidden_size,
470            other.num_layers,
471            other.bias,
472            other.batch_first,
473            other.dropout,
474            other.bidirectional,
475        )
476        observed.qconfig = getattr(other, "qconfig", qconfig)
477        for idx in range(other.num_layers):
478            observed.layers[idx] = _LSTMLayer.from_float(
479                other, idx, qconfig, batch_first=False
480            )
481
482        # Prepare the model
483        if other.training:
484            observed.train()
485            observed = torch.ao.quantization.prepare_qat(observed, inplace=True)
486        else:
487            observed.eval()
488            observed = torch.ao.quantization.prepare(observed, inplace=True)
489        return observed
490
491    @classmethod
492    def from_observed(cls, other):
493        # The whole flow is float -> observed -> quantized
494        # This class does float -> observed only
495        raise NotImplementedError(
496            "It looks like you are trying to convert a "
497            "non-quantizable LSTM module. Please, see "
498            "the examples on quantizable LSTMs."
499        )
500