xref: /aosp_15_r20/external/pytorch/test/quantization/core/test_workflow_module.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3# Torch
4import torch
5from torch.ao.quantization import (
6    MinMaxObserver,
7    PerChannelMinMaxObserver,
8    MovingAverageMinMaxObserver,
9    MovingAveragePerChannelMinMaxObserver,
10    HistogramObserver,
11    RecordingObserver,
12    PlaceholderObserver,
13    NoopObserver,
14    FakeQuantize,
15    FixedQParamsObserver,
16    default_debug_qconfig,
17    default_observer,
18    default_histogram_observer,
19    default_per_channel_weight_observer,
20    prepare,
21    prepare_qat,
22    convert,
23    QConfig,
24    FusedMovingAvgObsFakeQuantize,
25    get_embedding_qat_module_mappings,
26    get_embedding_static_quant_module_mappings,
27)
28from torch.ao.quantization.quantize import _get_observer_dict
29
30import torch.nn as nn
31
32# Standard library
33import copy
34import io
35import itertools
36import unittest
37import math
38import numpy as np
39
40# Testing utils
41from hypothesis import given, settings
42from hypothesis import strategies as st
43import torch.testing._internal.hypothesis_utils as hu
44hu.assert_deadline_disabled()
45from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
46from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo
47from torch.testing._internal.common_quantization import (
48    QuantizationTestCase,
49    AnnotatedSingleLayerLinearModel,
50    test_only_eval_fn,
51    SingleLayerLinearModel,
52)
53
54from torch.testing._internal.common_quantized import (
55    override_quantized_engine,
56    supported_qengines,
57    override_qengines,
58    _fake_quantize_per_channel_affine_reference,
59    _fake_quantize_per_channel_affine_grad_reference,
60    to_tensor,
61)
62
63from torch.testing._internal.common_quantization import (
64    DeFusedEmbeddingBagLinear,
65)
66
67NP_RANDOM_SEED = 19
68tolerance = 1e-6
69
70class TestObserver(QuantizationTestCase):
71    @given(qdtype=st.sampled_from((torch.qint8, torch.quint8, torch.qint32)),
72           qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
73           reduce_range=st.booleans())
74    def test_per_tensor_observers(self, qdtype, qscheme, reduce_range):
75        # reduce_range cannot be true for symmetric quantization with uint8
76        if (qdtype == torch.quint8 and qscheme == torch.per_tensor_symmetric) or qdtype == torch.qint32:
77            reduce_range = False
78        ObserverList = [MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range),
79                        MovingAverageMinMaxObserver(averaging_constant=0.5,
80                                                    dtype=qdtype,
81                                                    qscheme=qscheme,
82                                                    reduce_range=reduce_range)]
83
84        def _get_ref_params(reduce_range, qscheme, dtype, input_scale, min_val, max_val):
85            eps = torch.tensor([tolerance])
86            if dtype == torch.qint8:
87                if reduce_range:
88                    quant_min, quant_max = -64, 63
89                else:
90                    quant_min, quant_max = -128, 127
91            elif dtype == torch.quint8:
92                if reduce_range:
93                    quant_min, quant_max = 0, 127
94                else:
95                    quant_min, quant_max = 0, 255
96            elif dtype == torch.qint32:
97                quant_min, quant_max = -1 * (2 ** 31), (2 ** 31) - 1
98
99            min_val_neg = torch.tensor([0.])
100            max_val_pos = torch.tensor([input_scale * max_val]) if qdtype is torch.qint32 else torch.tensor([max_val])
101
102            scale, zero_point = 1.0, 0
103            if qscheme == torch.per_tensor_symmetric or qscheme == torch.per_channel_symmetric:
104                scale = torch.max(-min_val_neg, max_val_pos) / (float(quant_max - quant_min) / 2)
105                scale = torch.max(scale, eps)
106                if dtype == torch.quint8:
107                    zero_point = 128
108            else:
109                scale = torch.max((max_val_pos - min_val_neg) / float(quant_max - quant_min), eps)
110                zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
111                zero_point = torch.clamp(zero_point, quant_min, quant_max)
112            return scale, zero_point
113
114        for myobs in ObserverList:
115            # Calculate Qparams should return with a warning for observers with no data
116            qparams = myobs.calculate_qparams()
117            input_scale = 2**16 if qdtype is torch.qint32 else 1
118            if type(myobs) == MinMaxObserver:
119                x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0]) * input_scale
120                y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0]) * input_scale
121            else:
122                # Moving average of min/max for x and y matches that of
123                # extreme values for x/y used for minmax observer
124                x = torch.tensor([0.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0]) * input_scale
125                y = torch.tensor([2.0, 5.0, 5.0, 6.0, 7.0, 10.0]) * input_scale
126
127            result = myobs(x)
128            result = myobs(y)
129            self.assertEqual(result, y)
130            self.assertEqual(myobs.min_val, 1.0 * input_scale)
131            self.assertEqual(myobs.max_val, 8.0 * input_scale)
132            qparams = myobs.calculate_qparams()
133            ref_scale, ref_zero_point = _get_ref_params(reduce_range, qscheme, qdtype, input_scale, 1.0, 8.0)
134
135            self.assertEqual(qparams[1].item(), ref_zero_point)
136            self.assertEqual(qparams[0].item(), ref_scale, atol=1e-5, rtol=0)
137            state_dict = myobs.state_dict()
138            b = io.BytesIO()
139            torch.save(state_dict, b)
140            for weights_only in [True, False]:
141                b.seek(0)
142                loaded_dict = torch.load(b, weights_only=weights_only)
143                for key in state_dict:
144                    self.assertEqual(state_dict[key], loaded_dict[key])
145                loaded_obs = MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
146                loaded_obs.load_state_dict(loaded_dict)
147                loaded_qparams = loaded_obs.calculate_qparams()
148                self.assertEqual(myobs.min_val, loaded_obs.min_val)
149                self.assertEqual(myobs.max_val, loaded_obs.max_val)
150                self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams())
151
152
153    @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
154           qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric, torch.per_channel_affine_float_qparams)),
155           ch_axis=st.sampled_from((0, 1, 2, 3)), reduce_range=st.booleans())
156    def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range):
157        # reduce_range cannot be true for symmetric quantization with uint8
158        if qscheme == torch.per_channel_affine_float_qparams:
159            reduce_range = False
160        if qdtype == torch.quint8 and qscheme == torch.per_channel_symmetric:
161            reduce_range = False
162        ObserverList = [PerChannelMinMaxObserver(reduce_range=reduce_range,
163                                                 ch_axis=ch_axis,
164                                                 dtype=qdtype,
165                                                 qscheme=qscheme),
166                        MovingAveragePerChannelMinMaxObserver(averaging_constant=0.5,
167                                                              reduce_range=reduce_range,
168                                                              ch_axis=ch_axis,
169                                                              dtype=qdtype,
170                                                              qscheme=qscheme)]
171
172        for myobs in ObserverList:
173            # Calculate qparams should work for empty observers
174            qparams = myobs.calculate_qparams()
175            x = torch.tensor(
176                [
177                    [[[1.0, 2.0], [2.0, 2.5]], [[3.0, 4.0], [4.5, 6.0]]],
178                    [[[-4.0, -3.0], [5.0, 5.0]], [[6.0, 3.0], [7.0, 8.0]]],
179                ]
180            )
181            if type(myobs) == MovingAveragePerChannelMinMaxObserver:
182                # Scaling the input tensor to model change in min/max values
183                # across batches
184                result = myobs(0.5 * x)
185                result = myobs(1.5 * x)
186                self.assertEqual(result, 1.5 * x)
187            else:
188                result = myobs(x)
189                self.assertEqual(result, x)
190
191            qparams = myobs.calculate_qparams()
192            ref_min_vals = [[1.0, -4.0], [-4.0, 3.0], [-4.0, 2.0], [-4.0, -3.0]]
193            ref_max_vals = [[6.0, 8.0], [5.0, 8.0], [6.0, 8.0], [7.0, 8.0]]
194            per_channel_symmetric_ref_scales = [
195                [0.04705882, 0.06274509],
196                [0.03921569, 0.0627451],
197                [0.04705882, 0.0627451],
198                [0.05490196, 0.0627451],
199            ]
200            per_channel_affine_ref_scales = [
201                [0.02352941, 0.04705882],
202                [0.03529412, 0.03137255],
203                [0.03921569, 0.03137255],
204                [0.04313726, 0.04313726],
205            ]
206            per_channel_affine_qint8_zp = [
207                [-128, -43],
208                [-15, -128],
209                [-26, -128],
210                [-35, -58],
211            ]
212            per_channel_affine_float_qparams_ref_scales = [
213                [0.0196, 0.0471],
214                [0.0353, 0.0196],
215                [0.0392, 0.0235],
216                [0.0431, 0.0431],
217            ]
218            per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0], [93, 70]]
219
220            self.assertEqual(myobs.min_val, ref_min_vals[ch_axis])
221            self.assertEqual(myobs.max_val, ref_max_vals[ch_axis])
222            if qscheme == torch.per_channel_symmetric:
223                ref_scales = per_channel_symmetric_ref_scales[ch_axis]
224                ref_zero_points = [0, 0] if qdtype is torch.qint8 else [128, 128]
225            elif qscheme == torch.per_channel_affine_float_qparams:
226                ref_scales = per_channel_affine_float_qparams_ref_scales[ch_axis]
227                ref_zero_points = [-1 * ref_min_vals[ch_axis][i] / ref_scales[i] for i in range(len(ref_scales))]
228            else:
229                ref_scales = per_channel_affine_ref_scales[ch_axis]
230                ref_zero_points = (
231                    per_channel_affine_qint8_zp[ch_axis]
232                    if qdtype is torch.qint8
233                    else per_channel_affine_quint8_zp[ch_axis]
234                )
235
236            if reduce_range:
237                ref_scales = [s * 255 / 127 for s in ref_scales]
238                ref_zero_points = [math.floor(z / 2) for z in ref_zero_points]
239            self.assertEqual(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype), rtol=1e-5, atol=0.0001)
240            if qscheme == torch.per_channel_affine_float_qparams:
241                self.assertEqual(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype), rtol=1e-5, atol=1)
242            else:
243                self.assertEqual(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype))
244
245
246            # Test for serializability
247            state_dict = myobs.state_dict()
248            b = io.BytesIO()
249            torch.save(state_dict, b)
250            b.seek(0)
251            loaded_dict = torch.load(b)
252            for key in state_dict:
253                self.assertEqual(state_dict[key], loaded_dict[key])
254            loaded_obs = PerChannelMinMaxObserver(reduce_range=reduce_range, ch_axis=ch_axis, dtype=qdtype, qscheme=qscheme)
255            loaded_obs.load_state_dict(loaded_dict)
256            loaded_qparams = loaded_obs.calculate_qparams()
257            self.assertEqual(myobs.min_val, loaded_obs.min_val)
258            self.assertEqual(myobs.max_val, loaded_obs.max_val)
259            self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams())
260
261
262    def test_observer_scriptable(self):
263        obs_list = [MinMaxObserver(), MovingAverageMinMaxObserver()]
264        for obs in obs_list:
265            scripted = torch.jit.script(obs)
266
267            x = torch.rand(3, 4)
268            obs(x)
269            scripted(x)
270            self.assertEqual(obs.calculate_qparams(), scripted.calculate_qparams())
271
272            buf = io.BytesIO()
273            torch.jit.save(scripted, buf)
274            buf.seek(0)
275            loaded = torch.jit.load(buf)
276            self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams())
277
278    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
279    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
280    @override_qengines
281    def test_state_dict_respects_device_affinity(self):
282        """
283        Tests that loading from a state dict loads buffers to the correct
284        device.
285        """
286        device_cpu = torch.device('cpu')
287        device_cuda = torch.device('cuda:0')
288        test_cases = itertools.product(
289            [device_cpu, device_cuda],
290            [device_cpu, device_cuda],
291            [MinMaxObserver, MovingAverageMinMaxObserver,
292             PerChannelMinMaxObserver,
293             MovingAveragePerChannelMinMaxObserver,
294             # TODO: enable this (separate PR)
295             # HistogramObserver,
296             PlaceholderObserver, RecordingObserver, NoopObserver,
297             FakeQuantize])
298
299        for device_source, device_target, obs_cls in test_cases:
300            # calibrated source model
301            model = obs_cls()
302            model.to(device_source)
303            model(torch.randn(4, 1, 4, 4, device=device_source))
304            # target model
305            model2 = obs_cls()
306            model2.to(device_target)
307            model2.load_state_dict(model.state_dict())
308            # verify that buffers stayed on model2's device
309            model_devices = {p.device for p in model2.parameters()} | \
310                {p.device for p in model2.buffers()}
311            # some observers do not have any buffers, so lessEqual instead of
312            # Equal
313            self.assertLessEqual(len(model_devices), 1)
314            if len(model_devices) == 1:
315                model_device = next(iter(model_devices))
316                self.assertEqual(model_device, device_target)
317
318    def test_histogram_observer_consistent_buffer_shape(self):
319        """
320        Ensures that the buffer shapes do not change from uninitialized to
321        initialized states for HistogramObserver.
322        """
323        obs = HistogramObserver()
324        min_shape_before = obs.min_val.shape
325        max_shape_before = obs.max_val.shape
326        for _ in range(2):
327            obs(torch.randn(4, 4, 4, 4))
328        self.assertEqual(min_shape_before, obs.min_val.shape)
329        self.assertEqual(max_shape_before, obs.max_val.shape)
330
331    def test_histogram_observer_ignore_infinity(self):
332        """
333        Ensures that HistogramObserver doesn't record values of infinity
334        """
335        obs = HistogramObserver()
336        obs2 = HistogramObserver()
337        x = torch.randn(4, 4, 4, 4)
338        obs(x * torch.inf)
339        obs(x)
340        obs2(x)
341        obs(x * torch.inf)
342        self.assertTrue(obs.min_val != -torch.inf and obs.max_val != torch.inf)
343        self.assertEqual(obs.histogram, obs2.histogram)
344
345    def test_histogram_observer_handle_close_to_infinity(self):
346        for sign in [-1, 1]:
347            obser = HistogramObserver.with_args(reduce_range=False)()
348            mask = torch.tensor([-3.4028234663852886 * 10**30, 0, 0, 0]) * sign
349            obser(mask)
350            obser(mask - sign)
351            scale, zp = obser.calculate_qparams()
352
353            input = torch.randn(1, 4)
354            ref_result = torch.softmax(input + mask, dim=1)
355
356            quant_mask = torch.quantize_per_tensor(mask, scale, zp, torch.quint8)
357            dequant_mask = quant_mask.dequantize()
358            result = torch.softmax(input + dequant_mask, dim=1)
359            self.assertEqual(result, ref_result)
360
361    def test_histogram_observer_handle_OOM_due_to_close_min_max_value(self):
362        obser = HistogramObserver.with_args(reduce_range=False)()
363        # close min and max value in the 1st forward() pass of observer tends
364        # to cause OOM in the following pass.
365        # This is due to the allocation of histogram tensor during _combine_histograms().
366        # With sanity check on the size of histogram tensor, we expect the histogram observer
367        # can still work by resetting the histogram
368        x1 = torch.tensor([0, 1e-9])
369        obser(x1)
370
371        x2 = torch.tensor([2.0, 3.0])
372        obser(x2)
373
374    def test_histogram_observer_save_load_state_dict(self):
375        """
376        Smoke test on saving/loading state_dict
377        """
378        obs1 = HistogramObserver()
379        obs1(torch.randn(4, 4, 4, 4))
380        obs2 = HistogramObserver()
381        obs2.load_state_dict(obs1.state_dict())
382        self.assertEqual(obs2.min_val.shape, torch.Size([]))
383        self.assertEqual(obs2.max_val.shape, torch.Size([]))
384
385
386    def test_save_load_state_dict_script(self):
387        """
388        Tests that we can save and load state_dict for observers that are scripted
389        in a quantized model.
390        """
391        obs_list = [MinMaxObserver, MovingAverageMinMaxObserver, HistogramObserver]
392
393        for obs in obs_list:
394            model = SingleLayerLinearModel().eval()
395            qconfig = QConfig(activation=default_observer, weight=obs)
396            qconfig_dict = {'' : qconfig}
397            scripted = torch.jit.script(model)
398            scripted = torch.ao.quantization.prepare_jit(scripted, qconfig_dict)
399            x = torch.rand(5, 5)
400            scripted(x)
401            obs_dict = torch.ao.quantization.get_observer_state_dict(scripted)
402
403            # Load stats
404            scripted_2 = torch.jit.script(model)
405            scripted_2 = torch.ao.quantization.prepare_jit(scripted_2, qconfig_dict)
406            torch.ao.quantization.load_observer_state_dict(scripted_2, obs_dict)
407            # Verify that state_dict matches exactly with original one.
408            self.assertEqual(scripted.state_dict(), scripted_2.state_dict())
409
410
411    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
412    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
413    def test_observer_qparams_respects_device_affinity(self):
414        """
415        Ensure that the scale and zero_point returned by the observer
416        are on the same device as the input tensor.
417        """
418        observerList = [MinMaxObserver(),
419                        MovingAverageMinMaxObserver(),
420                        PerChannelMinMaxObserver(),
421                        MovingAveragePerChannelMinMaxObserver()]
422        for obs in observerList:
423            device = torch.device('cuda:1')
424            x = torch.randn(1, 2, device=device)
425            obs.to(device)
426            result = obs(x)
427            scale, zero_point = obs.calculate_qparams()
428
429            self.assertEqual(x.device, scale.device)
430            self.assertEqual(x.device, zero_point.device)
431
432    def test_zero_numel(self):
433        obs_list = [MinMaxObserver, MovingAverageMinMaxObserver,
434                    PerChannelMinMaxObserver,
435                    MovingAveragePerChannelMinMaxObserver, HistogramObserver,
436                    FakeQuantize, FixedQParamsObserver]
437        for obs_cls in obs_list:
438            if obs_cls is FixedQParamsObserver:
439                obs = obs_cls(0.1, 0)
440            else:
441                obs = obs_cls()
442            x = torch.tensor([])
443            # verify no crash
444            x = obs(x)
445
446    def test_dynamic_quant_observer(self):
447        obs = MovingAverageMinMaxObserver(averaging_constant=1, is_dynamic=True)
448        x = torch.randn((3, 3))
449        obs(x)
450        params = obs.calculate_qparams()
451        for _ in range(20):
452            obs(10 * torch.randn((3, 3)))
453            self.assertNotEqual(params, obs.calculate_qparams())
454            obs(x)
455            self.assertEqual(params, obs.calculate_qparams())
456
457    def test_dynamic_quant_observer_matching_choose_qparams(self):
458        obs = MovingAverageMinMaxObserver(averaging_constant=1, is_dynamic=True)
459        for x in [torch.randn(3, 3), torch.rand(3, 3, 3), torch.randn(3, 3, 3, 3)]:
460            obs(x)
461            params = obs.calculate_qparams()
462            scale, zero_point = torch._choose_qparams_per_tensor(x)
463            self.assertEqual(scale, params[0])
464            self.assertEqual(zero_point, params[1])
465
466    def test_per_channel_observers_load_state_dict(self):
467        observer_list = [PerChannelMinMaxObserver, MovingAveragePerChannelMinMaxObserver]
468
469        for obs_cls in observer_list:
470            obs = obs_cls()
471            obs(torch.randn((32, 32)))
472            new_obs = obs_cls()
473            # make sure the state_dict can be loaded
474            new_obs.load_state_dict(obs.state_dict())
475            self.assertTrue(torch.equal(obs.min_val, new_obs.min_val))
476            self.assertTrue(torch.equal(obs.max_val, new_obs.max_val))
477
478# HistogramObserver that works like it does on master
479class _ReferenceHistogramObserver(HistogramObserver):
480    def __init__(self, *args, **kwargs):
481        super().__init__(*args, **kwargs)
482
483    @torch.jit.ignore
484    def _non_linear_param_search(self):
485        r"""Non-linear parameter search.
486
487        An approximation for L2 error minimization for selecting min/max.
488        By selecting new min/max, we filter out outliers in input distribution.
489        This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in
490        caffe2/quantization/server/norm_minimization.cc
491        """
492        def _get_norm(delta_begin, delta_end, density, norm_type):
493            r"""
494            Compute the norm of the values uniformaly distributed between
495            delta_begin and delta_end.
496
497            norm = density * (integral_{begin, end} x^2)
498                 = density * (end^3 - begin^3) / 3
499            """
500            assert norm_type == "L2", "Only L2 norms are currently supported"
501            norm = 0.0
502            if norm_type == "L2":
503                norm = (
504                    delta_end * delta_end * delta_end
505                    - delta_begin * delta_begin * delta_begin
506                ) / 3
507            return density * norm
508
509        def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
510            r"""
511            Compute the quantization error if we use start_bin to end_bin as the
512            min and max to do the quantization.
513            """
514            bin_width = (self.max_val.item() - self.min_val.item()) / self.bins
515
516            norm = 0.0
517            dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
518            if dst_bin_width == 0.0:
519                return 0.0
520            for src_bin in range(self.bins):
521                # distances from the beginning of first dst_bin to the beginning and
522                # end of src_bin
523                src_bin_begin = (src_bin - next_start_bin) * bin_width
524                src_bin_end = src_bin_begin + bin_width
525
526                # which dst_bins the beginning and end of src_bin belong to?
527                dst_bin_of_begin = min(
528                    self.dst_nbins - 1, max(0.0, math.floor(src_bin_begin / dst_bin_width))
529                )
530                dst_bin_of_end = min(
531                    self.dst_nbins - 1, max(0.0, math.floor(src_bin_end / dst_bin_width))
532                )
533                dst_bin_of_begin_center = (
534                    dst_bin_of_begin * dst_bin_width + dst_bin_width / 2
535                )
536
537                density = self.histogram[src_bin] / bin_width
538                if dst_bin_of_begin == dst_bin_of_end:
539                    # if src_bin is entirely within 1 dst_bin
540                    delta_begin = src_bin_begin - dst_bin_of_begin_center
541                    delta_end = src_bin_end - dst_bin_of_begin_center
542                    norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
543                else:
544                    delta_begin = src_bin_begin - dst_bin_of_begin_center
545                    delta_end = dst_bin_width / 2
546                    norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
547
548                    norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm(
549                        -dst_bin_width / 2, dst_bin_width / 2, density, norm_type
550                    )
551
552                    dst_bin_of_end_center = (
553                        dst_bin_of_end * dst_bin_width + dst_bin_width / 2
554                    )
555
556                    delta_begin = -dst_bin_width / 2
557                    delta_end = src_bin_end - dst_bin_of_end_center
558                    norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
559            return norm
560
561        assert self.histogram.size()[0] == self.bins, "bins mistmatch"
562        bin_width = (self.max_val - self.min_val) / self.bins
563
564        # cumulative sum
565        total = torch.sum(self.histogram).item()
566        cSum = torch.cumsum(self.histogram, dim=0)
567
568        stepsize = 1e-5  # granularity
569        alpha = 0.0  # lower bound
570        beta = 1.0  # upper bound
571        start_bin = 0
572        end_bin = self.bins - 1
573        norm_min = float("inf")
574
575        while alpha < beta:
576            # Find the next step
577            next_alpha = alpha + stepsize
578            next_beta = beta - stepsize
579
580            # find the left and right bins between the quantile bounds
581            l = start_bin
582            r = end_bin
583            while l < end_bin and cSum[l] < next_alpha * total:
584                l = l + 1
585            while r > start_bin and cSum[r] > next_beta * total:
586                r = r - 1
587
588            # decide the next move
589            next_start_bin = start_bin
590            next_end_bin = end_bin
591            if (l - start_bin) > (end_bin - r):
592                # move the start bin
593                next_start_bin = l
594                alpha = next_alpha
595            else:
596                # move the end bin
597                next_end_bin = r
598                beta = next_beta
599
600            if next_start_bin == start_bin and next_end_bin == end_bin:
601                continue
602
603            # calculate the quantization error using next_start_bin and next_end_bin
604            norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2")
605
606            if norm > norm_min:
607                break
608            norm_min = norm
609            start_bin = next_start_bin
610            end_bin = next_end_bin
611
612        new_min = self.min_val + bin_width * start_bin
613        new_max = self.min_val + bin_width * (end_bin + 1)
614        return new_min, new_max
615
616class TestRecordHistogramObserver(QuantizationTestCase):
617    # TODO: move this to quantize.py
618    def test_record_observer(self):
619        for qengine in supported_qengines:
620            with override_quantized_engine(qengine):
621                model = AnnotatedSingleLayerLinearModel()
622                model.qconfig = default_debug_qconfig
623                model = prepare(model)
624                # run the evaluation and dump all tensors
625                test_only_eval_fn(model, self.calib_data)
626                test_only_eval_fn(model, self.calib_data)
627                observer_dict = {}
628                _get_observer_dict(model, observer_dict)
629
630                self.assertTrue('fc1.module.activation_post_process' in observer_dict.keys(),
631                                'observer is not recorded in the dict')
632                self.assertEqual(len(observer_dict['fc1.module.activation_post_process'].get_tensor_value()),
633                                 2 * len(self.calib_data))
634                self.assertEqual(observer_dict['fc1.module.activation_post_process'].get_tensor_value()[0],
635                                 model(self.calib_data[0][0]))
636
637    @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)))
638    def test_observer_scriptable(self, qdtype):
639        obs = RecordingObserver(dtype=qdtype)
640        scripted = torch.jit.script(obs)
641
642        x = torch.rand(3, 4)
643        obs(x)
644        scripted(x)
645        self.assertTrue(torch.equal(obs.get_tensor_value()[0], scripted.get_tensor_value()[0]))
646        buf = io.BytesIO()
647        torch.jit.save(scripted, buf)
648        buf.seek(0)
649        loaded = torch.jit.load(buf)
650        self.assertTrue(torch.equal(obs.get_tensor_value()[0], loaded.get_tensor_value()[0]))
651
652class TestHistogramObserver(QuantizationTestCase):
653    @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
654           qscheme=st.sampled_from(
655               (torch.per_tensor_affine, torch.per_tensor_symmetric))
656           )
657    def test_observer_scriptable(self, qdtype, qscheme):
658        ob_list = [
659            HistogramObserver(dtype=qdtype, qscheme=qscheme),
660            default_histogram_observer()
661        ]
662        for obs in ob_list:
663            scripted = torch.jit.script(obs)
664
665            x = torch.rand(3, 4)
666            obs(x)
667            scripted(x)
668            self.assertTrue(torch.equal(obs.histogram, scripted.histogram))
669            buf = io.BytesIO()
670            torch.jit.save(scripted, buf)
671            buf.seek(0)
672            loaded = torch.jit.load(buf)
673            self.assertTrue(torch.equal(obs.histogram, scripted.histogram))
674
675    @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
676           qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
677           reduce_range=st.booleans())
678    @settings(max_examples=10)
679    def test_histogram_observer(self, qdtype, qscheme, reduce_range):
680        myobs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
681        # Calculate qparams should work for empty observers
682        qparams = myobs.calculate_qparams()
683        x = torch.tensor([2.0, 3.0, 4.0, 5.0], requires_grad=True)
684        y = torch.tensor([5.0, 6.0, 7.0, 8.0])
685        out_x = myobs(x)
686        self.assertTrue(out_x.requires_grad)
687        myobs(y)
688        self.assertEqual(myobs.min_val, 2.0)
689        self.assertEqual(myobs.max_val, 8.0)
690        self.assertEqual(myobs.histogram, [2., 3., 3.])
691
692        qparams = myobs.calculate_qparams()
693
694        if reduce_range:
695            if qscheme == torch.per_tensor_symmetric:
696                ref_scale = 0.0470588 * 255 / 127
697                ref_zero_point = 0 if qdtype is torch.qint8 else 128
698            else:
699                ref_scale = 0.0235294 * 255 / 127
700                ref_zero_point = -64 if qdtype is torch.qint8 else 0
701        else:
702            if qscheme == torch.per_tensor_symmetric:
703                ref_scale = 0.0470588
704                ref_zero_point = 0 if qdtype is torch.qint8 else 128
705            else:
706                ref_scale = 0.0235294
707                ref_zero_point = -128 if qdtype is torch.qint8 else 0
708
709        self.assertEqual(qparams[1].item(), ref_zero_point)
710        self.assertEqual(qparams[0].item(), ref_scale, atol=1e-5, rtol=0)
711        # Test for serializability
712        state_dict = myobs.state_dict()
713        b = io.BytesIO()
714        torch.save(state_dict, b)
715        b.seek(0)
716        loaded_dict = torch.load(b)
717        for key in state_dict:
718            self.assertEqual(state_dict[key], loaded_dict[key])
719        loaded_obs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
720        loaded_obs.load_state_dict(loaded_dict)
721        loaded_qparams = loaded_obs.calculate_qparams()
722        self.assertEqual(myobs.min_val, loaded_obs.min_val)
723        self.assertEqual(myobs.max_val, loaded_obs.max_val)
724        self.assertEqual(myobs.histogram, loaded_obs.histogram)
725        self.assertEqual(myobs.bins, loaded_obs.bins)
726        self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams())
727
728    def test_histogram_observer_one_sided(self):
729        myobs = HistogramObserver(bins=8, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
730        x = torch.tensor([0.0, 0.3, 1.2, 1.7])
731        y = torch.tensor([0.1, 1.3, 2.0, 2.7])
732        myobs(x)
733        myobs(y)
734        self.assertEqual(myobs.min_val, 0)
735        qparams = myobs.calculate_qparams()
736        self.assertEqual(qparams[1].item(), 0)
737
738    def test_histogram_observer_same_inputs(self):
739        myobs = HistogramObserver(bins=3, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric,
740                                  reduce_range=False)
741        w = torch.ones(4, requires_grad=True)
742        x = torch.zeros(4, requires_grad=True)
743        y = torch.tensor([2.0, 3.0, 4.0, 5.0], requires_grad=True)
744        z = torch.tensor([5.0, 6.0, 7.0, 8.0])
745        myobs(w)
746        myobs(x)
747        myobs(x)
748        myobs(y)
749        myobs(z)
750        qparams = myobs.calculate_qparams()
751        self.assertEqual(myobs.min_val, 0.0)
752        self.assertEqual(myobs.max_val, 8.0)
753        self.assertEqual(myobs.histogram, [13.25, 3.75, 3.])
754
755    @skipIfTorchDynamo("too slow")
756    @given(N=st.sampled_from([10, 1000]),
757           bins=st.sampled_from([256, 512, 1024, 2048]),
758           dtype=st.sampled_from([torch.qint8, torch.quint8]),
759           qscheme=st.sampled_from([torch.per_tensor_affine, torch.per_tensor_symmetric]),
760           reduce_range=st.booleans())
761    def test_histogram_observer_against_reference(self, N, bins, dtype, qscheme, reduce_range):
762
763        ref_obs = _ReferenceHistogramObserver(bins=bins, dtype=dtype, qscheme=qscheme, reduce_range=reduce_range)
764        my_obs = HistogramObserver(bins=bins, dtype=dtype, qscheme=qscheme, reduce_range=reduce_range)
765
766        for _ in range(10):
767            X = torch.randn(N)
768            my_obs(X)
769            ref_obs(X)
770            self.assertEqual(my_obs.histogram, ref_obs.histogram)
771            self.assertEqual(my_obs.min_val, ref_obs.min_val)
772            self.assertEqual(my_obs.max_val, ref_obs.max_val)
773
774        ref_qparams = ref_obs.calculate_qparams()
775        my_qparams = my_obs.calculate_qparams()
776
777        for i in range(0, bins, 200):
778            for j in range(i + 5, bins, 200):
779                ref_qe = ref_obs._compute_quantization_error(i, j)
780                qe = my_obs._compute_quantization_error(i, j)
781                self.assertEqual(ref_qe, qe)
782
783        self.assertEqual(ref_qparams, my_qparams)
784
785    def test_histogram_observer_extreme_inputs(self):
786        """
787        Ensures that the HistogramObserver is able to work correctly in
788        a rare case: extreme samll max values
789        """
790        obs = HistogramObserver()
791        test_input = torch.tensor(
792            [0.0, 0.0, 4.58e-41, 4.58e-41]
793        )
794        # Make sure it runs, two passes are required based on the behavior of forward func
795        # The first pass initializes min_val&max_val, and second pass calls _adjust_min_max
796        obs(test_input)
797        obs(test_input)
798
799    def test_histogram_observer_correct_numel(self):
800        for i in range(1, 10):
801            obs = HistogramObserver()
802            obs(torch.randn(i, i))
803            self.assertEqual(obs.histogram.sum().item(), i**2)
804
805    def test_histogram_observer_single_inputs(self):
806        # Make sure that if we pass single valued tensors to the observer, the code runs
807        observer = HistogramObserver(bins=10)
808        a = torch.FloatTensor([1])
809        b = torch.FloatTensor([3])
810        c = torch.FloatTensor([2])
811        d = torch.FloatTensor([4])
812
813        observer(a)
814        observer(b)
815        observer(c)
816        observer(d)
817
818        self.assertEqual(observer.min_val, 1)
819        self.assertEqual(observer.max_val, 4)
820        self.assertEqual(torch.sum(observer.histogram), 4)
821
822    def test_histogram_observer_update_within_range_succeeds(self):
823        # test if an update within the existing range actually updates
824        myobs = HistogramObserver(bins=10)
825        x = torch.tensor([0.0, 3.0, 4.0, 9.0])
826        y = torch.tensor([2.0, 3.0, 7.0, 8.0])
827        myobs(x)
828        myobs(y)
829        self.assertEqual(myobs.min_val, 0.0)
830        self.assertEqual(myobs.max_val, 9.0)
831        self.assertEqual(myobs.histogram, [1., 0., 1., 2., 1., 0., 0., 1., 1., 1.])
832
833class TestFakeQuantize(TestCase):
834    @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
835           X=hu.per_channel_tensor(shapes=hu.array_shapes(2, 5,),
836           qparams=hu.qparams(dtypes=torch.qint8)))
837    def test_fq_module_per_channel(self, device, X):
838        np.random.seed(NP_RANDOM_SEED)
839        X, (scale, zero_point, axis, torch_type) = X
840        quant_min = torch.iinfo(torch_type).min
841        quant_max = torch.iinfo(torch_type).max
842
843        X = to_tensor(X, device)
844        X.requires_grad_()
845        fq_module = FakeQuantize(default_per_channel_weight_observer, quant_min, quant_max, ch_axis=axis).to(device)
846        Y_prime = fq_module(X)
847        assert fq_module.scale is not None
848        assert fq_module.zero_point is not None
849        Y = _fake_quantize_per_channel_affine_reference(X, fq_module.scale,
850                                                        fq_module.zero_point, axis, quant_min, quant_max)
851        np.testing.assert_allclose(Y.cpu().detach().numpy(), Y_prime.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
852
853        # Test backward
854        dout = torch.rand_like(X, dtype=torch.float, device=device)
855        Y_prime.backward(dout)
856        dX = _fake_quantize_per_channel_affine_grad_reference(dout, X, fq_module.scale,
857                                                              fq_module.zero_point, axis, quant_min, quant_max)
858        np.testing.assert_allclose(dX.cpu().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
859
860    def test_fq_serializable_per_channel(self):
861        observer = default_per_channel_weight_observer
862        quant_min = -128
863        quant_max = 127
864        fq_module = FakeQuantize(observer, quant_min, quant_max)
865        X = torch.tensor([[-5, -3.5, -2, 0, 3, 5, 7], [1, 3, 2, 5, 6.5, 8, 10]], dtype=torch.float32)
866        y_ref = fq_module(X)
867        state_dict = fq_module.state_dict()
868        self.assertEqual(state_dict['scale'], [0.054902, 0.078431])
869        self.assertEqual(state_dict['zero_point'], [0, 0])
870        b = io.BytesIO()
871        torch.save(state_dict, b)
872        b.seek(0)
873        loaded_dict = torch.load(b)
874        for key in state_dict:
875            self.assertEqual(state_dict[key], loaded_dict[key])
876
877    def test_quant_min_max_override(self):
878        observer = default_per_channel_weight_observer
879        # test no override
880        fq_module = FakeQuantize(observer)
881        self.assertEqual(fq_module.activation_post_process.quant_min, -128)
882        self.assertEqual(fq_module.activation_post_process.quant_max, 127)
883        # test quant_min/quant_max override
884        fq_module = FakeQuantize(observer, quant_min=0, quant_max=127)
885        self.assertEqual(fq_module.activation_post_process.quant_min, 0)
886        self.assertEqual(fq_module.activation_post_process.quant_max, 127)
887
888def _get_buffer_ids(module):
889    """
890    Object addresses stay constant if and only if all modifications are in-place
891    """
892    return [id(v) for k, v in module._buffers.items()]
893
894class TestDistributed(QuantizationTestCase):
895
896    def test_observers_preserve_buffers(self):
897        """
898        Tests that observers only modify buffers in place. Note: this is important
899        because nn.DataParallel depends on this assumption to work correctly.
900        However, DataParallel does not expose IDs of the replicas, so we test it
901        without DataParallel in order to easily access the object IDs.
902        """
903        observer_types = [
904            torch.ao.quantization.MinMaxObserver.with_args(dtype=torch.qint8),
905            torch.ao.quantization.MovingAverageMinMaxObserver.with_args(dtype=torch.qint8),
906            torch.ao.quantization.PerChannelMinMaxObserver.with_args(dtype=torch.qint8),
907            torch.ao.quantization.MovingAveragePerChannelMinMaxObserver.with_args(dtype=torch.qint8),
908            torch.ao.quantization.HistogramObserver.with_args(dtype=torch.qint8),
909            torch.ao.quantization.RecordingObserver.with_args(dtype=torch.qint8),
910            torch.ao.quantization.PlaceholderObserver.with_args(dtype=torch.float16),
911        ]
912
913        for observer_type in observer_types:
914            observer = observer_type()
915            buffer_ids_before = _get_buffer_ids(observer)
916            for _i in range(5):
917                inputs = torch.rand((4, 4, 4))
918                observer(inputs)
919            buffer_ids_after = _get_buffer_ids(observer)
920            self.assertEqual(
921                buffer_ids_before,
922                buffer_ids_after,
923                msg=f"{str(observer)}: Buffers must be modified in place")
924
925    def test_fake_quant_preserves_buffers(self):
926        """
927        Tests that fake quant only modifies buffers in place. Note: this is important
928        because nn.DataParallel depends on this assumption to work correctly.
929        However, DataParallel does not expose IDs of the replicas, so we test it
930        without DataParallel in order to easily access the object IDs.
931        """
932        model = torch.ao.quantization.FakeQuantize()
933        buffer_ids_before = _get_buffer_ids(model)
934        for _i in range(5):
935            inputs = torch.rand((4, 4, 4))
936            model(inputs)
937        model.apply(torch.ao.quantization.enable_fake_quant)
938        model.apply(torch.ao.quantization.disable_fake_quant)
939        model.apply(torch.ao.quantization.enable_observer)
940        model.apply(torch.ao.quantization.disable_observer)
941        buffer_ids_after = _get_buffer_ids(model)
942        self.assertEqual(
943            buffer_ids_before,
944            buffer_ids_after,
945            msg="FakeQuant: Buffers must be modified in place")
946
947    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
948    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
949    def test_qat_data_parallel(self):
950        """
951        Tests that doing QAT in nn.DataParallel does not crash.
952        """
953        if 'fbgemm' not in torch.backends.quantized.supported_engines:
954            return
955        with override_quantized_engine('fbgemm'):
956            device = torch.device('cuda')
957
958            model = nn.Sequential(
959                torch.ao.quantization.QuantStub(),
960                nn.Conv2d(3, 1, 1, bias=False),
961                nn.BatchNorm2d(1),
962                nn.ReLU(),
963                nn.Conv2d(1, 2, 3, stride=2, padding=1, bias=False),
964                nn.BatchNorm2d(2),
965                nn.AvgPool2d(14),
966                nn.Sigmoid(),
967                torch.ao.quantization.DeQuantStub(),
968            )
969
970            torch.ao.quantization.fuse_modules_qat(model, [['1', '2', '3'], ['4', '5']], inplace=True)
971
972            model.qconfig = torch.ao.quantization.get_default_qat_qconfig('fbgemm')
973            torch.ao.quantization.prepare_qat(model, inplace=True)
974            model = nn.DataParallel(model, device_ids=[0, 1])
975            model.to(device)
976            model.train()
977
978            for epoch in range(3):
979                inputs = torch.rand(2, 3, 28, 28).to(device)
980                model(inputs)
981                if epoch >= 1:
982                    model.apply(torch.ao.quantization.disable_observer)
983                if epoch >= 2:
984                    model.apply(torch.ao.nn.intrinsic.qat.freeze_bn_stats)
985                quant_model = copy.deepcopy(model.module)
986                quant_model = torch.ao.quantization.convert(quant_model.eval().cpu(), inplace=False)
987                with torch.no_grad():
988                    out = quant_model(torch.rand(1, 3, 28, 28))
989
990    def test_qat_convbn_fused_syncbn_replacement(self):
991        """
992        Tests that SyncBatchNorm replacement works for fused ConvBN.
993        """
994        if 'fbgemm' not in torch.backends.quantized.supported_engines:
995            return
996        with override_quantized_engine('fbgemm'):
997            # create conv-bn
998            class Model(nn.Module):
999                def __init__(self) -> None:
1000                    super().__init__()
1001                    self.conv = nn.Conv2d(4, 1, 3, padding=1)
1002                    self.bn = nn.BatchNorm2d(1)
1003
1004                def forward(self, x):
1005                    x = self.conv(x)
1006                    x = self.bn(x)
1007                    return x
1008
1009            model = Model()
1010            # fuse it
1011            fused_model = torch.ao.quantization.fuse_modules_qat(
1012                model,
1013                [['conv', 'bn']],
1014            )
1015            # convert to QAT
1016            fused_model.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
1017            torch.ao.quantization.prepare_qat(fused_model, inplace=True)
1018            # replace with DDP
1019            fused_model = nn.SyncBatchNorm.convert_sync_batchnorm(fused_model)
1020            self.assertTrue(
1021                isinstance(fused_model.conv.bn, nn.SyncBatchNorm),
1022                "Expected BN to be converted to SyncBN")
1023
1024    def test_syncbn_preserves_qconfig(self):
1025        """
1026        Makes sure that if a BatchNorm is not fused and a qconfig exists,
1027        convering the module to SyncBatchNorm preserves the qconfig.
1028        """
1029        m = nn.Sequential(
1030            nn.Conv2d(1, 1, 1),
1031            nn.BatchNorm2d(1),
1032        )
1033        m[1].qconfig = torch.ao.quantization.default_qconfig
1034        m = torch.nn.SyncBatchNorm.convert_sync_batchnorm(m)
1035        self.assertTrue(
1036            hasattr(m[1], "qconfig"),
1037            "missing qconfig after SyncBatchNorm conversion")
1038
1039    @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
1040    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1041    @override_qengines
1042    def test_device_affinity(self):
1043        """
1044        Tests that converting a model to QAT respects device affinity
1045        """
1046        class Model(nn.Module):
1047
1048            def __init__(self) -> None:
1049                super().__init__()
1050                self.conv = nn.Conv2d(1, 1, 1)
1051                self.bn = nn.BatchNorm2d(1)
1052                self.relu = nn.ReLU()
1053
1054            def forward(self, x):
1055                x = self.conv(x)
1056                x = self.bn(x)
1057                x = self.relu(x)
1058                return x
1059
1060        model = Model()
1061        model.qconfig = torch.ao.quantization.get_default_qat_qconfig(torch.backends.quantized.engine)
1062        device = torch.device('cuda:0')
1063        model.to(device)
1064        torch.ao.quantization.prepare_qat(model, inplace=True)
1065        model_devices = {p.device for p in model.parameters()} | \
1066            {p.device for p in model.buffers()}
1067        self.assertEqual(len(model_devices), 1)
1068        model_device = next(iter(model_devices))
1069        self.assertEqual(model_device, device)
1070
1071        # ensure that running an input on CUDA works without any needed changes
1072        input = torch.randn(4, 1, 4, 4, device=device)
1073        model(input)
1074
1075class TestFusedObsFakeQuantModule(TestCase):
1076    @given(
1077        device=st.sampled_from(
1078            ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
1079        )
1080    )
1081    @settings(deadline=None)
1082    def test_fused_obs_fq_module(self, device):
1083        # Set up the parameters
1084        x = torch.randn(5, 5, device=device)
1085        running_min_op = torch.tensor(float("inf"), device=device)
1086        running_max_op = torch.tensor(float("-inf"), device=device)
1087        avg_const = 0.01
1088        scale = torch.tensor([1.0], device=device)
1089        zero_point = torch.tensor([0], dtype=torch.int, device=device)
1090
1091        # Run the forward on the Module
1092        mod = FusedMovingAvgObsFakeQuantize()
1093        torch.ao.quantization.enable_fake_quant(mod)
1094        torch.ao.quantization.enable_observer(mod)
1095        mod.to(device)
1096        out = mod(x)
1097
1098        # Run the operator directly
1099        pt_op = torch.fused_moving_avg_obs_fake_quant
1100
1101        out_ref = pt_op(
1102            x,
1103            mod.observer_enabled,
1104            mod.fake_quant_enabled,
1105            running_min_op,
1106            running_max_op,
1107            scale,
1108            zero_point,
1109            avg_const,
1110            0,
1111            255,
1112            0,
1113            False,
1114        )
1115
1116        # Compare params with reference
1117        torch.testing.assert_close(out, out_ref)
1118        torch.testing.assert_close(
1119            running_min_op, mod.activation_post_process.min_val
1120        )
1121        torch.testing.assert_close(
1122            running_max_op, mod.activation_post_process.max_val
1123        )
1124
1125    @given(
1126        device=st.sampled_from(
1127            ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
1128        )
1129    )
1130    @settings(deadline=None)
1131    def test_fused_obs_fq_moving_avg_module(self, device):
1132        # Set up the parameters
1133        running_min_op = torch.tensor(float("inf"), device=device)
1134        running_max_op = torch.tensor(float("-inf"), device=device)
1135        avg_const = 0.001
1136        scale = torch.tensor([1.0], device=device)
1137        zero_point = torch.tensor([0], dtype=torch.int, device=device)
1138
1139        mod = FusedMovingAvgObsFakeQuantize(averaging_constant=0.001)
1140        mod.to(device)
1141        mod.observer_enabled[0] = 0
1142        mod.fake_quant_enabled[0] = 0
1143
1144        for i in range(10):
1145            x = torch.randn(5, 5, device=device)
1146            if i > 2:
1147                mod.observer_enabled[0] = 1
1148            if i > 4:
1149                mod.fake_quant_enabled[0] = 1
1150            # Run the forward on the Module
1151            out = mod(x)
1152
1153            # Run the operator directly
1154            pt_op = torch.fused_moving_avg_obs_fake_quant
1155
1156            out_ref = pt_op(
1157                x,
1158                mod.observer_enabled,
1159                mod.fake_quant_enabled,
1160                running_min_op,
1161                running_max_op,
1162                scale,
1163                zero_point,
1164                avg_const,
1165                0,
1166                255,
1167                0,
1168                False,
1169            )
1170
1171            # Compare params with reference
1172            torch.testing.assert_close(out, out_ref)
1173            torch.testing.assert_close(
1174                running_min_op, mod.activation_post_process.min_val
1175            )
1176            torch.testing.assert_close(
1177                running_max_op, mod.activation_post_process.max_val
1178            )
1179
1180    @given(
1181        device=st.sampled_from(
1182            ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
1183        )
1184    )
1185    @settings(deadline=None)
1186    def test_compare_fused_obs_fq_oss_module(self, device):
1187        mod = FusedMovingAvgObsFakeQuantize()
1188        torch.ao.quantization.enable_fake_quant(mod)
1189        torch.ao.quantization.enable_observer(mod)
1190        mod.to(device)
1191
1192        mod_ref = FakeQuantize()
1193        torch.ao.quantization.enable_fake_quant(mod_ref)
1194        torch.ao.quantization.enable_observer(mod_ref)
1195        mod_ref.to(device)
1196
1197        for i in range(10):
1198            x = torch.randn(5, 5, device=device)
1199            out = mod(x)
1200            out_ref = mod_ref(x)
1201            torch.testing.assert_close(out, out_ref)
1202            torch.testing.assert_close(
1203                mod_ref.activation_post_process.min_val,
1204                mod.activation_post_process.min_val,
1205            )
1206            torch.testing.assert_close(
1207                mod_ref.activation_post_process.max_val,
1208                mod.activation_post_process.max_val,
1209            )
1210
1211    def test_fused_mod_per_channel(self):
1212        devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]
1213        m = 5
1214        n = 10
1215        for device in devices:
1216            running_min_op = torch.empty(m, device=device).fill_(float("inf"))
1217            running_max_op = torch.empty(m, device=device).fill_(float("-inf"))
1218            avg_const = 0.001
1219            scale = torch.empty(m, device=device).fill_(0.1)
1220            zero_point = torch.empty(m, dtype=torch.int, device=device).fill_(0)
1221            obs = FusedMovingAvgObsFakeQuantize.with_args(
1222                averaging_constant=avg_const,
1223                observer=MovingAveragePerChannelMinMaxObserver,
1224            )
1225            mod = obs()
1226            mod = torch.jit.script(mod)
1227            mod.to(device)
1228
1229            for i in range(10):
1230                x = torch.randn(m, n, device=device)
1231                if i > 2:
1232                    mod.observer_enabled[0] = 1
1233                if i > 4:
1234                    mod.fake_quant_enabled[0] = 1
1235                # Run the forward on the Module
1236                out = mod(x)
1237
1238                # Run the operator directly
1239                pt_op = torch.fused_moving_avg_obs_fake_quant
1240
1241                out_ref = pt_op(
1242                    x,
1243                    mod.observer_enabled,
1244                    mod.fake_quant_enabled,
1245                    running_min_op,
1246                    running_max_op,
1247                    scale,
1248                    zero_point,
1249                    avg_const,
1250                    0,
1251                    255,
1252                    0,
1253                    True,
1254                    False,
1255                )
1256                # Compare params with reference
1257                torch.testing.assert_close(out, out_ref)
1258                if mod.observer_enabled[0]:
1259                    torch.testing.assert_close(
1260                        running_min_op, mod.activation_post_process.min_val
1261                    )
1262                    torch.testing.assert_close(
1263                        running_max_op, mod.activation_post_process.max_val
1264                    )
1265                if mod.fake_quant_enabled:
1266                    torch.testing.assert_close(scale, mod.scale)
1267                    torch.testing.assert_close(zero_point, mod.zero_point)
1268
1269            torch.testing.assert_close(mod.state_dict()['activation_post_process.min_val'], running_min_op)
1270            torch.testing.assert_close(mod.state_dict()['activation_post_process.max_val'], running_max_op)
1271
1272    def test_fused_mod_reduce_range(self):
1273        obs = FusedMovingAvgObsFakeQuantize(quant_min=0, quant_max=255, dtype=torch.quint8, reduce_range=True)
1274        self.assertEqual(obs.activation_post_process.quant_min, 0)
1275        self.assertEqual(obs.activation_post_process.quant_max, 127)
1276
1277    def test_embedding_bag_qat_config(self):
1278        class Model(nn.Module):
1279            def __init__(self) -> None:
1280                super().__init__()
1281                self.emb1 = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12,
1282                                                  include_last_offset=True, scale_grad_by_freq=False, mode='sum')
1283                self.emb2 = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12,
1284                                                  include_last_offset=True, scale_grad_by_freq=False, mode='sum')
1285
1286            def forward(self, indices):
1287                return torch.cat((self.emb1(indices), self.emb2(indices)))
1288
1289
1290        qconfigs = [torch.ao.quantization.default_embedding_qat_qconfig,
1291                    torch.ao.quantization.default_embedding_qat_qconfig_4bit]
1292        for qconfig in qconfigs:
1293            model = Model().train()
1294            indices = torch.randint(0, 10, (5, 12))
1295
1296            model.qconfig = qconfig
1297
1298            quant_model = prepare_qat(model,
1299                                      mapping=get_embedding_qat_module_mappings())
1300
1301            count_fake_quant = 0
1302            for name, mod in quant_model.named_modules():
1303                if name.endswith('weight_fake_quant'):
1304                    count_fake_quant += 1
1305                    self.assertEqual(type(mod), FakeQuantize)
1306            self.assertEqual(count_fake_quant, 2)
1307
1308            quant_model(indices)
1309
1310            # Ensure that EmbeddingBags have float zero_point values
1311            self.assertEqual(quant_model.emb1.weight_fake_quant.zero_point.dtype, torch.float32)
1312            self.assertEqual(quant_model.emb2.weight_fake_quant.zero_point.dtype, torch.float32)
1313
1314            inference_gm = convert(quant_model.eval().cpu(),
1315                                   mapping=get_embedding_static_quant_module_mappings())
1316
1317            # Ensure that EmbeddingBags are now quantized with the appropriate bitwidth.
1318            self.assertEqual(type(inference_gm.emb1), torch.ao.nn.quantized.EmbeddingBag)
1319            self.assertEqual(type(inference_gm.emb2), torch.ao.nn.quantized.EmbeddingBag)
1320            self.assertEqual(inference_gm.emb1.dtype, qconfig.weight().dtype)
1321            self.assertEqual(inference_gm.emb2.dtype, qconfig.weight().dtype)
1322
1323    def test_embedding_qat_config(self):
1324        for qengine in supported_qengines:
1325            with override_quantized_engine(qengine):
1326                model = DeFusedEmbeddingBagLinear()
1327                indices = torch.randint(0, 10, (5, 12))
1328                quant_model = prepare_qat(model,
1329                                          mapping=get_embedding_qat_module_mappings())
1330
1331                count_fake_quant = 0
1332                count_activation_postproc = 0
1333                for name, mod in quant_model.named_modules():
1334                    if name.endswith('weight_fake_quant'):
1335                        count_fake_quant += 1
1336                    if name.count('activation_post_process') == 1 and 'weight_fake_quant' not in name:
1337                        count_activation_postproc += 1
1338                # One for embeddings, one for linear layer.
1339                self.assertEqual(count_fake_quant, 2)
1340                # One for embeddings (but it is a NoOp), One for quantize, one for linear layer.
1341                self.assertEqual(count_activation_postproc, 3)
1342
1343                self.assertEqual(type(quant_model.emb.weight_fake_quant), FakeQuantize)
1344                self.assertEqual(quant_model.emb.weight_fake_quant.zero_point.dtype, torch.float32)
1345                self.assertEqual(type(quant_model.emb.activation_post_process), NoopObserver)
1346                self.assertEqual(type(quant_model.linear.weight_fake_quant), FusedMovingAvgObsFakeQuantize)
1347                self.assertEqual(type(quant_model.linear.activation_post_process), FusedMovingAvgObsFakeQuantize)
1348
1349                quant_model(indices)
1350                inference_gm = convert(quant_model,
1351                                       mapping=get_embedding_static_quant_module_mappings())
1352                # Ensure that Embedding is now quantized
1353                self.assertEqual(type(inference_gm.emb), torch.ao.nn.quantized.Embedding)
1354                # Ensure that Linear is now quantized
1355                self.assertEqual(type(inference_gm.linear), torch.ao.nn.quantized.Linear)
1356
1357    def test_default_fused_qat_config(self):
1358        class Model(nn.Module):
1359            def __init__(self) -> None:
1360                super().__init__()
1361                self.linear = nn.Linear(2, 2)
1362                self.relu = nn.ReLU()
1363
1364            def forward(self, x):
1365                x = self.linear(x)
1366                x = self.relu(x)
1367                return x
1368
1369        for qengine in ["fbgemm", "qnnpack"]:
1370            model = Model()
1371            model.linear.weight = torch.nn.Parameter(torch.randn(2, 2))
1372            sample_input = torch.randn(2, 2)
1373            model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine, version=1)
1374            ref_model = torch.ao.quantization.QuantWrapper(model)
1375            ref_model = torch.ao.quantization.prepare_qat(ref_model)
1376            ref_model(sample_input)
1377            count_fake_quant = 0
1378            for name, mod in ref_model.named_modules():
1379                if name.endswith('weight_fake_quant'):
1380                    count_fake_quant += 1
1381                    self.assertEqual(type(mod), FusedMovingAvgObsFakeQuantize)
1382
1383                if name.count('activation_post_process') == 1 and 'weight_fake_quant' not in name:
1384                    count_fake_quant += 1
1385                    self.assertEqual(type(mod), FusedMovingAvgObsFakeQuantize)
1386
1387            self.assertEqual(count_fake_quant, 3)
1388
1389            if qengine == "fbgemm":
1390                lower_bnd = 0
1391                upper_bnd = 127
1392                obs2match = MovingAveragePerChannelMinMaxObserver
1393
1394            else:
1395                lower_bnd = 0
1396                upper_bnd = 255
1397                obs2match = MovingAverageMinMaxObserver
1398
1399            self.assertEqual(ref_model.quant.activation_post_process.activation_post_process.quant_min, lower_bnd)
1400            self.assertEqual(ref_model.quant.activation_post_process.activation_post_process.quant_max, upper_bnd)
1401            self.assertEqual(type(ref_model.module.linear.weight_fake_quant.activation_post_process),
1402                             obs2match)
1403
1404if __name__ == '__main__':
1405    raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
1406                       "\tpython test/test_quantization.py TESTNAME\n\n"
1407                       "instead.")
1408