xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/quantize_fx.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import copy
2import warnings
3from typing import Any, Dict, Optional, Tuple, Union
4
5import torch
6from torch.fx import GraphModule
7from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY
8
9from .backend_config import BackendConfig, get_tensorrt_backend_config  # noqa: F401
10from .fx.convert import convert
11from .fx.custom_config import ConvertCustomConfig, FuseCustomConfig, PrepareCustomConfig
12from .fx.fuse import fuse  # noqa: F401
13from .fx.graph_module import ObservedGraphModule  # noqa: F401
14from .fx.prepare import prepare  # noqa: F401
15from .fx.tracer import QuantizationTracer, Scope, ScopeContextManager  # noqa: F401
16from .fx.utils import (  # noqa: F401
17    get_custom_module_class_keys,
18    get_skipped_module_name_and_classes,
19)
20from .qconfig_mapping import QConfigMapping
21
22
23def attach_preserved_attrs_to_model(
24    model: Union[GraphModule, torch.nn.Module],
25    preserved_attrs: Dict[str, Any],
26) -> None:
27    """Store preserved attributes to the model.meta so that it can be preserved during deepcopy"""
28    model.meta[_USER_PRESERVED_ATTRIBUTES_KEY] = copy.copy(preserved_attrs)  # type: ignore[operator, index, assignment]
29    # set the preserved attributes in the model so that user can call
30    # model.attr as they do before calling fx graph mode quantization
31    for attr_name, attr in model.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():  # type: ignore[index, union-attr]
32        setattr(model, attr_name, attr)
33
34
35def _check_is_graph_module(model: torch.nn.Module) -> None:
36    if not isinstance(model, GraphModule):
37        raise ValueError(
38            "input model must be a GraphModule, "
39            + "Got type:"
40            + str(type(model))
41            + " Please make "
42            + "sure to follow the tutorials."
43        )
44
45
46def _attach_meta_to_node_if_not_exist(model: GraphModule) -> None:
47    """Attach meta field to all nodes of the graph if it does not exist,
48    meta field is a field stores some meta information about the node, such
49    as dtype and shape information for output of the node, this only exists
50    if the program is captured by make_fx (used in quantize_pt2e flow), if
51    the program is captured by torch.fx symbolic tracing, this field may not exist,
52    so we add it here to avoid checking this all over the places
53    """
54    for node in model.graph.nodes:
55        if not hasattr(node, "meta"):
56            node.meta = {}
57
58
59def _swap_ff_with_fxff(model: torch.nn.Module) -> None:
60    r"""Swap FloatFunctional with FXFloatFunctional"""
61    modules_to_swap = []
62    for name, module in model.named_children():
63        if isinstance(module, torch.ao.nn.quantized.FloatFunctional):
64            modules_to_swap.append(name)
65        else:
66            _swap_ff_with_fxff(module)
67
68    for name in modules_to_swap:
69        del model._modules[name]
70        model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional()
71
72
73def _fuse_fx(
74    model: GraphModule,
75    is_qat: bool,
76    fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
77    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
78) -> GraphModule:
79    r"""Internal helper function to fuse modules in preparation for quantization
80
81    Args:
82        model: GraphModule object from symbolic tracing (torch.fx.symbolic_trace)
83    """
84    _check_is_graph_module(model)
85    return fuse(
86        model, is_qat, fuse_custom_config, backend_config
87    )  # type: ignore[operator]
88
89
90def _prepare_fx(
91    model: torch.nn.Module,
92    qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
93    is_qat: bool,
94    example_inputs: Tuple[Any, ...],
95    prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
96    _equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
97    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
98    is_standalone_module: bool = False,
99) -> GraphModule:
100    r"""Internal helper function for prepare_fx
101        Args:
102          `model`, `qconfig_mapping`, `prepare_custom_config`, `_equalization_config`:
103          see docs for :func:`~torch.ao.quantization.prepare_fx`
104          `is_standalone_module`: a boolean flag indicates whether we are
105          quantizing a standalone module or not, a standalone module
106          is a submodule of the parent module that is not inlined in the
107    forward graph of the parent module,
108          the way we quantize standalone module is described in:
109          :func:`~torch.ao.quantization._prepare_standalone_module_fx`
110    """
111    if prepare_custom_config is None:
112        prepare_custom_config = PrepareCustomConfig()
113    if _equalization_config is None:
114        _equalization_config = QConfigMapping()
115
116    if isinstance(prepare_custom_config, dict):
117        warnings.warn(
118            "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported "
119            "in a future version. Please pass in a PrepareCustomConfig instead.",
120            FutureWarning,
121            stacklevel=3,
122        )
123        prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config)
124
125    # swap FloatFunctional with FXFloatFunctional
126    _swap_ff_with_fxff(model)
127
128    skipped_module_names, skipped_module_classes = get_skipped_module_name_and_classes(
129        prepare_custom_config, is_standalone_module
130    )
131    preserved_attr_names = prepare_custom_config.preserved_attributes
132    preserved_attrs = {
133        attr: getattr(model, attr)
134        for attr in preserved_attr_names
135        if hasattr(model, attr)
136    }
137    # symbolically trace the model
138    tracer = QuantizationTracer(skipped_module_names, skipped_module_classes)  # type: ignore[arg-type]
139    graph_module = GraphModule(model, tracer.trace(model))
140    _attach_meta_to_node_if_not_exist(graph_module)
141
142    fuse_custom_config = FuseCustomConfig().set_preserved_attributes(
143        prepare_custom_config.preserved_attributes
144    )
145    graph_module = _fuse_fx(graph_module, is_qat, fuse_custom_config, backend_config)
146    prepared = prepare(
147        graph_module,
148        qconfig_mapping,
149        is_qat,
150        tracer.node_name_to_scope,
151        example_inputs=example_inputs,
152        prepare_custom_config=prepare_custom_config,
153        _equalization_config=_equalization_config,
154        backend_config=backend_config,
155        is_standalone_module=is_standalone_module,
156    )  # type: ignore[operator]
157
158    attach_preserved_attrs_to_model(prepared, preserved_attrs)
159    return prepared
160
161
162def _prepare_standalone_module_fx(
163    model: torch.nn.Module,
164    qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
165    is_qat: bool,
166    example_inputs: Tuple[Any, ...],
167    prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
168    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
169) -> GraphModule:
170    r"""[Internal use only] Prepare a standalone module, so that it can be used when quantizing the
171    parent module.
172    standalone_module means it a submodule that is not inlined in parent module,
173    and will be quantized separately as one unit.
174
175    How the standalone module is observed is specified by `input_quantized_idxs` and
176    `output_quantized_idxs` in the prepare_custom_config for the standalone module
177
178    Returns:
179
180        * model(GraphModule): prepared standalone module. It has these attributes in
181          model.meta:
182
183            * `standalone_module_input_quantized_idxs(List[Int])`: a list of
184              indexes for the graph input that is expected to be quantized,
185              same as input_quantized_idxs configuration provided
186              for the standalone module
187            * `standalone_module_output_quantized_idxs(List[Int])`: a list of
188              indexs for the graph output that is quantized
189              same as input_quantized_idxs configuration provided
190              for the standalone module
191
192    """
193    return _prepare_fx(
194        model,
195        qconfig_mapping,
196        is_qat,
197        example_inputs,
198        prepare_custom_config,
199        backend_config=backend_config,
200        is_standalone_module=True,
201    )
202
203
204def fuse_fx(
205    model: torch.nn.Module,
206    fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
207    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
208) -> GraphModule:
209    r"""Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode.
210    Fusion rules are defined in torch.ao.quantization.fx.fusion_pattern.py
211
212    Args:
213
214        * `model` (torch.nn.Module): a torch.nn.Module model
215        * `fuse_custom_config` (FuseCustomConfig): custom configurations for fuse_fx.
216            See :class:`~torch.ao.quantization.fx.custom_config.FuseCustomConfig` for more details
217    Example::
218
219        from torch.ao.quantization import fuse_fx
220        m = Model().eval()
221        m = fuse_fx(m)
222
223    """
224    if fuse_custom_config is None:
225        fuse_custom_config = FuseCustomConfig()
226
227    if isinstance(fuse_custom_config, dict):
228        warnings.warn(
229            "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported "
230            "in a future version. Please pass in a FuseCustomConfig instead.",
231            FutureWarning,
232            stacklevel=2,
233        )
234        fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config)
235
236    torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx")
237    preserved_attr_names = fuse_custom_config.preserved_attributes
238    preserved_attrs = {
239        attr: getattr(model, attr)
240        for attr in preserved_attr_names
241        if hasattr(model, attr)
242    }
243
244    graph_module = torch.fx.symbolic_trace(model)
245    _attach_meta_to_node_if_not_exist(graph_module)
246    graph_module = _fuse_fx(graph_module, False, fuse_custom_config, backend_config)
247
248    attach_preserved_attrs_to_model(graph_module, preserved_attrs)
249    return graph_module
250
251
252def prepare_fx(
253    model: torch.nn.Module,
254    qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
255    example_inputs: Tuple[Any, ...],
256    prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
257    _equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None,
258    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
259) -> GraphModule:
260    r""" Prepare a model for post training quantization
261
262    Args:
263      * `model` (torch.nn.Module): torch.nn.Module model
264
265      * `qconfig_mapping` (QConfigMapping): QConfigMapping object to configure how a model is
266         quantized, see :class:`~torch.ao.quantization.qconfig_mapping.QConfigMapping`
267         for more details
268
269      * `example_inputs` (Tuple[Any, ...]): Example inputs for forward function of the model,
270         Tuple of positional args (keyword args can be passed as positional args as well)
271
272      * `prepare_custom_config` (PrepareCustomConfig): customization configuration for quantization tool.
273          See :class:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig` for more details
274
275      * `_equalization_config`: config for specifying how to perform equalization on the model
276
277      * `backend_config` (BackendConfig): config that specifies how operators are quantized
278         in a backend, this includes how the operators are observed,
279         supported fusion patterns, how quantize/dequantize ops are
280         inserted, supported dtypes etc. See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details
281
282    Return:
283      A GraphModule with observer (configured by qconfig_mapping), ready for calibration
284
285    Example::
286
287        import torch
288        from torch.ao.quantization import get_default_qconfig_mapping
289        from torch.ao.quantization.quantize_fx import prepare_fx
290
291        class Submodule(torch.nn.Module):
292            def __init__(self) -> None:
293                super().__init__()
294                self.linear = torch.nn.Linear(5, 5)
295            def forward(self, x):
296                x = self.linear(x)
297                return x
298
299        class M(torch.nn.Module):
300            def __init__(self) -> None:
301                super().__init__()
302                self.linear = torch.nn.Linear(5, 5)
303                self.sub = Submodule()
304
305            def forward(self, x):
306                x = self.linear(x)
307                x = self.sub(x) + x
308                return x
309
310        # initialize a floating point model
311        float_model = M().eval()
312
313        # define calibration function
314        def calibrate(model, data_loader):
315            model.eval()
316            with torch.no_grad():
317                for image, target in data_loader:
318                    model(image)
319
320        # qconfig is the configuration for how we insert observers for a particular
321        # operator
322        # qconfig = get_default_qconfig("fbgemm")
323        # Example of customizing qconfig:
324        # qconfig = torch.ao.quantization.QConfig(
325        #    activation=MinMaxObserver.with_args(dtype=torch.qint8),
326        #    weight=MinMaxObserver.with_args(dtype=torch.qint8))
327        # `activation` and `weight` are constructors of observer module
328
329        # qconfig_mapping is a collection of quantization configurations, user can
330        # set the qconfig for each operator (torch op calls, functional calls, module calls)
331        # in the model through qconfig_mapping
332        # the following call will get the qconfig_mapping that works best for models
333        # that target "fbgemm" backend
334        qconfig_mapping = get_default_qconfig_mapping("fbgemm")
335
336        # We can customize qconfig_mapping in different ways.
337        # e.g. set the global qconfig, which means we will use the same qconfig for
338        # all operators in the model, this can be overwritten by other settings
339        # qconfig_mapping = QConfigMapping().set_global(qconfig)
340        # e.g. quantize the linear submodule with a specific qconfig
341        # qconfig_mapping = QConfigMapping().set_module_name("linear", qconfig)
342        # e.g. quantize all nn.Linear modules with a specific qconfig
343        # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig)
344        # for a more complete list, please see the docstring for :class:`torch.ao.quantization.QConfigMapping`
345        # argument
346
347        # example_inputs is a tuple of inputs, that is used to infer the type of the
348        # outputs in the model
349        # currently it's not used, but please make sure model(*example_inputs) runs
350        example_inputs = (torch.randn(1, 3, 224, 224),)
351
352        # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
353        # e.g. backend_config = get_default_backend_config("fbgemm")
354        # `prepare_fx` inserts observers in the model based on qconfig_mapping and
355        # backend_config. If the configuration for an operator in qconfig_mapping
356        # is supported in the backend_config (meaning it's supported by the target
357        # hardware), we'll insert observer modules according to the qconfig_mapping
358        # otherwise the configuration in qconfig_mapping will be ignored
359        #
360        # Example:
361        # in qconfig_mapping, user sets linear module to be quantized with quint8 for
362        # activation and qint8 for weight:
363        # qconfig = torch.ao.quantization.QConfig(
364        #     observer=MinMaxObserver.with_args(dtype=torch.quint8),
365        #     weight=MinMaxObserver.with-args(dtype=torch.qint8))
366        # Note: current qconfig api does not support setting output observer, but
367        # we may extend this to support these more fine grained control in the
368        # future
369        #
370        # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig)
371        # in backend config, linear module also supports in this configuration:
372        # weighted_int8_dtype_config = DTypeConfig(
373        #   input_dtype=torch.quint8,
374        #   output_dtype=torch.quint8,
375        #   weight_dtype=torch.qint8,
376        #   bias_type=torch.float)
377
378        # linear_pattern_config = BackendPatternConfig(torch.nn.Linear) \
379        #    .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
380        #    .add_dtype_config(weighted_int8_dtype_config) \
381        #    ...
382
383        # backend_config = BackendConfig().set_backend_pattern_config(linear_pattern_config)
384        # `prepare_fx` will check that the setting requested by suer in qconfig_mapping
385        # is supported by the backend_config and insert observers and fake quant modules
386        # in the model
387        prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs)
388        # Run calibration
389        calibrate(prepared_model, sample_inference_data)
390    """
391    torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx")
392    return _prepare_fx(
393        model,
394        qconfig_mapping,
395        False,  # is_qat
396        example_inputs,
397        prepare_custom_config,
398        _equalization_config,
399        backend_config,
400    )
401
402
403def prepare_qat_fx(
404    model: torch.nn.Module,
405    qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
406    example_inputs: Tuple[Any, ...],
407    prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
408    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
409) -> GraphModule:
410    r"""Prepare a model for quantization aware training
411
412    Args:
413      * `model` (torch.nn.Module): torch.nn.Module model
414      * `qconfig_mapping` (QConfigMapping): see :func:`~torch.ao.quantization.prepare_fx`
415      * `example_inputs` (Tuple[Any, ...]): see :func:`~torch.ao.quantization.prepare_fx`
416      * `prepare_custom_config` (PrepareCustomConfig): see :func:`~torch.ao.quantization.prepare_fx`
417      * `backend_config` (BackendConfig): see :func:`~torch.ao.quantization.prepare_fx`
418
419    Return:
420      A GraphModule with fake quant modules (configured by qconfig_mapping and backend_config), ready for
421      quantization aware training
422
423    Example::
424
425        import torch
426        from torch.ao.quantization import get_default_qat_qconfig_mapping
427        from torch.ao.quantization.quantize_fx import prepare_qat_fx
428
429        class Submodule(torch.nn.Module):
430            def __init__(self) -> None:
431                super().__init__()
432                self.linear = torch.nn.Linear(5, 5)
433            def forward(self, x):
434                x = self.linear(x)
435                return x
436
437        class M(torch.nn.Module):
438            def __init__(self) -> None:
439                super().__init__()
440                self.linear = torch.nn.Linear(5, 5)
441                self.sub = Submodule()
442
443            def forward(self, x):
444                x = self.linear(x)
445                x = self.sub(x) + x
446                return x
447
448        # initialize a floating point model
449        float_model = M().train()
450        # (optional, but preferred) load the weights from pretrained model
451        # float_model.load_weights(...)
452
453        # define the training loop for quantization aware training
454        def train_loop(model, train_data):
455            model.train()
456            for image, target in data_loader:
457                ...
458
459        # qconfig is the configuration for how we insert observers for a particular
460        # operator
461        # qconfig = get_default_qconfig("fbgemm")
462        # Example of customizing qconfig:
463        # qconfig = torch.ao.quantization.QConfig(
464        #    activation=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)),
465        #    weight=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)))
466        # `activation` and `weight` are constructors of observer module
467
468        # qconfig_mapping is a collection of quantization configurations, user can
469        # set the qconfig for each operator (torch op calls, functional calls, module calls)
470        # in the model through qconfig_mapping
471        # the following call will get the qconfig_mapping that works best for models
472        # that target "fbgemm" backend
473        qconfig_mapping = get_default_qat_qconfig("fbgemm")
474
475        # We can customize qconfig_mapping in different ways, please take a look at
476        # the docstring for :func:`~torch.ao.quantization.prepare_fx` for different ways
477        # to configure this
478
479        # example_inputs is a tuple of inputs, that is used to infer the type of the
480        # outputs in the model
481        # currently it's not used, but please make sure model(*example_inputs) runs
482        example_inputs = (torch.randn(1, 3, 224, 224),)
483
484        # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
485        # e.g. backend_config = get_default_backend_config("fbgemm")
486        # `prepare_qat_fx` inserts observers in the model based on qconfig_mapping and
487        # backend_config, if the configuration for an operator in qconfig_mapping
488        # is supported in the backend_config (meaning it's supported by the target
489        # hardware), we'll insert fake_quantize modules according to the qconfig_mapping
490        # otherwise the configuration in qconfig_mapping will be ignored
491        # see :func:`~torch.ao.quantization.prepare_fx` for a detailed explanation of
492        # how qconfig_mapping interacts with backend_config
493        prepared_model = prepare_qat_fx(float_model, qconfig_mapping, example_inputs)
494        # Run training
495        train_loop(prepared_model, train_loop)
496
497    """
498    torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx")
499    return _prepare_fx(
500        model,
501        qconfig_mapping,
502        True,  # is_qat
503        example_inputs,
504        prepare_custom_config,
505        backend_config=backend_config,
506    )
507
508
509def _convert_fx(
510    graph_module: GraphModule,
511    is_reference: bool,
512    convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
513    is_standalone_module: bool = False,
514    _remove_qconfig: bool = True,
515    qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
516    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
517    is_decomposed: bool = False,
518) -> GraphModule:
519    """`is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`"""
520    if convert_custom_config is None:
521        convert_custom_config = ConvertCustomConfig()
522
523    if isinstance(convert_custom_config, dict):
524        warnings.warn(
525            "Passing a convert_custom_config_dict to convert is deprecated and will not be supported "
526            "in a future version. Please pass in a ConvertCustomConfig instead.",
527            FutureWarning,
528            stacklevel=3,
529        )
530        convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config)
531
532    _check_is_graph_module(graph_module)
533    preserved_attr_names = convert_custom_config.preserved_attributes
534    preserved_attrs = {
535        attr: getattr(graph_module, attr)
536        for attr in preserved_attr_names
537        if hasattr(graph_module, attr)
538    }
539
540    quantized = convert(
541        graph_module,
542        is_reference,
543        convert_custom_config,
544        is_standalone_module,
545        _remove_qconfig_flag=_remove_qconfig,
546        qconfig_mapping=qconfig_mapping,
547        backend_config=backend_config,
548        is_decomposed=is_decomposed,
549    )
550
551    attach_preserved_attrs_to_model(quantized, preserved_attrs)
552    return quantized
553
554
555def convert_fx(
556    graph_module: GraphModule,
557    convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
558    _remove_qconfig: bool = True,
559    qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
560    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
561) -> GraphModule:
562    r"""Convert a calibrated or trained model to a quantized model
563
564    Args:
565        * `graph_module` (torch.fx.GraphModule): A prepared and calibrated/trained model (GraphModule)
566
567        * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
568            See :class:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig` for more details
569
570        * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
571
572        * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
573
574           The keys must include the ones in the qconfig_mapping passed to `prepare_fx` or `prepare_qat_fx`,
575           with the same values or `None`. Additional keys can be specified with values set to `None`.
576
577          For each entry whose value is set to None, we skip quantizing that entry in the model::
578
579            qconfig_mapping = QConfigMapping
580                .set_global(qconfig_from_prepare)
581                .set_object_type(torch.nn.functional.add, None)  # skip quantizing torch.nn.functional.add
582                .set_object_type(torch.nn.functional.linear, qconfig_from_prepare)
583                .set_module_name("foo.bar", None)  # skip quantizing module "foo.bar"
584
585         * `backend_config` (BackendConfig): A configuration for the backend which describes how
586            operators should be quantized in the backend, this includes quantization
587            mode support (static/dynamic/weight_only), dtype support (quint8/qint8 etc.),
588            observer placement for each operators and fused operators.
589            See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details
590
591    Return:
592        A quantized model (torch.nn.Module)
593
594    Example::
595
596        # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
597        # convert_fx converts a calibrated/trained model to a quantized model for the
598        # target hardware, this includes converting the model first to a reference
599        # quantized model, and then lower the reference quantized model to a backend
600        # Currently, the supported backends are fbgemm (onednn), qnnpack (xnnpack) and
601        # they share the same set of quantized operators, so we are using the same
602        # lowering procedure
603        #
604        # backend_config defines the corresponding reference quantized module for
605        # the weighted modules in the model, e.g. nn.Linear
606        # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
607        # e.g. backend_config = get_default_backend_config("fbgemm")
608        quantized_model = convert_fx(prepared_model)
609
610    """
611    torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx")
612    return _convert_fx(
613        graph_module,
614        is_reference=False,
615        convert_custom_config=convert_custom_config,
616        _remove_qconfig=_remove_qconfig,
617        qconfig_mapping=qconfig_mapping,
618        backend_config=backend_config,
619    )
620
621
622def convert_to_reference_fx(
623    graph_module: GraphModule,
624    convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
625    _remove_qconfig: bool = True,
626    qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
627    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
628) -> GraphModule:
629    r"""Convert a calibrated or trained model to a reference quantized model,
630    see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details,
631    reference quantized model is a standard representation of a quantized model provided
632    by FX Graph Mode Quantization, it can be further lowered to run on the target
633    hardware, like accelerators
634
635    Args:
636        * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule)
637
638        * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
639            See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
640
641        * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
642
643        * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
644            See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
645
646         * `backend_config` (BackendConfig): A configuration for the backend which describes how
647            operators should be quantized in the backend. See
648            :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
649
650    Return:
651        A reference quantized model (GraphModule)
652
653    Example::
654
655        # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
656        # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
657        # e.g. backend_config = get_default_backend_config("fbgemm")
658        reference_quantized_model = convert_to_reference_fx(prepared_model)
659
660    """
661    torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_to_reference_fx")
662    return _convert_fx(
663        graph_module,
664        is_reference=True,
665        convert_custom_config=convert_custom_config,
666        _remove_qconfig=_remove_qconfig,
667        qconfig_mapping=qconfig_mapping,
668        backend_config=backend_config,
669    )
670
671
672def _convert_to_reference_decomposed_fx(
673    graph_module: GraphModule,
674    convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
675    qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None,
676    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
677) -> GraphModule:
678    r"""Convert a calibrated or trained model to a reference quantized model, with
679    decomposed representation for quantized Tensor
680    see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details,
681    reference quantized model is a standard representation of a quantized model provided
682    by FX Graph Mode Quantization, it can be further lowered to run on the target
683    hardware, like accelerators
684
685    Note: this is not public API
686
687    Args:
688        * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule)
689
690        * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
691            See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
692
693        * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
694
695        * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
696            See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
697
698         * `backend_config` (BackendConfig): A configuration for the backend which describes how
699            operators should be quantized in the backend. See
700            :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
701
702    Return:
703        A reference quantized model (GraphModule) with operators working with decomposed quantized Tensor
704
705    Example::
706
707        # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training
708        # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack
709        # e.g. backend_config = get_default_backend_config("fbgemm")
710        reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model)
711
712    """
713    torch._C._log_api_usage_once(
714        "quantization_api.quantize_fx._convert_to_reference_decomposed_fx"
715    )
716    return _convert_fx(
717        graph_module,
718        is_reference=True,
719        convert_custom_config=convert_custom_config,
720        _remove_qconfig=False,
721        qconfig_mapping=qconfig_mapping,
722        backend_config=backend_config,
723        is_decomposed=True,
724    )
725
726
727def _convert_standalone_module_fx(
728    graph_module: GraphModule,
729    is_reference: bool = False,
730    convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None,
731) -> GraphModule:
732    r"""[Internal use only] Convert a model produced by :func:`~torch.ao.quantization.prepare_standalone_module_fx`
733    and convert it to a quantized model
734
735    Returns a quantized standalone module, whether input/output is quantized is
736    specified by prepare_custom_config, with
737    input_quantized_idxs, output_quantized_idxs, please
738    see docs for prepare_fx for details
739    """
740    return _convert_fx(
741        graph_module,
742        is_reference,
743        convert_custom_config,
744        is_standalone_module=True,
745    )
746