xref: /aosp_15_r20/external/pytorch/test/quantization/jit/test_deprecated_jit_quant.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3import torch
4from torch.testing._internal.common_quantization import skipIfNoFBGEMM
5from torch.testing._internal.jit_utils import JitTestCase
6
7
8class TestDeprecatedJitQuantized(JitTestCase):
9    @skipIfNoFBGEMM
10    def test_rnn_cell_quantized(self):
11        d_in, d_hid = 2, 2
12
13        for cell in [
14            torch.nn.LSTMCell(d_in, d_hid).float(),
15            torch.nn.GRUCell(d_in, d_hid).float(),
16            torch.nn.RNNCell(d_in, d_hid).float(),
17        ]:
18            if isinstance(cell, torch.nn.LSTMCell):
19                num_chunks = 4
20            elif isinstance(cell, torch.nn.GRUCell):
21                num_chunks = 3
22            elif isinstance(cell, torch.nn.RNNCell):
23                num_chunks = 1
24
25            # Replace parameter values s.t. the range of values is exactly
26            # 255, thus we will have 0 quantization error in the quantized
27            # GEMM call. This i s for testing purposes.
28            #
29            # Note that the current implementation does not support
30            # accumulation values outside of the range representable by a
31            # 16 bit integer, instead resulting in a saturated value. We
32            # must take care that in our test we do not end up with a dot
33            # product that overflows the int16 range, e.g.
34            # (255*127+255*127) = 64770. So, we hardcode the test values
35            # here and ensure a mix of signedness.
36            vals = [
37                [100, -155],
38                [100, -155],
39                [-155, 100],
40                [-155, 100],
41                [100, -155],
42                [-155, 100],
43                [-155, 100],
44                [100, -155],
45            ]
46            vals = vals[: d_hid * num_chunks]
47            cell.weight_ih = torch.nn.Parameter(
48                torch.tensor(vals, dtype=torch.float), requires_grad=False
49            )
50            cell.weight_hh = torch.nn.Parameter(
51                torch.tensor(vals, dtype=torch.float), requires_grad=False
52            )
53
54            with self.assertRaisesRegex(
55                RuntimeError,
56                "quantize_rnn_cell_modules function is no longer supported",
57            ):
58                cell = torch.jit.quantized.quantize_rnn_cell_modules(cell)
59
60    @skipIfNoFBGEMM
61    def test_rnn_quantized(self):
62        d_in, d_hid = 2, 2
63
64        for cell in [
65            torch.nn.LSTM(d_in, d_hid).float(),
66            torch.nn.GRU(d_in, d_hid).float(),
67        ]:
68            # Replace parameter values s.t. the range of values is exactly
69            # 255, thus we will have 0 quantization error in the quantized
70            # GEMM call. This i s for testing purposes.
71            #
72            # Note that the current implementation does not support
73            # accumulation values outside of the range representable by a
74            # 16 bit integer, instead resulting in a saturated value. We
75            # must take care that in our test we do not end up with a dot
76            # product that overflows the int16 range, e.g.
77            # (255*127+255*127) = 64770. So, we hardcode the test values
78            # here and ensure a mix of signedness.
79            vals = [
80                [100, -155],
81                [100, -155],
82                [-155, 100],
83                [-155, 100],
84                [100, -155],
85                [-155, 100],
86                [-155, 100],
87                [100, -155],
88            ]
89            if isinstance(cell, torch.nn.LSTM):
90                num_chunks = 4
91            elif isinstance(cell, torch.nn.GRU):
92                num_chunks = 3
93            vals = vals[: d_hid * num_chunks]
94            cell.weight_ih_l0 = torch.nn.Parameter(
95                torch.tensor(vals, dtype=torch.float), requires_grad=False
96            )
97            cell.weight_hh_l0 = torch.nn.Parameter(
98                torch.tensor(vals, dtype=torch.float), requires_grad=False
99            )
100
101            with self.assertRaisesRegex(
102                RuntimeError, "quantize_rnn_modules function is no longer supported"
103            ):
104                cell_int8 = torch.jit.quantized.quantize_rnn_modules(
105                    cell, dtype=torch.int8
106                )
107
108            with self.assertRaisesRegex(
109                RuntimeError, "quantize_rnn_modules function is no longer supported"
110            ):
111                cell_fp16 = torch.jit.quantized.quantize_rnn_modules(
112                    cell, dtype=torch.float16
113                )
114
115    if "fbgemm" in torch.backends.quantized.supported_engines:
116
117        def test_quantization_modules(self):
118            K1, N1 = 2, 2
119
120            class FooBar(torch.nn.Module):
121                def __init__(self) -> None:
122                    super().__init__()
123                    self.linear1 = torch.nn.Linear(K1, N1).float()
124
125                def forward(self, x):
126                    x = self.linear1(x)
127                    return x
128
129            fb = FooBar()
130            fb.linear1.weight = torch.nn.Parameter(
131                torch.tensor([[-150, 100], [100, -150]], dtype=torch.float),
132                requires_grad=False,
133            )
134            fb.linear1.bias = torch.nn.Parameter(
135                torch.zeros_like(fb.linear1.bias), requires_grad=False
136            )
137
138            x = (torch.rand(1, K1).float() - 0.5) / 10.0
139            value = torch.tensor([[100, -150]], dtype=torch.float)
140
141            y_ref = fb(value)
142
143            with self.assertRaisesRegex(
144                RuntimeError, "quantize_linear_modules function is no longer supported"
145            ):
146                fb_int8 = torch.jit.quantized.quantize_linear_modules(fb)
147
148            with self.assertRaisesRegex(
149                RuntimeError, "quantize_linear_modules function is no longer supported"
150            ):
151                fb_fp16 = torch.jit.quantized.quantize_linear_modules(fb, torch.float16)
152
153    @skipIfNoFBGEMM
154    def test_erase_class_tensor_shapes(self):
155        class Linear(torch.nn.Module):
156            def __init__(self, in_features, out_features):
157                super().__init__()
158                qweight = torch._empty_affine_quantized(
159                    [out_features, in_features],
160                    scale=1,
161                    zero_point=0,
162                    dtype=torch.qint8,
163                )
164                self._packed_weight = torch.ops.quantized.linear_prepack(qweight)
165
166            @torch.jit.export
167            def __getstate__(self):
168                return (
169                    torch.ops.quantized.linear_unpack(self._packed_weight)[0],
170                    self.training,
171                )
172
173            def forward(self):
174                return self._packed_weight
175
176            @torch.jit.export
177            def __setstate__(self, state):
178                self._packed_weight = torch.ops.quantized.linear_prepack(state[0])
179                self.training = state[1]
180
181            @property
182            def weight(self):
183                return torch.ops.quantized.linear_unpack(self._packed_weight)[0]
184
185            @weight.setter
186            def weight(self, w):
187                self._packed_weight = torch.ops.quantized.linear_prepack(w)
188
189        with torch._jit_internal._disable_emit_hooks():
190            x = torch.jit.script(Linear(10, 10))
191            torch._C._jit_pass_erase_shape_information(x.graph)
192
193
194if __name__ == "__main__":
195    raise RuntimeError(
196        "This test file is not meant to be run directly, use:\n\n"
197        "\tpython test/test_quantization.py TESTNAME\n\n"
198        "instead."
199    )
200