xref: /aosp_15_r20/external/pytorch/test/quantization/jit/test_ondevice_quantization.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: quantization"]
2
3import io
4from typing import Dict
5
6import torch
7import torch._C
8from torch.ao.quantization import default_dynamic_qconfig, per_channel_dynamic_qconfig
9from torch.ao.quantization.quantize_jit import (
10    _prepare_ondevice_dynamic_jit,
11    _quantize_ondevice_dynamic_jit,
12    convert_dynamic_jit,
13    prepare_dynamic_jit,
14)
15from torch.jit.mobile import _load_for_lite_interpreter, LiteScriptModule
16from torch.testing import FileCheck
17from torch.testing._internal.common_quantization import (
18    get_script_module,
19    LinearAddModel,
20)
21from torch.testing._internal.common_utils import TestCase
22from torch.utils import bundled_inputs as bundled_inputs
23
24
25class myMod(torch.nn.Module):
26    def __init__(self, weight):
27        super().__init__()
28        self.fc1 = torch.nn.Linear(5, 5).float()
29        self.fc1.weight = weight
30        self.fc2 = torch.nn.Linear(5, 5).float()
31
32    def forward(self, x):
33        return self.fc2(self.fc1(x))
34
35
36class MyConvLinearModule(torch.nn.Module):
37    def __init__(self) -> None:
38        super().__init__()
39        self.conv = torch.nn.Conv2d(3, 5, 3)
40        weight = torch.nn.Parameter(torch.ones(5, 5))
41        self.weight1 = torch.nn.Parameter(torch.ones(5, 5))
42        self.mymod = myMod(weight)
43
44    def forward(self, x):
45        conv_output = self.conv(x)
46        y = self.mymod(conv_output)
47        z = torch.nn.functional.linear(y, self.weight1)
48        return z
49
50    def get_example_inputs(self):
51        return (torch.rand(1, 3, 12, 7),)
52
53
54class OnDevicePTQUtils:
55    observer_module_name = ["MinMaxObserver", "PerChannelMinMaxObserver"]
56
57    @staticmethod
58    def insert_observers(model, qconfig_dict):
59        inputs = model.get_example_inputs()
60        scripted_model = get_script_module(model, False, inputs)
61        scripted_model = _prepare_ondevice_dynamic_jit(scripted_model, qconfig_dict)
62        return scripted_model
63
64    @staticmethod
65    def ptq_dynamic_quantize(model, qconfig_dict):
66        inputs = model.get_example_inputs()
67        m = get_script_module(model, False, inputs)
68        m = _quantize_ondevice_dynamic_jit(m, qconfig_dict, "forward", True)
69        return m
70
71    @staticmethod
72    def find_observer_modules(m):
73        observer_modules = []
74        for child_module in m.children():
75            if child_module.original_name in OnDevicePTQUtils.observer_module_name:
76                observer_modules.append(child_module)
77        return observer_modules
78
79    @staticmethod
80    def is_value_type_observer(value):
81        type_name = value.type()
82        for observer_type in OnDevicePTQUtils.observer_module_name:
83            if observer_type in type_name.str():
84                return True
85        return False
86
87    @staticmethod
88    def is_calculate_qparam(node):
89        if node.kind() == "prim::CallMethod":
90            if node.s("name") == "calculate_qparams":
91                return True
92        return False
93
94    @staticmethod
95    def get_linear_packed_param_fp_weight(node):
96        weight = node.inputsAt(0).node()
97        if (
98            weight.kind() != "aten::quantize_per_tensor"
99            and weight.kind() != "aten::quantize_per_channel"
100        ):
101            raise ValueError("Quantized weight must be produced.")
102        fp_weight = weight.inputsAt(0).node()
103        assert (
104            fp_weight.kind() == "prim::GetAttr"
105        ), "Weight must be an attribute of the module."
106        fp_weight_name = fp_weight.s("name")
107        return fp_weight_name
108
109    @staticmethod
110    def is_per_channel_quantized_packed_param(node):
111        assert (
112            node.kind() == "quantized::linear_prepack"
113        ), "Node must corresponds to linear_prepack."
114        weight = node.inputsAt(0).node()
115        assert (
116            weight.kind() != "aten::quantize_per_tensor"
117            or weight.kind() != "aten::quantize_per_channel"
118        )
119        return weight.kind() != "aten::quantize_per_tensor"
120
121
122class TestOnDeviceDynamicPTQInsertObservers(TestCase):
123    def _check_num_and_type_of_observers(self, model, num_observers):
124        qconfig_dict = {"": default_dynamic_qconfig}
125        scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
126        observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model)
127        self.assertTrue(len(observer_modules) == num_observers)
128        for observer in observer_modules:
129            self.assertTrue(observer.original_name == "MinMaxObserver")
130
131        qconfig_dict = {"": per_channel_dynamic_qconfig}
132        scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
133        observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model)
134        self.assertTrue(len(observer_modules) == num_observers)
135        for observer in observer_modules:
136            self.assertTrue(observer.original_name == "PerChannelMinMaxObserver")
137
138    def _check_observer_method(self, model, num_observers):
139        qconfig_dict = {"": default_dynamic_qconfig}
140        inputs = model.get_example_inputs()
141        orig_scripted_model = get_script_module(model, False, inputs)
142        torch._C._jit_pass_inline(orig_scripted_model.graph)
143        orig_forward_graph = orig_scripted_model.graph.str()
144        scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
145        quant_forward_graph = scripted_model.graph.str()
146        # exact graph matching is difficult so just resorting to # of lines
147        # instead of implementing graph matching
148        self.assertEqual(
149            len(orig_forward_graph.splitlines()), len(quant_forward_graph.splitlines())
150        )
151        observe_method = scripted_model.observe_forward.graph
152        FileCheck().check_count(
153            'prim::CallMethod[name="forward"](%_observer', num_observers, exactly=True
154        ).run(observe_method)
155        reset_observers_method = scripted_model.reset_observers_forward.graph
156        FileCheck().check_count(
157            'prim::CallMethod[name="reset_min_max_vals"](%_observer',
158            num_observers,
159            exactly=True,
160        ).run(reset_observers_method)
161
162    def _observer_is_weight_only(self, node):
163        if (node.kind() == "prim::CallMethod") and node.s("name") == "forward":
164            if OnDevicePTQUtils.is_value_type_observer(node.inputsAt(0)):
165                return node.inputsAt(1).node().kind() == "prim::GetAttr"
166        return False
167
168    def test_num_observers(self):
169        model = LinearAddModel()
170        self._check_num_and_type_of_observers(model, 2)
171        model = MyConvLinearModule()
172        self._check_num_and_type_of_observers(model, 3)
173
174    def test_observe_method(self):
175        model = MyConvLinearModule()
176        self._check_observer_method(model, 3)
177
178    def test_weight_only_observers(self):
179        model = MyConvLinearModule()
180        qconfig_dict = {"": default_dynamic_qconfig}
181        inputs = model.get_example_inputs()
182        scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
183        observe_forward_graph = scripted_model.observe_forward.graph
184        num_weight_only_observers = 0
185        for node in observe_forward_graph.nodes():
186            if self._observer_is_weight_only(node):
187                num_weight_only_observers += 1
188        self.assertEqual(num_weight_only_observers, 3)
189
190
191class TestOnDeviceDynamicPTQInsertQuantDequant(TestCase):
192    def _validate_quant_dequant_nodes(self, model, num_nodes, per_channel=0):
193        quantize_forward_graph = model.quantize_forward.graph
194        quantize_per_tensor = quantize_per_channel = 0
195        for n in quantize_forward_graph.nodes():
196            if "aten::quantize_per_tensor" in n.kind():
197                quantize_per_tensor += 1
198            if "aten::quantize_per_channel" in n.kind():
199                quantize_per_channel += 1
200        self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes)
201
202    def _validate_calculate_qparams(self, model, num_nodes):
203        quantize_forward_graph = model.quantize_forward.graph
204        num_calculate_qparams = 0
205        for n in quantize_forward_graph.nodes():
206            if OnDevicePTQUtils.is_calculate_qparam(n):
207                num_calculate_qparams += 1
208        self.assertEqual(num_calculate_qparams, num_nodes)
209
210    def _validate_no_observer_forward(self, model):
211        quantize_forward_graph = model.quantize_forward.graph
212        for n in quantize_forward_graph.nodes():
213            if (n.kind() == "prim::CallMethod") and n.s("name") == "forward":
214                if OnDevicePTQUtils.is_value_type_observer(n.inputsAt(0)):
215                    return False
216        return True
217
218    def _check_quant_dequant_and_calc_qparams(self, model, num_nodes):
219        qconfig_dict = {"": default_dynamic_qconfig}
220        m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
221        self._validate_quant_dequant_nodes(m, num_nodes)
222        self._validate_calculate_qparams(m, num_nodes)
223        self._validate_no_observer_forward(m)
224
225        qconfig_dict = {"": per_channel_dynamic_qconfig}
226        m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
227        self._validate_quant_dequant_nodes(m, num_nodes, num_nodes)
228        self._validate_calculate_qparams(m, num_nodes)
229        self._validate_no_observer_forward(m)
230
231    def _check_quantize_forward_runs(self, model):
232        inputs = model.get_example_inputs()
233        qconfig_dict = {"": default_dynamic_qconfig}
234        m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
235        m.observe_forward(*inputs)
236        m.quantize_forward(*inputs)
237
238        qconfig_dict = {"": per_channel_dynamic_qconfig}
239        m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
240        # First must run observe forward to record the stats to produce
241        # correct scales and zero points
242        m.observe_forward(*inputs)
243        m.quantize_forward(*inputs)
244
245    def test_num_quant_dequant_nodes(self):
246        model = LinearAddModel()
247        self._check_quant_dequant_and_calc_qparams(model, 2)
248        model = MyConvLinearModule()
249        self._check_quant_dequant_and_calc_qparams(model, 3)
250
251    def test_quantize_forward_runs(self):
252        model = LinearAddModel()
253        self._check_quantize_forward_runs(model)
254        model = MyConvLinearModule()
255        self._check_quantize_forward_runs(model)
256
257
258class TestOnDeviceDynamicPTQFinalize(TestCase):
259    def _validate_packed_params(self, model, num_nodes, per_channel=0):
260        quantize_forward_graph = model.quantize_forward.graph
261        quantize_per_tensor = quantize_per_channel = 0
262        linear_prepack = 0
263        linear_prepack_uses = 0
264        for n in quantize_forward_graph.nodes():
265            if n.kind() == "prim::SetAttr":
266                maybe_packed_param_value = n.inputsAt(1)
267                maybe_packed_param = maybe_packed_param_value.node()
268                if maybe_packed_param.kind() == "quantized::linear_prepack":
269                    linear_prepack += 1
270                    linear_prepack_uses += len(maybe_packed_param_value.uses())
271                    if OnDevicePTQUtils.is_per_channel_quantized_packed_param(
272                        maybe_packed_param
273                    ):
274                        quantize_per_channel += 1
275                    else:
276                        quantize_per_tensor += 1
277        self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes)
278        self.assertEqual(quantize_per_channel, per_channel)
279        self.assertEqual(linear_prepack, num_nodes)
280        self.assertEqual(linear_prepack_uses, num_nodes)
281
282    def _validate_no_linear_unpack(self, model):
283        quantize_forward_graph = model.quantize_forward.graph
284        for n in quantize_forward_graph.nodes():
285            if n.kind() == "quantized::linear_unpack":
286                return False
287        return True
288
289    def _validate_setattr_fp_weights(self, model, num_nodes):
290        quantize_forward_graph = model.quantize_forward.graph
291        fp_weights_setattr = 0
292        fp_weight_names = []
293        for n in quantize_forward_graph.nodes():
294            if n.kind() == "prim::SetAttr":
295                maybe_packed_param = n.inputsAt(1).node()
296                if maybe_packed_param.kind() == "quantized::linear_prepack":
297                    weight_name = OnDevicePTQUtils.get_linear_packed_param_fp_weight(
298                        maybe_packed_param
299                    )
300                    fp_weight_names.append(weight_name)
301
302        for n in quantize_forward_graph.nodes():
303            # This is basically detecting
304            # %x = prim::Constant
305            # = prim::SetAttr(<weight_name>)(module_value, x)
306            # Thus making sure that the original fp weights are
307            # reset
308            if n.kind() == "prim::SetAttr":
309                weight_name = n.s("name")
310                if weight_name in fp_weight_names:
311                    maybe_constant = n.inputsAt(1).node()
312                    if maybe_constant.kind() == "prim::Constant":
313                        fp_weights_setattr += 1
314        self.assertEqual(fp_weights_setattr, num_nodes)
315
316    def _validate_quantized_forward(self, model, num_nodes):
317        quantized_forward_graph = model.quantized_forward.graph
318        quantize_per_tensor = quantize_per_channel = 0
319        quantized_linear_dynamic = 0
320        linear_packed_params = 0
321        num_setattr = 0
322        for n in quantized_forward_graph.nodes():
323            if "aten::quantize_per_tensor" in n.kind():
324                quantize_per_tensor += 1
325            if "aten::quantize_per_channel" in n.kind():
326                quantize_per_channel += 1
327            if "quantized::linear_dynamic" in n.kind():
328                quantized_linear_dynamic += 1
329            if n.kind() == "prim::GetAttr":
330                output = n.outputsAt(0)
331                output_type = output.type()
332                if "LinearPackedParamsBase" in output_type.str():
333                    linear_packed_params += 1
334            if n.kind() == "prim::SetAttr":
335                num_setattr += 1
336        self.assertEqual(quantize_per_tensor, 0)
337        self.assertEqual(quantize_per_channel, 0)
338        self.assertEqual(quantized_linear_dynamic, num_nodes)
339        self.assertEqual(linear_packed_params, num_nodes)
340        # self.assertEqual(num_setattr, 0)
341
342    def _check_quantize_forward(self, model, num_nodes):
343        qconfig_dict = {"": default_dynamic_qconfig}
344        m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
345        self._validate_packed_params(m, num_nodes)
346        self._validate_no_linear_unpack(m)
347        self._validate_setattr_fp_weights(m, num_nodes)
348
349        qconfig_dict = {"": per_channel_dynamic_qconfig}
350        m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
351        self._validate_packed_params(m, num_nodes, num_nodes)
352        self._validate_no_linear_unpack(m)
353        self._validate_setattr_fp_weights(m, num_nodes)
354
355    def _check_quantized_forward(self, model, num_nodes):
356        qconfig_dict = {"": default_dynamic_qconfig}
357        m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
358        self._validate_quantized_forward(m, num_nodes)
359
360        qconfig_dict = {"": per_channel_dynamic_qconfig}
361        m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
362        self._validate_quantized_forward(m, num_nodes)
363
364    def _check_against_ref_dynamic_ptq(self, model):
365        model.eval()
366        inputs = model.get_example_inputs()
367        ref_m = torch.jit.script(model)
368        torch._C._jit_pass_inline(ref_m.graph)
369        qconfig_dict = {"": default_dynamic_qconfig}
370        ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
371        ref_m = convert_dynamic_jit(ref_m)
372        ref_output = ref_m(*inputs)
373
374        m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
375        m.observe_forward(*inputs)
376        m.quantize_forward(*inputs)
377        output = m.quantized_forward(*inputs)
378        self.assertTrue(torch.allclose(ref_output, output))
379        thrown = False
380        try:
381            m(*inputs)
382        except Exception as e:
383            thrown = True
384        self.assertTrue(thrown)
385
386        # test with per channel quant
387        ref_m = torch.jit.script(model)
388        torch._C._jit_pass_inline(ref_m.graph)
389        qconfig_dict = {"": per_channel_dynamic_qconfig}
390        ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
391        ref_m = convert_dynamic_jit(ref_m)
392        ref_output = ref_m(*inputs)
393
394        m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
395        m.observe_forward(*inputs)
396        m.quantize_forward(*inputs)
397        output = m.quantized_forward(*inputs)
398        self.assertTrue(torch.allclose(ref_output, output))
399        thrown = False
400        try:
401            m(*inputs)
402        except Exception as e:
403            thrown = True
404        self.assertTrue(thrown)
405
406    def _check_serdes_and_device_side_api_helper(
407        self, model, check_device_side_api=False
408    ):
409        model.eval()
410        inputs = model.get_example_inputs()
411        ref_m = torch.jit.script(model)
412        torch._C._jit_pass_inline(ref_m.graph)
413        qconfig_dict = {"": default_dynamic_qconfig}
414        ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
415        ref_m = convert_dynamic_jit(ref_m)
416        buffer = io.BytesIO()
417        torch.jit.save(ref_m, buffer)
418        buffer.seek(0)
419        ref_m = torch.jit.load(buffer)
420        ref_output = ref_m(*inputs)
421
422        if not check_device_side_api:
423            m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
424            buffer = io.BytesIO()
425            torch.jit.save(m, buffer)
426            buffer.seek(0)
427            m = torch.jit.load(buffer)
428            m.reset_observers_forward()
429            m.observe_forward(*inputs)
430            m.quantize_forward(*inputs)
431            output = m.quantized_forward(*inputs)
432            self.assertTrue(torch.allclose(ref_output, output))
433        else:
434            # check for lite interpreter
435            m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
436            (first_input,) = inputs
437            rand_input = bundled_inputs.bundle_randn(
438                first_input.size(), dtype=first_input.dtype
439            )
440            m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input,)])
441            buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
442            buffer.seek(0)
443            m = _load_for_lite_interpreter(buffer)  # Error here
444            torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
445            self.assertFalse(m.find_method("quantized_forward"))
446            self.assertFalse(m.find_method("quantize_forward"))
447            self.assertFalse(m.find_method("observe_forward"))
448            self.assertFalse(m.find_method("reset_observers_forward"))
449            output = m(*inputs)
450            self.assertTrue(torch.allclose(ref_output, output))
451
452            # Now serialize to flabuffer and load from fb and check
453            dict: Dict[str, str] = {}
454            bytes = torch._C._save_mobile_module_to_bytes(m._c, dict)
455            m = LiteScriptModule(torch._C._load_mobile_module_from_bytes(bytes))
456            fb_output = m(*inputs)
457            self.assertTrue(torch.allclose(ref_output, fb_output))
458
459        model.eval()
460        inputs = model.get_example_inputs()
461        ref_m = torch.jit.script(model)
462        torch._C._jit_pass_inline(ref_m.graph)
463        qconfig_dict = {"": per_channel_dynamic_qconfig}
464        ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
465        ref_m = convert_dynamic_jit(ref_m)
466        buffer = io.BytesIO()
467        torch.jit.save(ref_m, buffer)
468        buffer.seek(0)
469        ref_m = torch.jit.load(buffer)
470        ref_output = ref_m(*inputs)
471
472        if not check_device_side_api:
473            m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
474            buffer = io.BytesIO()
475            torch.jit.save(m, buffer)
476            buffer.seek(0)
477            m = torch.jit.load(buffer)
478            m.reset_observers_forward()
479            m.observe_forward(*inputs)
480            m.quantize_forward(*inputs)
481            output = m.quantized_forward(*inputs)
482            self.assertTrue(torch.allclose(ref_output, output))
483        else:
484            # check for lite interpreter
485            m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
486            (first_input,) = inputs
487            rand_input = bundled_inputs.bundle_randn(
488                first_input.size(), dtype=first_input.dtype
489            )
490            m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input,)])
491            buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
492            buffer.seek(0)
493            m = _load_for_lite_interpreter(buffer)  # Error here
494            torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
495            self.assertFalse(m.find_method("quantized_forward"))
496            self.assertFalse(m.find_method("quantize_forward"))
497            self.assertFalse(m.find_method("observe_forward"))
498            self.assertFalse(m.find_method("reset_observers_forward"))
499            output = m(*inputs)
500            self.assertTrue(torch.allclose(ref_output, output))
501
502    def _check_serialization_deserialization(self, model):
503        self._check_serdes_and_device_side_api_helper(model, False)
504
505    def _check_device_side_api(self, model):
506        self._check_serdes_and_device_side_api_helper(model, True)
507
508    def test_quantize_forward(self):
509        model = LinearAddModel()
510        self._check_quantize_forward(model, 2)
511        model = MyConvLinearModule()
512        self._check_quantize_forward(model, 3)
513
514    def test_quantized_forward(self):
515        model = LinearAddModel()
516        self._check_quantized_forward(model, 2)
517        model = MyConvLinearModule()
518        self._check_quantized_forward(model, 3)
519
520    def test_against_offdevice_dynamic_ptq(self):
521        model = LinearAddModel()
522        self._check_against_ref_dynamic_ptq(model)
523        model = MyConvLinearModule()
524        self._check_against_ref_dynamic_ptq(model)
525
526    def test_serialization_deserialization(self):
527        model = MyConvLinearModule()
528        self._check_serialization_deserialization(model)
529
530    def test_device_side_api(self):
531        model = MyConvLinearModule()
532        self._check_device_side_api(model)
533