1import copy 2import warnings 3from typing import Any, Dict, Optional, Tuple, Union 4 5import torch 6from torch.fx import GraphModule 7from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY 8 9from .backend_config import BackendConfig, get_tensorrt_backend_config # noqa: F401 10from .fx.convert import convert 11from .fx.custom_config import ConvertCustomConfig, FuseCustomConfig, PrepareCustomConfig 12from .fx.fuse import fuse # noqa: F401 13from .fx.graph_module import ObservedGraphModule # noqa: F401 14from .fx.prepare import prepare # noqa: F401 15from .fx.tracer import QuantizationTracer, Scope, ScopeContextManager # noqa: F401 16from .fx.utils import ( # noqa: F401 17 get_custom_module_class_keys, 18 get_skipped_module_name_and_classes, 19) 20from .qconfig_mapping import QConfigMapping 21 22 23def attach_preserved_attrs_to_model( 24 model: Union[GraphModule, torch.nn.Module], 25 preserved_attrs: Dict[str, Any], 26) -> None: 27 """Store preserved attributes to the model.meta so that it can be preserved during deepcopy""" 28 model.meta[_USER_PRESERVED_ATTRIBUTES_KEY] = copy.copy(preserved_attrs) # type: ignore[operator, index, assignment] 29 # set the preserved attributes in the model so that user can call 30 # model.attr as they do before calling fx graph mode quantization 31 for attr_name, attr in model.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): # type: ignore[index, union-attr] 32 setattr(model, attr_name, attr) 33 34 35def _check_is_graph_module(model: torch.nn.Module) -> None: 36 if not isinstance(model, GraphModule): 37 raise ValueError( 38 "input model must be a GraphModule, " 39 + "Got type:" 40 + str(type(model)) 41 + " Please make " 42 + "sure to follow the tutorials." 43 ) 44 45 46def _attach_meta_to_node_if_not_exist(model: GraphModule) -> None: 47 """Attach meta field to all nodes of the graph if it does not exist, 48 meta field is a field stores some meta information about the node, such 49 as dtype and shape information for output of the node, this only exists 50 if the program is captured by make_fx (used in quantize_pt2e flow), if 51 the program is captured by torch.fx symbolic tracing, this field may not exist, 52 so we add it here to avoid checking this all over the places 53 """ 54 for node in model.graph.nodes: 55 if not hasattr(node, "meta"): 56 node.meta = {} 57 58 59def _swap_ff_with_fxff(model: torch.nn.Module) -> None: 60 r"""Swap FloatFunctional with FXFloatFunctional""" 61 modules_to_swap = [] 62 for name, module in model.named_children(): 63 if isinstance(module, torch.ao.nn.quantized.FloatFunctional): 64 modules_to_swap.append(name) 65 else: 66 _swap_ff_with_fxff(module) 67 68 for name in modules_to_swap: 69 del model._modules[name] 70 model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional() 71 72 73def _fuse_fx( 74 model: GraphModule, 75 is_qat: bool, 76 fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None, 77 backend_config: Union[BackendConfig, Dict[str, Any], None] = None, 78) -> GraphModule: 79 r"""Internal helper function to fuse modules in preparation for quantization 80 81 Args: 82 model: GraphModule object from symbolic tracing (torch.fx.symbolic_trace) 83 """ 84 _check_is_graph_module(model) 85 return fuse( 86 model, is_qat, fuse_custom_config, backend_config 87 ) # type: ignore[operator] 88 89 90def _prepare_fx( 91 model: torch.nn.Module, 92 qconfig_mapping: Union[QConfigMapping, Dict[str, Any]], 93 is_qat: bool, 94 example_inputs: Tuple[Any, ...], 95 prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None, 96 _equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None, 97 backend_config: Union[BackendConfig, Dict[str, Any], None] = None, 98 is_standalone_module: bool = False, 99) -> GraphModule: 100 r"""Internal helper function for prepare_fx 101 Args: 102 `model`, `qconfig_mapping`, `prepare_custom_config`, `_equalization_config`: 103 see docs for :func:`~torch.ao.quantization.prepare_fx` 104 `is_standalone_module`: a boolean flag indicates whether we are 105 quantizing a standalone module or not, a standalone module 106 is a submodule of the parent module that is not inlined in the 107 forward graph of the parent module, 108 the way we quantize standalone module is described in: 109 :func:`~torch.ao.quantization._prepare_standalone_module_fx` 110 """ 111 if prepare_custom_config is None: 112 prepare_custom_config = PrepareCustomConfig() 113 if _equalization_config is None: 114 _equalization_config = QConfigMapping() 115 116 if isinstance(prepare_custom_config, dict): 117 warnings.warn( 118 "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported " 119 "in a future version. Please pass in a PrepareCustomConfig instead.", 120 FutureWarning, 121 stacklevel=3, 122 ) 123 prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config) 124 125 # swap FloatFunctional with FXFloatFunctional 126 _swap_ff_with_fxff(model) 127 128 skipped_module_names, skipped_module_classes = get_skipped_module_name_and_classes( 129 prepare_custom_config, is_standalone_module 130 ) 131 preserved_attr_names = prepare_custom_config.preserved_attributes 132 preserved_attrs = { 133 attr: getattr(model, attr) 134 for attr in preserved_attr_names 135 if hasattr(model, attr) 136 } 137 # symbolically trace the model 138 tracer = QuantizationTracer(skipped_module_names, skipped_module_classes) # type: ignore[arg-type] 139 graph_module = GraphModule(model, tracer.trace(model)) 140 _attach_meta_to_node_if_not_exist(graph_module) 141 142 fuse_custom_config = FuseCustomConfig().set_preserved_attributes( 143 prepare_custom_config.preserved_attributes 144 ) 145 graph_module = _fuse_fx(graph_module, is_qat, fuse_custom_config, backend_config) 146 prepared = prepare( 147 graph_module, 148 qconfig_mapping, 149 is_qat, 150 tracer.node_name_to_scope, 151 example_inputs=example_inputs, 152 prepare_custom_config=prepare_custom_config, 153 _equalization_config=_equalization_config, 154 backend_config=backend_config, 155 is_standalone_module=is_standalone_module, 156 ) # type: ignore[operator] 157 158 attach_preserved_attrs_to_model(prepared, preserved_attrs) 159 return prepared 160 161 162def _prepare_standalone_module_fx( 163 model: torch.nn.Module, 164 qconfig_mapping: Union[QConfigMapping, Dict[str, Any]], 165 is_qat: bool, 166 example_inputs: Tuple[Any, ...], 167 prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None, 168 backend_config: Union[BackendConfig, Dict[str, Any], None] = None, 169) -> GraphModule: 170 r"""[Internal use only] Prepare a standalone module, so that it can be used when quantizing the 171 parent module. 172 standalone_module means it a submodule that is not inlined in parent module, 173 and will be quantized separately as one unit. 174 175 How the standalone module is observed is specified by `input_quantized_idxs` and 176 `output_quantized_idxs` in the prepare_custom_config for the standalone module 177 178 Returns: 179 180 * model(GraphModule): prepared standalone module. It has these attributes in 181 model.meta: 182 183 * `standalone_module_input_quantized_idxs(List[Int])`: a list of 184 indexes for the graph input that is expected to be quantized, 185 same as input_quantized_idxs configuration provided 186 for the standalone module 187 * `standalone_module_output_quantized_idxs(List[Int])`: a list of 188 indexs for the graph output that is quantized 189 same as input_quantized_idxs configuration provided 190 for the standalone module 191 192 """ 193 return _prepare_fx( 194 model, 195 qconfig_mapping, 196 is_qat, 197 example_inputs, 198 prepare_custom_config, 199 backend_config=backend_config, 200 is_standalone_module=True, 201 ) 202 203 204def fuse_fx( 205 model: torch.nn.Module, 206 fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None, 207 backend_config: Union[BackendConfig, Dict[str, Any], None] = None, 208) -> GraphModule: 209 r"""Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode. 210 Fusion rules are defined in torch.ao.quantization.fx.fusion_pattern.py 211 212 Args: 213 214 * `model` (torch.nn.Module): a torch.nn.Module model 215 * `fuse_custom_config` (FuseCustomConfig): custom configurations for fuse_fx. 216 See :class:`~torch.ao.quantization.fx.custom_config.FuseCustomConfig` for more details 217 Example:: 218 219 from torch.ao.quantization import fuse_fx 220 m = Model().eval() 221 m = fuse_fx(m) 222 223 """ 224 if fuse_custom_config is None: 225 fuse_custom_config = FuseCustomConfig() 226 227 if isinstance(fuse_custom_config, dict): 228 warnings.warn( 229 "Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported " 230 "in a future version. Please pass in a FuseCustomConfig instead.", 231 FutureWarning, 232 stacklevel=2, 233 ) 234 fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config) 235 236 torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx") 237 preserved_attr_names = fuse_custom_config.preserved_attributes 238 preserved_attrs = { 239 attr: getattr(model, attr) 240 for attr in preserved_attr_names 241 if hasattr(model, attr) 242 } 243 244 graph_module = torch.fx.symbolic_trace(model) 245 _attach_meta_to_node_if_not_exist(graph_module) 246 graph_module = _fuse_fx(graph_module, False, fuse_custom_config, backend_config) 247 248 attach_preserved_attrs_to_model(graph_module, preserved_attrs) 249 return graph_module 250 251 252def prepare_fx( 253 model: torch.nn.Module, 254 qconfig_mapping: Union[QConfigMapping, Dict[str, Any]], 255 example_inputs: Tuple[Any, ...], 256 prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None, 257 _equalization_config: Optional[Union[QConfigMapping, Dict[str, Any]]] = None, 258 backend_config: Union[BackendConfig, Dict[str, Any], None] = None, 259) -> GraphModule: 260 r""" Prepare a model for post training quantization 261 262 Args: 263 * `model` (torch.nn.Module): torch.nn.Module model 264 265 * `qconfig_mapping` (QConfigMapping): QConfigMapping object to configure how a model is 266 quantized, see :class:`~torch.ao.quantization.qconfig_mapping.QConfigMapping` 267 for more details 268 269 * `example_inputs` (Tuple[Any, ...]): Example inputs for forward function of the model, 270 Tuple of positional args (keyword args can be passed as positional args as well) 271 272 * `prepare_custom_config` (PrepareCustomConfig): customization configuration for quantization tool. 273 See :class:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig` for more details 274 275 * `_equalization_config`: config for specifying how to perform equalization on the model 276 277 * `backend_config` (BackendConfig): config that specifies how operators are quantized 278 in a backend, this includes how the operators are observed, 279 supported fusion patterns, how quantize/dequantize ops are 280 inserted, supported dtypes etc. See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details 281 282 Return: 283 A GraphModule with observer (configured by qconfig_mapping), ready for calibration 284 285 Example:: 286 287 import torch 288 from torch.ao.quantization import get_default_qconfig_mapping 289 from torch.ao.quantization.quantize_fx import prepare_fx 290 291 class Submodule(torch.nn.Module): 292 def __init__(self) -> None: 293 super().__init__() 294 self.linear = torch.nn.Linear(5, 5) 295 def forward(self, x): 296 x = self.linear(x) 297 return x 298 299 class M(torch.nn.Module): 300 def __init__(self) -> None: 301 super().__init__() 302 self.linear = torch.nn.Linear(5, 5) 303 self.sub = Submodule() 304 305 def forward(self, x): 306 x = self.linear(x) 307 x = self.sub(x) + x 308 return x 309 310 # initialize a floating point model 311 float_model = M().eval() 312 313 # define calibration function 314 def calibrate(model, data_loader): 315 model.eval() 316 with torch.no_grad(): 317 for image, target in data_loader: 318 model(image) 319 320 # qconfig is the configuration for how we insert observers for a particular 321 # operator 322 # qconfig = get_default_qconfig("fbgemm") 323 # Example of customizing qconfig: 324 # qconfig = torch.ao.quantization.QConfig( 325 # activation=MinMaxObserver.with_args(dtype=torch.qint8), 326 # weight=MinMaxObserver.with_args(dtype=torch.qint8)) 327 # `activation` and `weight` are constructors of observer module 328 329 # qconfig_mapping is a collection of quantization configurations, user can 330 # set the qconfig for each operator (torch op calls, functional calls, module calls) 331 # in the model through qconfig_mapping 332 # the following call will get the qconfig_mapping that works best for models 333 # that target "fbgemm" backend 334 qconfig_mapping = get_default_qconfig_mapping("fbgemm") 335 336 # We can customize qconfig_mapping in different ways. 337 # e.g. set the global qconfig, which means we will use the same qconfig for 338 # all operators in the model, this can be overwritten by other settings 339 # qconfig_mapping = QConfigMapping().set_global(qconfig) 340 # e.g. quantize the linear submodule with a specific qconfig 341 # qconfig_mapping = QConfigMapping().set_module_name("linear", qconfig) 342 # e.g. quantize all nn.Linear modules with a specific qconfig 343 # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig) 344 # for a more complete list, please see the docstring for :class:`torch.ao.quantization.QConfigMapping` 345 # argument 346 347 # example_inputs is a tuple of inputs, that is used to infer the type of the 348 # outputs in the model 349 # currently it's not used, but please make sure model(*example_inputs) runs 350 example_inputs = (torch.randn(1, 3, 224, 224),) 351 352 # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack 353 # e.g. backend_config = get_default_backend_config("fbgemm") 354 # `prepare_fx` inserts observers in the model based on qconfig_mapping and 355 # backend_config. If the configuration for an operator in qconfig_mapping 356 # is supported in the backend_config (meaning it's supported by the target 357 # hardware), we'll insert observer modules according to the qconfig_mapping 358 # otherwise the configuration in qconfig_mapping will be ignored 359 # 360 # Example: 361 # in qconfig_mapping, user sets linear module to be quantized with quint8 for 362 # activation and qint8 for weight: 363 # qconfig = torch.ao.quantization.QConfig( 364 # observer=MinMaxObserver.with_args(dtype=torch.quint8), 365 # weight=MinMaxObserver.with-args(dtype=torch.qint8)) 366 # Note: current qconfig api does not support setting output observer, but 367 # we may extend this to support these more fine grained control in the 368 # future 369 # 370 # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig) 371 # in backend config, linear module also supports in this configuration: 372 # weighted_int8_dtype_config = DTypeConfig( 373 # input_dtype=torch.quint8, 374 # output_dtype=torch.quint8, 375 # weight_dtype=torch.qint8, 376 # bias_type=torch.float) 377 378 # linear_pattern_config = BackendPatternConfig(torch.nn.Linear) \ 379 # .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ 380 # .add_dtype_config(weighted_int8_dtype_config) \ 381 # ... 382 383 # backend_config = BackendConfig().set_backend_pattern_config(linear_pattern_config) 384 # `prepare_fx` will check that the setting requested by suer in qconfig_mapping 385 # is supported by the backend_config and insert observers and fake quant modules 386 # in the model 387 prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs) 388 # Run calibration 389 calibrate(prepared_model, sample_inference_data) 390 """ 391 torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx") 392 return _prepare_fx( 393 model, 394 qconfig_mapping, 395 False, # is_qat 396 example_inputs, 397 prepare_custom_config, 398 _equalization_config, 399 backend_config, 400 ) 401 402 403def prepare_qat_fx( 404 model: torch.nn.Module, 405 qconfig_mapping: Union[QConfigMapping, Dict[str, Any]], 406 example_inputs: Tuple[Any, ...], 407 prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None, 408 backend_config: Union[BackendConfig, Dict[str, Any], None] = None, 409) -> GraphModule: 410 r"""Prepare a model for quantization aware training 411 412 Args: 413 * `model` (torch.nn.Module): torch.nn.Module model 414 * `qconfig_mapping` (QConfigMapping): see :func:`~torch.ao.quantization.prepare_fx` 415 * `example_inputs` (Tuple[Any, ...]): see :func:`~torch.ao.quantization.prepare_fx` 416 * `prepare_custom_config` (PrepareCustomConfig): see :func:`~torch.ao.quantization.prepare_fx` 417 * `backend_config` (BackendConfig): see :func:`~torch.ao.quantization.prepare_fx` 418 419 Return: 420 A GraphModule with fake quant modules (configured by qconfig_mapping and backend_config), ready for 421 quantization aware training 422 423 Example:: 424 425 import torch 426 from torch.ao.quantization import get_default_qat_qconfig_mapping 427 from torch.ao.quantization.quantize_fx import prepare_qat_fx 428 429 class Submodule(torch.nn.Module): 430 def __init__(self) -> None: 431 super().__init__() 432 self.linear = torch.nn.Linear(5, 5) 433 def forward(self, x): 434 x = self.linear(x) 435 return x 436 437 class M(torch.nn.Module): 438 def __init__(self) -> None: 439 super().__init__() 440 self.linear = torch.nn.Linear(5, 5) 441 self.sub = Submodule() 442 443 def forward(self, x): 444 x = self.linear(x) 445 x = self.sub(x) + x 446 return x 447 448 # initialize a floating point model 449 float_model = M().train() 450 # (optional, but preferred) load the weights from pretrained model 451 # float_model.load_weights(...) 452 453 # define the training loop for quantization aware training 454 def train_loop(model, train_data): 455 model.train() 456 for image, target in data_loader: 457 ... 458 459 # qconfig is the configuration for how we insert observers for a particular 460 # operator 461 # qconfig = get_default_qconfig("fbgemm") 462 # Example of customizing qconfig: 463 # qconfig = torch.ao.quantization.QConfig( 464 # activation=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)), 465 # weight=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8))) 466 # `activation` and `weight` are constructors of observer module 467 468 # qconfig_mapping is a collection of quantization configurations, user can 469 # set the qconfig for each operator (torch op calls, functional calls, module calls) 470 # in the model through qconfig_mapping 471 # the following call will get the qconfig_mapping that works best for models 472 # that target "fbgemm" backend 473 qconfig_mapping = get_default_qat_qconfig("fbgemm") 474 475 # We can customize qconfig_mapping in different ways, please take a look at 476 # the docstring for :func:`~torch.ao.quantization.prepare_fx` for different ways 477 # to configure this 478 479 # example_inputs is a tuple of inputs, that is used to infer the type of the 480 # outputs in the model 481 # currently it's not used, but please make sure model(*example_inputs) runs 482 example_inputs = (torch.randn(1, 3, 224, 224),) 483 484 # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack 485 # e.g. backend_config = get_default_backend_config("fbgemm") 486 # `prepare_qat_fx` inserts observers in the model based on qconfig_mapping and 487 # backend_config, if the configuration for an operator in qconfig_mapping 488 # is supported in the backend_config (meaning it's supported by the target 489 # hardware), we'll insert fake_quantize modules according to the qconfig_mapping 490 # otherwise the configuration in qconfig_mapping will be ignored 491 # see :func:`~torch.ao.quantization.prepare_fx` for a detailed explanation of 492 # how qconfig_mapping interacts with backend_config 493 prepared_model = prepare_qat_fx(float_model, qconfig_mapping, example_inputs) 494 # Run training 495 train_loop(prepared_model, train_loop) 496 497 """ 498 torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx") 499 return _prepare_fx( 500 model, 501 qconfig_mapping, 502 True, # is_qat 503 example_inputs, 504 prepare_custom_config, 505 backend_config=backend_config, 506 ) 507 508 509def _convert_fx( 510 graph_module: GraphModule, 511 is_reference: bool, 512 convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None, 513 is_standalone_module: bool = False, 514 _remove_qconfig: bool = True, 515 qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None, 516 backend_config: Union[BackendConfig, Dict[str, Any], None] = None, 517 is_decomposed: bool = False, 518) -> GraphModule: 519 """`is_standalone_module`: see docs in :func:`~torch.ao.quantization.prepare_standalone_module_fx`""" 520 if convert_custom_config is None: 521 convert_custom_config = ConvertCustomConfig() 522 523 if isinstance(convert_custom_config, dict): 524 warnings.warn( 525 "Passing a convert_custom_config_dict to convert is deprecated and will not be supported " 526 "in a future version. Please pass in a ConvertCustomConfig instead.", 527 FutureWarning, 528 stacklevel=3, 529 ) 530 convert_custom_config = ConvertCustomConfig.from_dict(convert_custom_config) 531 532 _check_is_graph_module(graph_module) 533 preserved_attr_names = convert_custom_config.preserved_attributes 534 preserved_attrs = { 535 attr: getattr(graph_module, attr) 536 for attr in preserved_attr_names 537 if hasattr(graph_module, attr) 538 } 539 540 quantized = convert( 541 graph_module, 542 is_reference, 543 convert_custom_config, 544 is_standalone_module, 545 _remove_qconfig_flag=_remove_qconfig, 546 qconfig_mapping=qconfig_mapping, 547 backend_config=backend_config, 548 is_decomposed=is_decomposed, 549 ) 550 551 attach_preserved_attrs_to_model(quantized, preserved_attrs) 552 return quantized 553 554 555def convert_fx( 556 graph_module: GraphModule, 557 convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None, 558 _remove_qconfig: bool = True, 559 qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None, 560 backend_config: Union[BackendConfig, Dict[str, Any], None] = None, 561) -> GraphModule: 562 r"""Convert a calibrated or trained model to a quantized model 563 564 Args: 565 * `graph_module` (torch.fx.GraphModule): A prepared and calibrated/trained model (GraphModule) 566 567 * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function. 568 See :class:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig` for more details 569 570 * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert. 571 572 * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization. 573 574 The keys must include the ones in the qconfig_mapping passed to `prepare_fx` or `prepare_qat_fx`, 575 with the same values or `None`. Additional keys can be specified with values set to `None`. 576 577 For each entry whose value is set to None, we skip quantizing that entry in the model:: 578 579 qconfig_mapping = QConfigMapping 580 .set_global(qconfig_from_prepare) 581 .set_object_type(torch.nn.functional.add, None) # skip quantizing torch.nn.functional.add 582 .set_object_type(torch.nn.functional.linear, qconfig_from_prepare) 583 .set_module_name("foo.bar", None) # skip quantizing module "foo.bar" 584 585 * `backend_config` (BackendConfig): A configuration for the backend which describes how 586 operators should be quantized in the backend, this includes quantization 587 mode support (static/dynamic/weight_only), dtype support (quint8/qint8 etc.), 588 observer placement for each operators and fused operators. 589 See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details 590 591 Return: 592 A quantized model (torch.nn.Module) 593 594 Example:: 595 596 # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training 597 # convert_fx converts a calibrated/trained model to a quantized model for the 598 # target hardware, this includes converting the model first to a reference 599 # quantized model, and then lower the reference quantized model to a backend 600 # Currently, the supported backends are fbgemm (onednn), qnnpack (xnnpack) and 601 # they share the same set of quantized operators, so we are using the same 602 # lowering procedure 603 # 604 # backend_config defines the corresponding reference quantized module for 605 # the weighted modules in the model, e.g. nn.Linear 606 # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack 607 # e.g. backend_config = get_default_backend_config("fbgemm") 608 quantized_model = convert_fx(prepared_model) 609 610 """ 611 torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx") 612 return _convert_fx( 613 graph_module, 614 is_reference=False, 615 convert_custom_config=convert_custom_config, 616 _remove_qconfig=_remove_qconfig, 617 qconfig_mapping=qconfig_mapping, 618 backend_config=backend_config, 619 ) 620 621 622def convert_to_reference_fx( 623 graph_module: GraphModule, 624 convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None, 625 _remove_qconfig: bool = True, 626 qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None, 627 backend_config: Union[BackendConfig, Dict[str, Any], None] = None, 628) -> GraphModule: 629 r"""Convert a calibrated or trained model to a reference quantized model, 630 see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details, 631 reference quantized model is a standard representation of a quantized model provided 632 by FX Graph Mode Quantization, it can be further lowered to run on the target 633 hardware, like accelerators 634 635 Args: 636 * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule) 637 638 * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function. 639 See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. 640 641 * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert. 642 643 * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization. 644 See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. 645 646 * `backend_config` (BackendConfig): A configuration for the backend which describes how 647 operators should be quantized in the backend. See 648 :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. 649 650 Return: 651 A reference quantized model (GraphModule) 652 653 Example:: 654 655 # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training 656 # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack 657 # e.g. backend_config = get_default_backend_config("fbgemm") 658 reference_quantized_model = convert_to_reference_fx(prepared_model) 659 660 """ 661 torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_to_reference_fx") 662 return _convert_fx( 663 graph_module, 664 is_reference=True, 665 convert_custom_config=convert_custom_config, 666 _remove_qconfig=_remove_qconfig, 667 qconfig_mapping=qconfig_mapping, 668 backend_config=backend_config, 669 ) 670 671 672def _convert_to_reference_decomposed_fx( 673 graph_module: GraphModule, 674 convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None, 675 qconfig_mapping: Union[QConfigMapping, Dict[str, Any], None] = None, 676 backend_config: Union[BackendConfig, Dict[str, Any], None] = None, 677) -> GraphModule: 678 r"""Convert a calibrated or trained model to a reference quantized model, with 679 decomposed representation for quantized Tensor 680 see https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md for more details, 681 reference quantized model is a standard representation of a quantized model provided 682 by FX Graph Mode Quantization, it can be further lowered to run on the target 683 hardware, like accelerators 684 685 Note: this is not public API 686 687 Args: 688 * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule) 689 690 * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function. 691 See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. 692 693 * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert. 694 695 * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization. 696 See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. 697 698 * `backend_config` (BackendConfig): A configuration for the backend which describes how 699 operators should be quantized in the backend. See 700 :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. 701 702 Return: 703 A reference quantized model (GraphModule) with operators working with decomposed quantized Tensor 704 705 Example:: 706 707 # prepared_model: the model after prepare_fx/prepare_qat_fx and calibration/training 708 # TODO: add backend_config after we split the backend_config for fbgemm and qnnpack 709 # e.g. backend_config = get_default_backend_config("fbgemm") 710 reference_quantized_model = _convert_to_reference_decomposed_fx(prepared_model) 711 712 """ 713 torch._C._log_api_usage_once( 714 "quantization_api.quantize_fx._convert_to_reference_decomposed_fx" 715 ) 716 return _convert_fx( 717 graph_module, 718 is_reference=True, 719 convert_custom_config=convert_custom_config, 720 _remove_qconfig=False, 721 qconfig_mapping=qconfig_mapping, 722 backend_config=backend_config, 723 is_decomposed=True, 724 ) 725 726 727def _convert_standalone_module_fx( 728 graph_module: GraphModule, 729 is_reference: bool = False, 730 convert_custom_config: Union[ConvertCustomConfig, Dict[str, Any], None] = None, 731) -> GraphModule: 732 r"""[Internal use only] Convert a model produced by :func:`~torch.ao.quantization.prepare_standalone_module_fx` 733 and convert it to a quantized model 734 735 Returns a quantized standalone module, whether input/output is quantized is 736 specified by prepare_custom_config, with 737 input_quantized_idxs, output_quantized_idxs, please 738 see docs for prepare_fx for details 739 """ 740 return _convert_fx( 741 graph_module, 742 is_reference, 743 convert_custom_config, 744 is_standalone_module=True, 745 ) 746