xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/quantize_jit.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import torch
4from torch.ao.quantization.qconfig import QConfig
5from torch.ao.quantization.quant_type import QuantType
6from torch.jit._recursive import wrap_cpp_module
7
8
9__all__ = [
10    "script_qconfig",
11    "script_qconfig_dict",
12    "fuse_conv_bn_jit",
13    "prepare_jit",
14    "prepare_dynamic_jit",
15    "convert_jit",
16    "convert_dynamic_jit",
17    "quantize_jit",
18    "quantize_dynamic_jit",
19]
20
21
22def _check_is_script_module(model):
23    if not isinstance(model, torch.jit.ScriptModule):
24        raise ValueError("input must be a script module, got: " + str(type(model)))
25
26
27def _check_forward_method(model):
28    if not model._c._has_method("forward"):
29        raise ValueError("input script module does not have forward method")
30
31
32def script_qconfig(qconfig):
33    r"""Instantiate the activation and weight observer modules and script
34    them, these observer module instances will be deepcopied during
35    prepare_jit step.
36    """
37    return QConfig(
38        activation=torch.jit.script(qconfig.activation())._c,
39        weight=torch.jit.script(qconfig.weight())._c,
40    )
41
42
43def script_qconfig_dict(qconfig_dict):
44    r"""Helper function used by `prepare_jit`.
45    Apply `script_qconfig` for all entries in `qconfig_dict` that is
46    not None.
47    """
48    return {k: script_qconfig(v) if v else None for k, v in qconfig_dict.items()}
49
50
51def fuse_conv_bn_jit(model, inplace=False):
52    r"""Fuse conv - bn module
53    Works for eval model only.
54
55    Args:
56        model: TorchScript model from scripting or tracing
57    """
58    torch._C._log_api_usage_once("quantization_api.quantize_jit.fuse_conv_bn_jit")
59    model_c = model._c
60    model_c = torch._C._jit_pass_fold_convbn(model_c)
61    if inplace:
62        model._reconstruct(model_c)
63    else:
64        model = wrap_cpp_module(model_c)
65    return model
66
67
68def _prepare_jit(model, qconfig_dict, inplace=False, quant_type=QuantType.STATIC):
69    _check_is_script_module(model)
70    _check_forward_method(model)
71    if not all(isinstance(x, str) for x in qconfig_dict.keys()):
72        raise ValueError("qconfig_dict should only contain names(str) as keys.")
73    scripted_qconfig_dict = script_qconfig_dict(qconfig_dict)
74    model = fuse_conv_bn_jit(model, inplace)
75    model_c = torch._C._jit_pass_insert_observers(
76        model._c, "forward", scripted_qconfig_dict, inplace, quant_type
77    )
78    if inplace:
79        model._reconstruct(model_c)
80    else:
81        model = wrap_cpp_module(model_c)
82    return model
83
84
85def _prepare_ondevice_jit(
86    model,
87    qconfig_dict,
88    method_name="forward",
89    inplace=False,
90    quant_type=QuantType.STATIC,
91):
92    _check_is_script_module(model)
93    if not all(isinstance(x, str) for x in qconfig_dict.keys()):
94        raise ValueError("qconfig_dict should only contain names(str) as keys.")
95    scripted_qconfig_dict = script_qconfig_dict(qconfig_dict)
96    method_graph = model._c._get_method(method_name).graph
97    torch._C._jit_pass_inline(method_graph)
98    model = fuse_conv_bn_jit(model, inplace)
99    model_c = torch._C._jit_pass_insert_observer_method_for_ondevice_ptq(
100        model._c, method_name, scripted_qconfig_dict, inplace, quant_type
101    )
102    if inplace:
103        model._reconstruct(model_c)
104    else:
105        model = wrap_cpp_module(model_c)
106    return model
107
108
109def prepare_jit(model, qconfig_dict, inplace=False):
110    torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_jit")
111    return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.STATIC)
112
113
114def prepare_dynamic_jit(model, qconfig_dict, inplace=False):
115    torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_dynamic_jit")
116    return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.DYNAMIC)
117
118
119def _prepare_ondevice_dynamic_jit(
120    model, qconfig_dict, method_name="forward", inplace=False
121):
122    return _prepare_ondevice_jit(
123        model, qconfig_dict, method_name, inplace, quant_type=QuantType.DYNAMIC
124    )
125
126
127def _convert_jit(
128    model, inplace=False, debug=False, quant_type=QuantType.STATIC, preserved_attrs=None
129):
130    _check_is_script_module(model)
131    model.eval()
132    model_c = model._c
133    model_c = torch._C._jit_pass_insert_quant_dequant(
134        model_c, "forward", inplace, debug, quant_type
135    )
136    if not debug:
137        is_xpu = all(p.device.type == "xpu" for p in model.parameters())
138        if not is_xpu:
139            # Moving model parameters to CPU since quantized operators
140            # are only supported on CPU and XPU right now
141            model.cpu()
142        if preserved_attrs is None:
143            preserved_attrs = []
144        model_c = torch._C._jit_pass_quant_finalize(
145            model_c, quant_type, preserved_attrs
146        )
147    if inplace:
148        model._reconstruct(model_c)
149    else:
150        model = wrap_cpp_module(model_c)
151    torch._C._jit_pass_constant_propagation(model.graph)
152    torch._C._jit_pass_dce(model.graph)
153    return model
154
155
156def _convert_ondevice_jit(
157    model, method_name, inplace=False, debug=False, quant_type=QuantType.STATIC
158):
159    _check_is_script_module(model)
160    assert (
161        quant_type == QuantType.DYNAMIC
162    ), "This API, while should work for static quant, is only tested for dynamic quant."
163    assert not method_name.startswith(
164        "observe_"
165    ), "Pass in valid method to be quantized, e.g. forward"
166    observe_method_name = "observe_" + method_name
167    quantize_method_name = "quantize_" + method_name
168    model_c = model._c
169    model_c = torch._C._jit_pass_insert_quant_dequant_for_ondevice_ptq(
170        model._c, observe_method_name, inplace, debug, QuantType.DYNAMIC
171    )
172    model_c = torch._C._jit_pass_quant_finalize_for_ondevice_ptq(
173        model_c, QuantType.DYNAMIC, quantize_method_name
174    )
175    if inplace:
176        model._reconstruct(model_c)
177    else:
178        model = wrap_cpp_module(model_c)
179    return model
180
181
182def convert_jit(model, inplace=False, debug=False, preserved_attrs=None):
183    torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_jit")
184    return _convert_jit(
185        model,
186        inplace,
187        debug,
188        quant_type=QuantType.STATIC,
189        preserved_attrs=preserved_attrs,
190    )
191
192
193def convert_dynamic_jit(model, inplace=False, debug=False, preserved_attrs=None):
194    torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_dynamic_jit")
195    return _convert_jit(
196        model,
197        inplace,
198        debug,
199        quant_type=QuantType.DYNAMIC,
200        preserved_attrs=preserved_attrs,
201    )
202
203
204def _convert_ondevice_dynamic_jit(model, method_name, inplace=False, debug=False):
205    return _convert_ondevice_jit(
206        model, method_name, inplace, debug, quant_type=QuantType.DYNAMIC
207    )
208
209
210def _quantize_ondevice_dynamic_jit_impl(
211    model, qconfig_dict, method_name, inplace=False
212):
213    model = _prepare_ondevice_dynamic_jit(model, qconfig_dict, method_name, inplace)
214    model = _convert_ondevice_dynamic_jit(model, method_name, inplace)
215    return model
216
217
218def _quantize_jit(
219    model,
220    qconfig_dict,
221    run_fn=None,
222    run_args=None,
223    inplace=False,
224    debug=False,
225    quant_type=QuantType.STATIC,
226):
227    # Always do inplace convert because the Tensor is already
228    # copied in prepare_jit when inplace is False
229    if quant_type == QuantType.DYNAMIC:
230        model = prepare_dynamic_jit(model, qconfig_dict, inplace)
231        model = convert_dynamic_jit(model, True, debug)
232    else:
233        assert (
234            run_fn
235        ), "Must provide calibration function for post training static quantization"
236        assert (
237            run_args
238        ), "Must provide calibration dataset for post training static quantization"
239        model = prepare_jit(model, qconfig_dict, inplace)
240        run_fn(model, *run_args)
241        model = convert_jit(model, True, debug)
242
243    torch._C._jit_pass_constant_propagation(model.graph)
244    torch._C._jit_pass_dce(model.graph)
245    return model
246
247
248def quantize_jit(model, qconfig_dict, run_fn, run_args, inplace=False, debug=False):
249    r"""Quantize the input float TorchScript model with
250    post training static quantization.
251
252    First it will prepare the model for calibration, then it calls
253    `run_fn` which will run the calibration step, after that we will
254    convert the model to a quantized model.
255
256    Args:
257        `model`: input float TorchScript model
258        `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
259        qconfig for that module as value, empty key means the qconfig will be applied
260        to whole model unless it's overwritten by more specific configurations, the
261        qconfig for each module is either found in the dictionary or fallback to
262         the qconfig of parent module.
263
264        Right now qconfig_dict is the only way to configure how the model is quantized,
265        and it is done in the granularity of module, that is, we only support one type
266        of qconfig for each torch.nn.Module, and the qconfig for sub module will
267        override the qconfig for parent module, empty string means global configuration.
268        `run_fn`: a calibration function for calibrating the prepared model
269        `run_args`: positional arguments for `run_fn`
270        `inplace`: carry out model transformations in-place, the original module is
271        mutated
272        `debug`: flag for producing a debug friendly model (preserve weight attribute)
273
274    Return:
275        Quantized TorchSciprt model.
276
277    Example:
278    ```python
279    import torch
280    from torch.ao.quantization import get_default_qconfig
281    from torch.ao.quantization import quantize_jit
282
283    ts_model = torch.jit.script(float_model.eval())  # or torch.jit.trace(float_model, input)
284    qconfig = get_default_qconfig('fbgemm')
285    def calibrate(model, data_loader):
286        model.eval()
287        with torch.no_grad():
288            for image, target in data_loader:
289                model(image)
290
291    quantized_model = quantize_jit(
292        ts_model,
293        {'': qconfig},
294        calibrate,
295        [data_loader_test])
296    ```
297    """
298    torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_jit")
299    return _quantize_jit(
300        model,
301        qconfig_dict,
302        run_fn,
303        run_args,
304        inplace,
305        debug,
306        quant_type=QuantType.STATIC,
307    )
308
309
310def quantize_dynamic_jit(model, qconfig_dict, inplace=False, debug=False):
311    r"""Quantize the input float TorchScript model with
312    post training dynamic quantization.
313    Currently only qint8 quantization of torch.nn.Linear is supported.
314
315    Args:
316        `model`: input float TorchScript model
317        `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
318        qconfig for that module as value, please see detailed
319        descriptions in :func:`~torch.ao.quantization.quantize_jit`
320        `inplace`: carry out model transformations in-place, the original module is
321        mutated
322        `debug`: flag for producing a debug friendly model (preserve weight attribute)
323
324    Return:
325        Quantized TorchSciprt model.
326
327    Example:
328    ```python
329    import torch
330    from torch.ao.quantization import per_channel_dynamic_qconfig
331    from torch.ao.quantization import quantize_dynamic_jit
332
333    ts_model = torch.jit.script(float_model.eval())  # or torch.jit.trace(float_model, input)
334    qconfig = get_default_qconfig('fbgemm')
335    def calibrate(model, data_loader):
336        model.eval()
337        with torch.no_grad():
338            for image, target in data_loader:
339                model(image)
340
341    quantized_model = quantize_dynamic_jit(
342        ts_model,
343        {'': qconfig},
344        calibrate,
345        [data_loader_test])
346    ```
347    """
348    torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_dynamic_jit")
349    return _quantize_jit(
350        model, qconfig_dict, inplace=inplace, debug=debug, quant_type=QuantType.DYNAMIC
351    )
352
353
354def _quantize_ondevice_dynamic_jit(
355    model, qconfig_dict, method_name="forward", inplace=False
356):
357    r"""Prepares the input float TorchScript model with
358    *on-device* post training dynamic quantization.
359    Currently only qint8 quantization of torch.nn.Linear is supported.
360
361    Args:
362        `model`: input float TorchScript model
363        `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
364        qconfig for that module as value, please see detailed
365        `method_name`: Name of the method within the model, to be prepared for quantization
366        descriptions in :func:`~torch.ao.quantization.quantize_jit`
367        `inplace`: carry out model transformations in-place, the original module is
368        mutated
369
370    Return:
371        TorchScript model that is ready for on device quantization.
372        This means that the returned
373        model has:
374        - Method is inlined.
375        - Model has observer modules inserted in the model.
376        - Model has packed params inserted in the model. However they are empty as in they dont
377          contain valid quantized weights.
378        - observe_<method_name> is added that observe the values to be quantized.
379        - reset_observers_<method_name> to reset observers.
380        - quantize_<method_name> is added to the model.
381          - This method extract scale, zero points.
382          - Quantizes observed weights.
383          - Creates packed params from it and update the attribute of the model with the new values
384            for the packed params.
385          - Reset the original fp32 weights with empty tensor using SetAttr.
386        - quantized_<method_name> is added to the model.
387          - This method uses quantized weights and quantized linear ops instead of fp32 op.
388          - This method should be used for inference post PTQ.
389        - Note that all method's signatures should be the same as method_name.
390
391        Later on device:
392        - Run reset_observers_<method_name>
393        - Run observe_<method_name>
394        - Run quantize_<method_name>
395        - Now model can be saved and loaded later.
396        - Run model with quantized_<method_name>
397
398    Example:
399    ```python
400    import torch
401    from torch.ao.quantization import per_channel_dynamic_qconfig
402    from torch.ao.quantization.quantize_jit import _quantize_ondevice_dynamic_jit
403
404    ts_model = torch.jit.script(float_model.eval())  # or torch.jit.trace(float_model, input)
405    qconfig = get_default_qconfig('fbgemm')
406    quant_ready_model = _quantize_ondevice_dynamic_jit(
407        ts_model,
408        {'': qconfig},
409        'forward',
410        True)
411    ```
412    """
413    return _quantize_ondevice_dynamic_jit_impl(
414        model, qconfig_dict, method_name, inplace=inplace
415    )
416