xref: /aosp_15_r20/external/pytorch/test/quantization/core/test_quantized_functional.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3# Torch
4import torch
5import torch.ao.nn.quantized.functional as qF
6import torch.nn.functional as F
7
8# Standard library
9import numpy as np
10
11# Testing utils
12from hypothesis import assume, given
13from hypothesis import strategies as st
14from torch.testing._internal.common_quantization import (
15    QuantizationTestCase,
16    _make_conv_test_input,
17)
18from torch.testing._internal.common_quantized import override_quantized_engine
19from torch.testing._internal.common_utils import (
20    IS_PPC,
21    TEST_WITH_UBSAN,
22)
23
24class TestQuantizedFunctionalOps(QuantizationTestCase):
25    def test_relu_api(self):
26        X = torch.arange(-5, 5, dtype=torch.float)
27        scale = 2.0
28        zero_point = 1
29        qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, dtype=torch.quint8)
30        qY = torch.relu(qX)
31        qY_hat = F.relu(qX)
32        self.assertEqual(qY, qY_hat)
33
34    def _test_conv_api_impl(
35        self, qconv_fn, conv_fn, batch_size, in_channels_per_group,
36        input_feature_map_size, out_channels_per_group, groups, kernel_size,
37        stride, padding, dilation, X_scale, X_zero_point, W_scale, W_zero_point,
38        Y_scale, Y_zero_point, use_bias, use_channelwise,
39    ):
40        for i in range(len(kernel_size)):
41            assume(input_feature_map_size[i] + 2 * padding[i]
42                   >= dilation[i] * (kernel_size[i] - 1) + 1)
43        (X, X_q, W, W_q, b) = _make_conv_test_input(
44            batch_size, in_channels_per_group, input_feature_map_size,
45            out_channels_per_group, groups, kernel_size, X_scale,
46            X_zero_point, W_scale, W_zero_point, use_bias, use_channelwise)
47
48        Y_exp = conv_fn(X, W, b, stride, padding, dilation, groups)
49        Y_exp = torch.quantize_per_tensor(
50            Y_exp, scale=Y_scale, zero_point=Y_zero_point, dtype=torch.quint8)
51        Y_act = qconv_fn(
52            X_q, W_q, b, stride, padding, dilation, groups,
53            padding_mode="zeros", scale=Y_scale, zero_point=Y_zero_point)
54
55        # Make sure the results match
56        # assert_array_almost_equal compares using the following formula:
57        #     abs(desired-actual) < 1.5 * 10**(-decimal)
58        # (https://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_almost_equal.html)
59        # We use decimal = 0 to ignore off-by-1 differences between reference
60        # and test. Off-by-1 differences arise due to the order of round and
61        # zero_point addition operation, i.e., if addition followed by round is
62        # used by reference and round followed by addition is used by test, the
63        # results may differ by 1.
64        # For example, the result of round(2.5) + 1 is 3 while round(2.5 + 1) is
65        # 4 assuming the rounding mode is round-to-nearest, ties-to-even.
66        np.testing.assert_array_almost_equal(
67            Y_exp.int_repr().numpy(), Y_act.int_repr().numpy(), decimal=0)
68
69    @given(batch_size=st.integers(1, 3),
70           in_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
71           L=st.integers(4, 16),
72           out_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
73           groups=st.integers(1, 4),
74           kernel=st.integers(1, 7),
75           stride=st.integers(1, 2),
76           pad=st.integers(0, 2),
77           dilation=st.integers(1, 2),
78           X_scale=st.floats(1.2, 1.6),
79           X_zero_point=st.integers(0, 4),
80           W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2),
81           W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2),
82           Y_scale=st.floats(4.2, 5.6),
83           Y_zero_point=st.integers(0, 4),
84           use_bias=st.booleans(),
85           use_channelwise=st.booleans(),
86           qengine=st.sampled_from(("qnnpack", "fbgemm")))
87    def test_conv1d_api(
88        self, batch_size, in_channels_per_group, L, out_channels_per_group,
89        groups, kernel, stride, pad, dilation,
90        X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point,
91        use_bias, use_channelwise, qengine,
92    ):
93        # Tests the correctness of the conv1d function.
94        if qengine not in torch.backends.quantized.supported_engines:
95            return
96        if qengine == 'qnnpack':
97            if IS_PPC or TEST_WITH_UBSAN:
98                return
99            use_channelwise = False
100
101        input_feature_map_size = (L, )
102        kernel_size = (kernel, )
103        stride = (stride, )
104        padding = (pad, )
105        dilation = (dilation, )
106
107        with override_quantized_engine(qengine):
108            qconv_fn = qF.conv1d
109            conv_fn = F.conv1d
110            self._test_conv_api_impl(
111                qconv_fn, conv_fn, batch_size, in_channels_per_group,
112                input_feature_map_size, out_channels_per_group, groups,
113                kernel_size, stride, padding, dilation, X_scale, X_zero_point,
114                W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias,
115                use_channelwise)
116
117    @given(batch_size=st.integers(1, 3),
118           in_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
119           H=st.integers(4, 16),
120           W=st.integers(4, 16),
121           out_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
122           groups=st.integers(1, 4),
123           kernel_h=st.integers(1, 7),
124           kernel_w=st.integers(1, 7),
125           stride_h=st.integers(1, 2),
126           stride_w=st.integers(1, 2),
127           pad_h=st.integers(0, 2),
128           pad_w=st.integers(0, 2),
129           dilation=st.integers(1, 2),
130           X_scale=st.floats(1.2, 1.6),
131           X_zero_point=st.integers(0, 4),
132           W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2),
133           W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2),
134           Y_scale=st.floats(4.2, 5.6),
135           Y_zero_point=st.integers(0, 4),
136           use_bias=st.booleans(),
137           use_channelwise=st.booleans(),
138           qengine=st.sampled_from(("qnnpack", "fbgemm")))
139    def test_conv2d_api(
140        self, batch_size, in_channels_per_group, H, W, out_channels_per_group,
141        groups, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, dilation,
142        X_scale, X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point,
143        use_bias, use_channelwise, qengine,
144    ):
145        # Tests the correctness of the conv2d function.
146
147        if qengine not in torch.backends.quantized.supported_engines:
148            return
149        if qengine == 'qnnpack':
150            if IS_PPC or TEST_WITH_UBSAN:
151                return
152
153        input_feature_map_size = (H, W)
154        kernel_size = (kernel_h, kernel_w)
155        stride = (stride_h, stride_w)
156        padding = (pad_h, pad_w)
157        dilation = (dilation, dilation)
158
159        with override_quantized_engine(qengine):
160            qconv_fn = qF.conv2d
161            conv_fn = F.conv2d
162            self._test_conv_api_impl(
163                qconv_fn, conv_fn, batch_size, in_channels_per_group,
164                input_feature_map_size, out_channels_per_group, groups,
165                kernel_size, stride, padding, dilation, X_scale, X_zero_point,
166                W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias,
167                use_channelwise)
168
169    @given(batch_size=st.integers(1, 3),
170           in_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
171           D=st.integers(4, 8),
172           H=st.integers(4, 8),
173           W=st.integers(4, 8),
174           out_channels_per_group=st.sampled_from([2, 4, 5, 8, 16, 32]),
175           groups=st.integers(1, 4),
176           kernel_d=st.integers(1, 4),
177           kernel_h=st.integers(1, 4),
178           kernel_w=st.integers(1, 4),
179           stride_d=st.integers(1, 2),
180           stride_h=st.integers(1, 2),
181           stride_w=st.integers(1, 2),
182           pad_d=st.integers(0, 2),
183           pad_h=st.integers(0, 2),
184           pad_w=st.integers(0, 2),
185           dilation=st.integers(1, 2),
186           X_scale=st.floats(1.2, 1.6),
187           X_zero_point=st.integers(0, 4),
188           W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2),
189           W_zero_point=st.lists(st.integers(-5, 5), min_size=1, max_size=2),
190           Y_scale=st.floats(4.2, 5.6),
191           Y_zero_point=st.integers(0, 4),
192           use_bias=st.booleans(),
193           use_channelwise=st.booleans(),
194           qengine=st.sampled_from(("fbgemm",)))
195    def test_conv3d_api(
196        self, batch_size, in_channels_per_group, D, H, W,
197        out_channels_per_group, groups, kernel_d, kernel_h, kernel_w,
198        stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, dilation, X_scale,
199        X_zero_point, W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias,
200        use_channelwise, qengine,
201    ):
202        # Tests the correctness of the conv3d function.
203        # Currently conv3d only supports FbGemm engine
204
205        if qengine not in torch.backends.quantized.supported_engines:
206            return
207
208        input_feature_map_size = (D, H, W)
209        kernel_size = (kernel_d, kernel_h, kernel_w)
210        stride = (stride_d, stride_h, stride_w)
211        padding = (pad_d, pad_h, pad_w)
212        dilation = (dilation, dilation, dilation)
213
214        with override_quantized_engine(qengine):
215            qconv_fn = qF.conv3d
216            conv_fn = F.conv3d
217            self._test_conv_api_impl(
218                qconv_fn, conv_fn, batch_size, in_channels_per_group,
219                input_feature_map_size, out_channels_per_group, groups,
220                kernel_size, stride, padding, dilation, X_scale, X_zero_point,
221                W_scale, W_zero_point, Y_scale, Y_zero_point, use_bias,
222                use_channelwise)
223
224    @given(N=st.integers(1, 10),
225           C=st.integers(1, 10),
226           H=st.integers(4, 8),
227           H_out=st.integers(4, 8),
228           W=st.integers(4, 8),
229           W_out=st.integers(4, 8),
230           scale=st.floats(.1, 2),
231           zero_point=st.integers(0, 4))
232    def test_grid_sample(self, N, C, H, H_out, W, W_out, scale, zero_point):
233        X = torch.rand(N, C, H, W)
234        X_q = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, dtype=torch.quint8)
235        grid = torch.rand(N, H_out, W_out, 2)
236
237        out = F.grid_sample(X_q, grid)
238        out_exp = torch.quantize_per_tensor(F.grid_sample(X, grid), scale=scale, zero_point=zero_point, dtype=torch.quint8)
239        np.testing.assert_array_almost_equal(
240            out.int_repr().numpy(), out_exp.int_repr().numpy(), decimal=0)
241