1# mypy: allow-untyped-defs 2import logging 3import operator 4from typing import List, Optional, Tuple, Union 5 6import torch 7import torch.export._trace 8from torch._ops import OpOverload 9from torch.ao.quantization.fx._decomposed import ( 10 dequantize_per_channel, 11 dequantize_per_tensor, 12 quantize_per_tensor, 13) 14from torch.ao.quantization.utils import calculate_qmin_qmax 15from torch.fx.graph_module import _assign_attr 16 17 18log = logging.getLogger(__name__) 19 20# Those values will need to be carried over multiple operators. 21_INPUT_Q_DTYPE: Optional[Union[torch.dtype, torch.fx.Node]] = None 22_SCALE: Optional[Union[float, torch.fx.Node]] = None 23_ZERO_POINT: Optional[Union[float, torch.fx.Node]] = None 24 25 26def int_to_valid_dtype(val: int) -> torch.dtype: 27 from torch._export.converter import _TORCH_ENUM_TO_DTYPE # No circular import. 28 29 if isinstance(val, torch.dtype): 30 return val 31 dtype = _TORCH_ENUM_TO_DTYPE[val] 32 if dtype == torch.quint8: 33 return torch.uint8 34 elif dtype == torch.qint8: 35 return torch.int8 36 return dtype 37 38 39def fx_enum_to_dtype(gm: torch.fx.GraphModule, val: int) -> torch.fx.Node: 40 return gm.graph.call_function(int_to_valid_dtype, (val,)) 41 42 43def insert_quantized_node( 44 gm: torch.fx.GraphModule, 45 val_node: torch.fx.Node, 46 scale_node: Union[float, torch.fx.Node], 47 zero_point_node: Union[float, torch.fx.Node], 48 qmin_node: Union[float, int, torch.fx.Node], 49 qmax_node: Union[float, int, torch.fx.Node], 50 dtype_node: Union[torch.dtype, torch.fx.Node], 51 qscheme: Optional[torch.qscheme], 52) -> torch.fx.Node: 53 return gm.graph.call_function( 54 quantize_per_tensor, 55 ( 56 val_node, 57 scale_node, 58 zero_point_node, 59 qmin_node, 60 qmax_node, 61 dtype_node, 62 ), 63 ) 64 65 66def get_dequantized( 67 val: torch.Tensor, 68 scale: Union[float, torch.Tensor], 69 zero_point: Union[float, torch.Tensor], 70 qmin: Union[float, int], 71 qmax: Union[float, int], 72 dtype: torch.dtype, 73 axis: Optional[int], 74 qscheme: Optional[torch.qscheme], 75) -> torch.Tensor: 76 if qscheme is torch.per_tensor_affine: 77 return dequantize_per_tensor( 78 val, 79 scale, 80 zero_point, 81 qmin, 82 qmax, 83 dtype, 84 ) 85 elif qscheme is torch.per_channel_affine: 86 return dequantize_per_channel( 87 val, 88 scale, 89 zero_point, 90 axis, 91 qmin, 92 qmax, 93 dtype, 94 ) 95 else: 96 raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}") 97 98 99def insert_dequantized_node( 100 gm: torch.fx.GraphModule, 101 val_node: torch.fx.Node, 102 scale_node: Union[float, torch.fx.Node], 103 zero_point_node: Union[float, torch.fx.Node], 104 qmin_node: Union[float, int, torch.fx.Node], 105 qmax_node: Union[float, int, torch.fx.Node], 106 dtype_node: Union[torch.dtype, torch.fx.Node], 107 axis_node: Optional[Union[int, torch.fx.Node]], 108 qscheme: Optional[torch.qscheme], 109) -> torch.fx.Node: 110 if qscheme is torch.per_tensor_affine: 111 return gm.graph.call_function( 112 dequantize_per_tensor, 113 ( 114 val_node, 115 scale_node, 116 zero_point_node, 117 qmin_node, 118 qmax_node, 119 dtype_node, 120 ), 121 ) 122 elif qscheme is torch.per_channel_affine: 123 return gm.graph.call_function( 124 dequantize_per_channel, 125 ( 126 val_node, 127 scale_node, 128 zero_point_node, 129 axis_node, 130 qmin_node, 131 qmax_node, 132 dtype_node, 133 ), 134 ) 135 else: 136 raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}") 137 138 139def get_qmin_qmax(dtype: torch.dtype) -> Tuple[Union[int, float], Union[int, float]]: 140 return calculate_qmin_qmax(None, None, False, dtype, False) # type: ignore[arg-type] 141 142 143def insert_qmin_qmax_node( 144 gm: torch.fx.GraphModule, dtype_node: Union[torch.dtype, torch.fx.Node] 145) -> Tuple[torch.fx.Node, torch.fx.Node]: 146 q_min_max_node = gm.graph.call_function( 147 calculate_qmin_qmax, (None, None, False, dtype_node, False) 148 ) 149 qmin_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 0)) 150 qmax_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 1)) 151 return qmin_node, qmax_node 152 153 154def get_script_object( 155 gm: torch.nn.Module, node: torch.fx.Node 156) -> torch._C.ScriptObject: 157 assert isinstance(node, torch.fx.Node) 158 assert node.op == "get_attr" 159 attr_name = node.target 160 assert isinstance(attr_name, str) 161 162 mod = gm 163 for attr in attr_name.split("."): 164 mod = getattr(mod, attr) 165 assert isinstance(mod, torch._C.ScriptObject) 166 return mod 167 168 169def insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject( 170 gm: torch.fx.GraphModule, 171 param_node: torch.fx.Node, 172) -> Tuple[torch.fx.Node, Optional[torch.fx.Node]]: 173 """Directly inline tensor from a get_attr fx node.""" 174 mod = get_script_object(gm, param_node) 175 w_qtensor, b_qtensor = mod.unpack() # type: ignore[attr-defined] 176 w_attr_name, b_attr_name = ( 177 f"dequantized_{param_node.target}_w", 178 f"dequantized_{param_node.target}_b", 179 ) 180 return insert_weight_and_bias_get_attr_node( 181 gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name 182 ) 183 184 185def insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor( 186 gm: torch.fx.GraphModule, 187 get_attr_to_weight_node: torch.fx.Node, 188 get_attr_to_bias_node: Optional[torch.fx.Node], 189) -> Tuple[torch.fx.Node, Optional[torch.fx.Node]]: 190 assert isinstance(get_attr_to_weight_node.target, str) 191 w_qtensor = getattr(gm, get_attr_to_weight_node.target) 192 w_attr_name = f"dequantized_{get_attr_to_weight_node.target}_w" 193 194 if get_attr_to_bias_node is not None: 195 assert isinstance(get_attr_to_bias_node.target, str) 196 b_qtensor = getattr(gm, get_attr_to_bias_node.target) 197 b_attr_name = f"dequantized_{get_attr_to_bias_node.target}_b" 198 else: 199 b_qtensor, b_attr_name = None, "" 200 201 return insert_weight_and_bias_get_attr_node( 202 gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name 203 ) 204 205 206def insert_weight_and_bias_get_attr_node( 207 gm: torch.fx.GraphModule, 208 w_qtensor: torch.Tensor, 209 b_qtensor: Optional[torch.Tensor], 210 w_attr_name: str, 211 b_attr_name: str, 212) -> Tuple[torch.fx.Node, Optional[torch.fx.Node]]: 213 w_tensor = get_tensor_from_qtensor(w_qtensor) 214 _assign_attr(w_tensor, gm, w_attr_name) 215 w_tensor_attr = gm.graph.get_attr(w_attr_name) 216 217 if b_qtensor is not None: 218 b_tensor = get_tensor_from_qtensor(b_qtensor, dequant=False) 219 _assign_attr(b_tensor, gm, b_attr_name) 220 b_tensor_attr = gm.graph.get_attr(b_attr_name) 221 else: 222 b_tensor_attr = None 223 224 return w_tensor_attr, b_tensor_attr 225 226 227def get_tensor_from_qtensor( 228 qtensor: torch.Tensor, dequant: bool = True 229) -> torch.Tensor: 230 # Manual conversion because qint8 is not used anymore. 231 if qtensor.dtype in [torch.qint8, torch.quint8]: 232 tensor = qtensor.int_repr() 233 else: 234 tensor = qtensor 235 236 # Weights need dequantization with scaling and zero_point adjustment, but 237 # bias does not need that. 238 if dequant: 239 qscheme = qtensor.qscheme() 240 if qscheme == torch.per_channel_affine: 241 scale, zero_point, axis = ( 242 qtensor.q_per_channel_scales(), 243 qtensor.q_per_channel_zero_points(), 244 qtensor.q_per_channel_axis(), 245 ) 246 else: 247 scale, zero_point, axis = ( 248 qtensor.q_scale(), # type: ignore[assignment] 249 qtensor.q_zero_point(), # type: ignore[assignment] 250 None, 251 ) 252 dtype = tensor.dtype 253 qmin, qmax = get_qmin_qmax(dtype) 254 return get_dequantized( 255 tensor, scale, zero_point, qmin, qmax, dtype, axis, qscheme 256 ) 257 return tensor 258 259 260def insert_fused_activation_node( 261 gm: torch.fx.GraphModule, opname: str, fx_node: torch.fx.Node 262) -> torch.fx.Node: 263 if opname in ["conv1d_relu", "conv2d_relu", "linear_relu", "add_relu", "mul_relu"]: 264 fx_node = gm.graph.call_function(torch.ops.aten.relu, (fx_node,)) 265 return fx_node 266 267 268def _conv1d_op_with_squeeze( 269 inp: torch.Tensor, 270 weight: torch.Tensor, 271 bias: Optional[torch.Tensor], 272 stride: List[int], 273 padding: List[int], 274 dilation: List[int], 275 groups: int, 276) -> torch.Tensor: 277 # In quantized version, conv1d is emulated using conv2d with squeeze and unsqueeze 278 # operations before and after the conv2d operation to match the dimension of weights. 279 # Reference: https://github.com/pytorch/pytorch/blob/eca0cb0fbe84bb0a34fa94afe261bceecd52c436/aten/src/ATen/native/quantized/cpu/qconv.cpp#L1827 # noqa: B950 280 s_inp = torch.ops.aten.unsqueeze(inp, 2) 281 conv1d_res = torch.ops.aten.conv2d( 282 s_inp, 283 weight, 284 bias, 285 stride, 286 padding, 287 dilation, 288 groups, 289 ) 290 uns_conv1d_res = torch.ops.aten.squeeze(conv1d_res, 2) 291 return uns_conv1d_res 292 293 294def _transform_conv_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node): 295 """Conv specfic transformation function.""" 296 assert isinstance(node.target, torch._ops.OpOverload) 297 opname = node.target._opname 298 scale_node, zero_point_node = node.args[2], node.args[3] 299 300 op_f = ( 301 torch.ops.aten.conv2d 302 if opname in ["conv2d", "conv2d_relu"] 303 else _conv1d_op_with_squeeze 304 ) 305 306 inp_node, param_node = node.args[0], node.args[1] 307 assert isinstance(inp_node, torch.fx.Node) 308 assert isinstance(param_node, torch.fx.Node) 309 310 if param_node.op == "call_function": 311 # Using Conv2dPrepackParam from conv_prepack. 312 # We directly skip the packing call and inline weights and bias. 313 w_node, b_node = param_node.args[0], param_node.args[1] 314 assert isinstance(w_node, torch.fx.Node) 315 assert b_node is None or isinstance(b_node, torch.fx.Node) 316 ( 317 param_0, 318 param_1, 319 ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor( 320 gm, w_node, b_node 321 ) 322 op_res_node = gm.graph.call_function( 323 op_f, (inp_node, param_0, param_1, *param_node.args[2:]) 324 ) 325 else: 326 # Using ConvPrepackedParam. 327 param = get_script_object(gm, param_node) 328 ( 329 param_0, 330 param_1, 331 ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject( 332 gm, param_node 333 ) # type: ignore[assignment] 334 op_res_node = gm.graph.call_function( 335 op_f, 336 ( 337 inp_node, 338 param_0, 339 param_1, 340 param.stride(), # type: ignore[attr-defined] 341 param.padding(), # type: ignore[attr-defined] 342 param.dilation(), # type: ignore[attr-defined] 343 param.groups(), # type: ignore[attr-defined] 344 ), 345 ) 346 return op_res_node, scale_node, zero_point_node 347 348 349def _transform_linear_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node): 350 """Linear specfic transformation function.""" 351 scale_node, zero_point_node = node.args[2], node.args[3] 352 353 inp_node, param_node = node.args[0], node.args[1] 354 assert isinstance(inp_node, torch.fx.Node) 355 assert isinstance(param_node, torch.fx.Node) 356 357 if param_node.op == "call_function": 358 # Using LinearPrepackParam from linear_prepack. 359 # We directly skip the packing call and inline weights and bias. 360 w_node, b_node = param_node.args[0], param_node.args[1] 361 assert isinstance(w_node, torch.fx.Node) 362 assert b_node is None or isinstance(b_node, torch.fx.Node) 363 ( 364 param_0, 365 param_1, 366 ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor( 367 gm, w_node, b_node 368 ) 369 op_res_node = gm.graph.call_function( 370 torch.ops.aten.linear, (inp_node, param_0, param_1, *param_node.args[2:]) 371 ) 372 else: 373 # Using LinearPackedParams. 374 ( 375 param_0, 376 param_1, 377 ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject( 378 gm, param_node 379 ) # type: ignore[assignment] 380 op_res_node = gm.graph.call_function( 381 torch.ops.aten.linear, (inp_node, param_0, param_1) 382 ) 383 return op_res_node, scale_node, zero_point_node 384 385 386def _transform_op_where_last_two_arguments_are_scale_and_zero_point( 387 gm: torch.fx.GraphModule, node: torch.fx.Node 388): 389 """ 390 This transformation function can be used for function where the last two 391 parameters are scale and zero point. Additionally, the function's parameters 392 do not need any unpacking. 393 """ 394 to_standard_op = { 395 "mul": torch.ops.aten.mul, 396 "mul_relu": torch.ops.aten.mul, 397 "add": torch.ops.aten.add, 398 "add_relu": torch.ops.aten.add, 399 "softmax": torch.ops.aten.softmax, 400 "cat": torch.ops.aten.cat, 401 "hardswish": torch.ops.aten.hardswish, 402 } 403 404 assert isinstance(node.target, torch._ops.OpOverload) 405 opname, args = node.target._opname, node.args 406 scale_node, zero_point_node = args[-2], args[-1] 407 op_res_node = gm.graph.call_function(to_standard_op[opname], tuple(args[:-2])) 408 return op_res_node, scale_node, zero_point_node 409 410 411def _transform_scalar_arithmetic(gm: torch.fx.GraphModule, node: torch.fx.Node): 412 """Transform scalar overload for basic arithmetic.""" 413 to_standard_op = { 414 "mul": torch.ops.aten.mul.Scalar, 415 "add": torch.ops.aten.add.Scalar, 416 } 417 assert isinstance(node.target, torch._ops.OpOverload) 418 opname, args = node.target._opname, node.args 419 op_res_node = gm.graph.call_function(to_standard_op[opname], args) 420 return op_res_node, _SCALE, _ZERO_POINT 421 422 423def _transform_prepacked_op(gm: torch.fx.GraphModule, node: torch.fx.Node): 424 """ 425 Transformation for functions under prepacked namespace, where they share 426 the same handling logic that [...]OpContext contains all parameters. 427 """ 428 assert isinstance(node.target, torch._ops.OpOverload) 429 opname, args = node.target._opname, node.args 430 op_f = None 431 if opname == "conv2d_clamp_run": 432 op_f = torch.ops.aten.conv2d 433 elif opname == "linear_clamp_run": 434 op_f = torch.ops.aten.linear 435 else: 436 raise RuntimeError(f"Invalid operator {opname}") 437 438 assert isinstance(args[1], torch.fx.Node) 439 so = get_script_object(gm, args[1]) 440 441 func_args = [] 442 func_args += [args[0]] 443 func_args += so.unpack()[:2] # type: ignore[attr-defined] 444 if opname == "conv2d_clamp_run": 445 func_args += torch.ops.prepacked.unpack_prepacked_sizes_conv2d(so)[2:] 446 447 op_res_node = gm.graph.call_function(op_f, tuple(func_args)) 448 return op_res_node 449 450 451def _transform_batch_norm(gm: torch.fx.GraphModule, node: torch.fx.Node): 452 args = node.args 453 scale_node, zero_point_node = args[-2], args[-1] 454 op_res_node = gm.graph.call_function( 455 torch.ops.aten.native_batch_norm, (*args[:-3], False, 0.1, args[-3]) 456 ) 457 op_res_node = gm.graph.call_function(operator.getitem, (op_res_node, 0)) 458 return op_res_node, scale_node, zero_point_node 459 460 461def fx_transform_quantized_op_to_standard_op( 462 gm: torch.fx.GraphModule, node: torch.fx.Node 463) -> torch.fx.Node: 464 global _SCALE, _ZERO_POINT, _INPUT_Q_DTYPE 465 466 assert isinstance(node.target, torch._ops.OpOverload) 467 opname, overload = node.target._opname, node.target._overloadname 468 469 key = f"{opname}.{overload}" 470 opname_to_transform_f = { 471 "conv1d.new": _transform_conv_with_packedparam, 472 "conv1d_relu.new": _transform_conv_with_packedparam, 473 "conv1d.default": _transform_conv_with_packedparam, 474 "conv1d_relu.default": _transform_conv_with_packedparam, 475 "conv2d.new": _transform_conv_with_packedparam, 476 "conv2d_relu.new": _transform_conv_with_packedparam, 477 "conv2d.default": _transform_conv_with_packedparam, 478 "conv2d_relu.default": _transform_conv_with_packedparam, 479 "linear.default": _transform_linear_with_packedparam, 480 "linear_relu.default": _transform_linear_with_packedparam, 481 "add.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, 482 "add_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, 483 "mul.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, 484 "mul_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, 485 "softmax.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, 486 "cat.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, 487 "hardswish.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point, 488 "batch_norm2d.default": _transform_batch_norm, 489 "mul.Scalar": _transform_scalar_arithmetic, 490 "add.Scalar": _transform_scalar_arithmetic, 491 } 492 493 if f"{key}" not in opname_to_transform_f: 494 raise RuntimeError(f"Unsupported quantized op during transformation: {key}") 495 496 op_res_node, scale_node, zero_point_node = opname_to_transform_f[f"{key}"](gm, node) 497 498 # Add fused activation layer. 499 op_res_node = insert_fused_activation_node(gm, opname, op_res_node) 500 _SCALE, _ZERO_POINT = scale_node, zero_point_node 501 502 assert _INPUT_Q_DTYPE is not None 503 qmin_node, qmax_node = insert_qmin_qmax_node(gm, _INPUT_Q_DTYPE) 504 q_fx_node = insert_quantized_node( 505 gm, 506 op_res_node, 507 scale_node, 508 zero_point_node, 509 qmin_node, 510 qmax_node, 511 _INPUT_Q_DTYPE, 512 torch.per_tensor_affine, 513 ) 514 dq_fx_node = insert_dequantized_node( 515 gm, 516 q_fx_node, 517 scale_node, 518 zero_point_node, 519 qmin_node, 520 qmax_node, 521 _INPUT_Q_DTYPE, 522 None, 523 torch.per_tensor_affine, 524 ) 525 return dq_fx_node 526 527 528def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule): 529 """ 530 Replace legacy quantized ops (aten.quantize_per_tensor, quantized.conv) with 531 PT2 ops (quantize_decomposed.quantize_per_tensor, aten.conv). 532 533 Before: x || -> aten.q || -> quantized.conv2d || -> quantized.linear || -> aten.dq || -> y 534 535 After: x || -> qd.q -> qd.dq || -> aten.conv2d -> qd.q -> qd.dq || aten.linear -> qd.q -> qd.dq || -> y 536 537 (qd == quantized_decomposed library, q = quantize, dq = dequantize) 538 ^ 539 | 540 getattr(w), getattr(b) from Conv2dParamPrepack 541 542 During each iteration, the transformation spits out the transformed operator, its quantized output, 543 and its dequantized value together. We did this because dequantization need to use the 544 scale and zero point parameters from the quantization to recover the approximate original value. After each 545 iteration, the new dequantization node will be used as the input to the next node (e.g., dq2 -> linear). 546 547 For operators like conv2d and linear, their weights and bias are packed in a quantized format in the ScriptObject. 548 During the transformation, we unpack those objects, get their dequantized tensor, populate those 549 as attributes to the module, and use getattr to access them. 550 551 One exception in the transformation is conv_prepack and linear_prepack. Those calls pack 552 weight and bias constant tensors into ScriptObject, which are then used by subsequent conv2d or linear calls. 553 During transformation, we directly skip transforming conv_prepack or linear_prepack. We check whether ScriptObject to the 554 quantized::conv2d or linear is from conv_prepack or linear_prepack. If it is, we then inline those parameters 555 to the operator by converting them to a getattr fx.node. 556 557 For prepacked::conv2d_clamp_run and prepacked::linear_clamp_run, we directly convert them to aten.conv2d and aten.linear 558 without the need of doing de/quantization. 559 560 Three global variables defined are _INPUT_Q_DTYPE, _SCALE, _ZERO_POINT. _INPUT_Q_DTYPE determines the de/quantization 561 data type, which is the same across the entire program, but it only shows up in the very first quantization 562 call. _SCALE and _ZERO_POINT are used only when operators do not have those specified. E.g., mul.Scalar. 563 """ 564 565 global _INPUT_Q_DTYPE 566 567 quantized = False 568 569 last_quantized_node = None 570 for node in gm.graph.nodes: 571 if isinstance(node.target, OpOverload): 572 with gm.graph.inserting_before(node): 573 namespace, opname = node.target.namespace, node.target._opname 574 if namespace == "quantized" and opname not in [ 575 "conv_prepack", 576 "linear_prepack", 577 ]: 578 quantized = True 579 fx_node = fx_transform_quantized_op_to_standard_op(gm, node) 580 node.replace_all_uses_with(fx_node) 581 last_quantized_node = fx_node 582 elif namespace == "prepacked": 583 quantized = True 584 fx_node = _transform_prepacked_op(gm, node) 585 node.replace_all_uses_with(fx_node) 586 last_quantized_node = fx_node 587 elif namespace == "aten" and opname == "quantize_per_tensor": 588 inp_node, scale_node, zero_point_node, dtype_node = node.args 589 dtype_node = fx_enum_to_dtype(gm, dtype_node) 590 _INPUT_Q_DTYPE = dtype_node 591 qmin_node, qmax_node = insert_qmin_qmax_node(gm, dtype_node) 592 q_fx_node = insert_quantized_node( 593 gm, 594 inp_node, 595 scale_node, 596 zero_point_node, 597 qmin_node, 598 qmax_node, 599 dtype_node, 600 torch.per_tensor_affine, 601 ) 602 dq_fx_node = insert_dequantized_node( 603 gm, 604 q_fx_node, 605 scale_node, 606 zero_point_node, 607 qmin_node, 608 qmax_node, 609 dtype_node, 610 None, 611 torch.per_tensor_affine, 612 ) 613 node.replace_all_uses_with(dq_fx_node) 614 last_quantized_node = dq_fx_node 615 elif namespace == "aten" and opname == "dequantize": 616 assert last_quantized_node is not None 617 node.replace_all_uses_with(last_quantized_node) 618 else: 619 last_quantized_node = node 620 621 # Post-processing again to remove legacy ScriptObjects and quantizated tensors 622 # stored as attributes or in the buffer. This is used to clean up the GraphModule 623 # to not trigger tracing errors like missing __obj_flatten__ functions. 624 def _clean_attr(mod: torch.nn.Module): 625 for submod in mod.modules(): 626 attr_names_to_clean = set() 627 for k, v in submod.__dict__.items(): 628 if isinstance(v, torch.ScriptObject): 629 attr_names_to_clean.add(k) 630 if k == "_buffers": 631 buffer_name_to_clean = set() 632 for b_name, b_value in v.items(): 633 if isinstance(b_value, torch.Tensor) and b_value.dtype in [ 634 torch.qint8, 635 torch.quint8, 636 ]: 637 buffer_name_to_clean.add(b_name) 638 for b_name in buffer_name_to_clean: 639 v.pop(b_name, None) 640 for attr_name in attr_names_to_clean: 641 delattr(submod, attr_name) 642 643 if quantized: 644 """ 645 TODO: SetAttr + quantized ops will result incorrect program. This flag is used to temporarily 646 bypass test cases. 647 648 The deadcode elimination pass is needed to remove legacy quantized ops. Otherwise, retracing 649 will throw errors. However, the current way of SetAttr does inplace update to attributes, so 650 this pass regard them as dead code and remove them. Below is an example of GraphModule before 651 and after the dead code elimination pass. 652 653 class GraphModule(torch.nn.Module): 654 def forward(self, x_1): 655 # No stacktrace found for following nodes 656 data = self.data; data = None 657 data_1 = self.data 658 add_tensor = torch.ops.aten.add.Tensor(data_1, x_1, alpha = 1); data_1 = None 659 data_2 = self.data 660 copy_ = torch_Tensor_copy_(data_2, add_tensor); data_2 = add_tensor = copy_ = None 661 data_3 = self.data 662 add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1); x_1 = data_3 = None 663 return add_tensor_1 664 665 class GraphModule(torch.nn.Module): 666 def forward(self, x_1): 667 # No stacktrace found for following nodes 668 data_3 = self.data 669 add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1); x_1 = data_3 = None 670 return add_tensor_1 671 """ 672 gm.graph.eliminate_dead_code() 673 _clean_attr(gm) 674