1# mypy: allow-untyped-defs
2import logging
3import operator
4from typing import List, Optional, Tuple, Union
5
6import torch
7import torch.export._trace
8from torch._ops import OpOverload
9from torch.ao.quantization.fx._decomposed import (
10    dequantize_per_channel,
11    dequantize_per_tensor,
12    quantize_per_tensor,
13)
14from torch.ao.quantization.utils import calculate_qmin_qmax
15from torch.fx.graph_module import _assign_attr
16
17
18log = logging.getLogger(__name__)
19
20# Those values will need to be carried over multiple operators.
21_INPUT_Q_DTYPE: Optional[Union[torch.dtype, torch.fx.Node]] = None
22_SCALE: Optional[Union[float, torch.fx.Node]] = None
23_ZERO_POINT: Optional[Union[float, torch.fx.Node]] = None
24
25
26def int_to_valid_dtype(val: int) -> torch.dtype:
27    from torch._export.converter import _TORCH_ENUM_TO_DTYPE  # No circular import.
28
29    if isinstance(val, torch.dtype):
30        return val
31    dtype = _TORCH_ENUM_TO_DTYPE[val]
32    if dtype == torch.quint8:
33        return torch.uint8
34    elif dtype == torch.qint8:
35        return torch.int8
36    return dtype
37
38
39def fx_enum_to_dtype(gm: torch.fx.GraphModule, val: int) -> torch.fx.Node:
40    return gm.graph.call_function(int_to_valid_dtype, (val,))
41
42
43def insert_quantized_node(
44    gm: torch.fx.GraphModule,
45    val_node: torch.fx.Node,
46    scale_node: Union[float, torch.fx.Node],
47    zero_point_node: Union[float, torch.fx.Node],
48    qmin_node: Union[float, int, torch.fx.Node],
49    qmax_node: Union[float, int, torch.fx.Node],
50    dtype_node: Union[torch.dtype, torch.fx.Node],
51    qscheme: Optional[torch.qscheme],
52) -> torch.fx.Node:
53    return gm.graph.call_function(
54        quantize_per_tensor,
55        (
56            val_node,
57            scale_node,
58            zero_point_node,
59            qmin_node,
60            qmax_node,
61            dtype_node,
62        ),
63    )
64
65
66def get_dequantized(
67    val: torch.Tensor,
68    scale: Union[float, torch.Tensor],
69    zero_point: Union[float, torch.Tensor],
70    qmin: Union[float, int],
71    qmax: Union[float, int],
72    dtype: torch.dtype,
73    axis: Optional[int],
74    qscheme: Optional[torch.qscheme],
75) -> torch.Tensor:
76    if qscheme is torch.per_tensor_affine:
77        return dequantize_per_tensor(
78            val,
79            scale,
80            zero_point,
81            qmin,
82            qmax,
83            dtype,
84        )
85    elif qscheme is torch.per_channel_affine:
86        return dequantize_per_channel(
87            val,
88            scale,
89            zero_point,
90            axis,
91            qmin,
92            qmax,
93            dtype,
94        )
95    else:
96        raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}")
97
98
99def insert_dequantized_node(
100    gm: torch.fx.GraphModule,
101    val_node: torch.fx.Node,
102    scale_node: Union[float, torch.fx.Node],
103    zero_point_node: Union[float, torch.fx.Node],
104    qmin_node: Union[float, int, torch.fx.Node],
105    qmax_node: Union[float, int, torch.fx.Node],
106    dtype_node: Union[torch.dtype, torch.fx.Node],
107    axis_node: Optional[Union[int, torch.fx.Node]],
108    qscheme: Optional[torch.qscheme],
109) -> torch.fx.Node:
110    if qscheme is torch.per_tensor_affine:
111        return gm.graph.call_function(
112            dequantize_per_tensor,
113            (
114                val_node,
115                scale_node,
116                zero_point_node,
117                qmin_node,
118                qmax_node,
119                dtype_node,
120            ),
121        )
122    elif qscheme is torch.per_channel_affine:
123        return gm.graph.call_function(
124            dequantize_per_channel,
125            (
126                val_node,
127                scale_node,
128                zero_point_node,
129                axis_node,
130                qmin_node,
131                qmax_node,
132                dtype_node,
133            ),
134        )
135    else:
136        raise RuntimeError(f"Unsupported dequantization scheme: {qscheme}")
137
138
139def get_qmin_qmax(dtype: torch.dtype) -> Tuple[Union[int, float], Union[int, float]]:
140    return calculate_qmin_qmax(None, None, False, dtype, False)  # type: ignore[arg-type]
141
142
143def insert_qmin_qmax_node(
144    gm: torch.fx.GraphModule, dtype_node: Union[torch.dtype, torch.fx.Node]
145) -> Tuple[torch.fx.Node, torch.fx.Node]:
146    q_min_max_node = gm.graph.call_function(
147        calculate_qmin_qmax, (None, None, False, dtype_node, False)
148    )
149    qmin_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 0))
150    qmax_node = gm.graph.call_function(operator.getitem, (q_min_max_node, 1))
151    return qmin_node, qmax_node
152
153
154def get_script_object(
155    gm: torch.nn.Module, node: torch.fx.Node
156) -> torch._C.ScriptObject:
157    assert isinstance(node, torch.fx.Node)
158    assert node.op == "get_attr"
159    attr_name = node.target
160    assert isinstance(attr_name, str)
161
162    mod = gm
163    for attr in attr_name.split("."):
164        mod = getattr(mod, attr)
165    assert isinstance(mod, torch._C.ScriptObject)
166    return mod
167
168
169def insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject(
170    gm: torch.fx.GraphModule,
171    param_node: torch.fx.Node,
172) -> Tuple[torch.fx.Node, Optional[torch.fx.Node]]:
173    """Directly inline tensor from a get_attr fx node."""
174    mod = get_script_object(gm, param_node)
175    w_qtensor, b_qtensor = mod.unpack()  # type: ignore[attr-defined]
176    w_attr_name, b_attr_name = (
177        f"dequantized_{param_node.target}_w",
178        f"dequantized_{param_node.target}_b",
179    )
180    return insert_weight_and_bias_get_attr_node(
181        gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name
182    )
183
184
185def insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor(
186    gm: torch.fx.GraphModule,
187    get_attr_to_weight_node: torch.fx.Node,
188    get_attr_to_bias_node: Optional[torch.fx.Node],
189) -> Tuple[torch.fx.Node, Optional[torch.fx.Node]]:
190    assert isinstance(get_attr_to_weight_node.target, str)
191    w_qtensor = getattr(gm, get_attr_to_weight_node.target)
192    w_attr_name = f"dequantized_{get_attr_to_weight_node.target}_w"
193
194    if get_attr_to_bias_node is not None:
195        assert isinstance(get_attr_to_bias_node.target, str)
196        b_qtensor = getattr(gm, get_attr_to_bias_node.target)
197        b_attr_name = f"dequantized_{get_attr_to_bias_node.target}_b"
198    else:
199        b_qtensor, b_attr_name = None, ""
200
201    return insert_weight_and_bias_get_attr_node(
202        gm, w_qtensor, b_qtensor, w_attr_name, b_attr_name
203    )
204
205
206def insert_weight_and_bias_get_attr_node(
207    gm: torch.fx.GraphModule,
208    w_qtensor: torch.Tensor,
209    b_qtensor: Optional[torch.Tensor],
210    w_attr_name: str,
211    b_attr_name: str,
212) -> Tuple[torch.fx.Node, Optional[torch.fx.Node]]:
213    w_tensor = get_tensor_from_qtensor(w_qtensor)
214    _assign_attr(w_tensor, gm, w_attr_name)
215    w_tensor_attr = gm.graph.get_attr(w_attr_name)
216
217    if b_qtensor is not None:
218        b_tensor = get_tensor_from_qtensor(b_qtensor, dequant=False)
219        _assign_attr(b_tensor, gm, b_attr_name)
220        b_tensor_attr = gm.graph.get_attr(b_attr_name)
221    else:
222        b_tensor_attr = None
223
224    return w_tensor_attr, b_tensor_attr
225
226
227def get_tensor_from_qtensor(
228    qtensor: torch.Tensor, dequant: bool = True
229) -> torch.Tensor:
230    # Manual conversion because qint8 is not used anymore.
231    if qtensor.dtype in [torch.qint8, torch.quint8]:
232        tensor = qtensor.int_repr()
233    else:
234        tensor = qtensor
235
236    # Weights need dequantization with scaling and zero_point adjustment, but
237    # bias does not need that.
238    if dequant:
239        qscheme = qtensor.qscheme()
240        if qscheme == torch.per_channel_affine:
241            scale, zero_point, axis = (
242                qtensor.q_per_channel_scales(),
243                qtensor.q_per_channel_zero_points(),
244                qtensor.q_per_channel_axis(),
245            )
246        else:
247            scale, zero_point, axis = (
248                qtensor.q_scale(),  # type: ignore[assignment]
249                qtensor.q_zero_point(),  # type: ignore[assignment]
250                None,
251            )
252        dtype = tensor.dtype
253        qmin, qmax = get_qmin_qmax(dtype)
254        return get_dequantized(
255            tensor, scale, zero_point, qmin, qmax, dtype, axis, qscheme
256        )
257    return tensor
258
259
260def insert_fused_activation_node(
261    gm: torch.fx.GraphModule, opname: str, fx_node: torch.fx.Node
262) -> torch.fx.Node:
263    if opname in ["conv1d_relu", "conv2d_relu", "linear_relu", "add_relu", "mul_relu"]:
264        fx_node = gm.graph.call_function(torch.ops.aten.relu, (fx_node,))
265    return fx_node
266
267
268def _conv1d_op_with_squeeze(
269    inp: torch.Tensor,
270    weight: torch.Tensor,
271    bias: Optional[torch.Tensor],
272    stride: List[int],
273    padding: List[int],
274    dilation: List[int],
275    groups: int,
276) -> torch.Tensor:
277    # In quantized version, conv1d is emulated using conv2d with squeeze and unsqueeze
278    # operations before and after the conv2d operation to match the dimension of weights.
279    # Reference: https://github.com/pytorch/pytorch/blob/eca0cb0fbe84bb0a34fa94afe261bceecd52c436/aten/src/ATen/native/quantized/cpu/qconv.cpp#L1827  # noqa: B950
280    s_inp = torch.ops.aten.unsqueeze(inp, 2)
281    conv1d_res = torch.ops.aten.conv2d(
282        s_inp,
283        weight,
284        bias,
285        stride,
286        padding,
287        dilation,
288        groups,
289    )
290    uns_conv1d_res = torch.ops.aten.squeeze(conv1d_res, 2)
291    return uns_conv1d_res
292
293
294def _transform_conv_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node):
295    """Conv specfic transformation function."""
296    assert isinstance(node.target, torch._ops.OpOverload)
297    opname = node.target._opname
298    scale_node, zero_point_node = node.args[2], node.args[3]
299
300    op_f = (
301        torch.ops.aten.conv2d
302        if opname in ["conv2d", "conv2d_relu"]
303        else _conv1d_op_with_squeeze
304    )
305
306    inp_node, param_node = node.args[0], node.args[1]
307    assert isinstance(inp_node, torch.fx.Node)
308    assert isinstance(param_node, torch.fx.Node)
309
310    if param_node.op == "call_function":
311        # Using Conv2dPrepackParam from conv_prepack.
312        # We directly skip the packing call and inline weights and bias.
313        w_node, b_node = param_node.args[0], param_node.args[1]
314        assert isinstance(w_node, torch.fx.Node)
315        assert b_node is None or isinstance(b_node, torch.fx.Node)
316        (
317            param_0,
318            param_1,
319        ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor(
320            gm, w_node, b_node
321        )
322        op_res_node = gm.graph.call_function(
323            op_f, (inp_node, param_0, param_1, *param_node.args[2:])
324        )
325    else:
326        # Using ConvPrepackedParam.
327        param = get_script_object(gm, param_node)
328        (
329            param_0,
330            param_1,
331        ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject(
332            gm, param_node
333        )  # type: ignore[assignment]
334        op_res_node = gm.graph.call_function(
335            op_f,
336            (
337                inp_node,
338                param_0,
339                param_1,
340                param.stride(),  # type: ignore[attr-defined]
341                param.padding(),  # type: ignore[attr-defined]
342                param.dilation(),  # type: ignore[attr-defined]
343                param.groups(),  # type: ignore[attr-defined]
344            ),
345        )
346    return op_res_node, scale_node, zero_point_node
347
348
349def _transform_linear_with_packedparam(gm: torch.fx.GraphModule, node: torch.fx.Node):
350    """Linear specfic transformation function."""
351    scale_node, zero_point_node = node.args[2], node.args[3]
352
353    inp_node, param_node = node.args[0], node.args[1]
354    assert isinstance(inp_node, torch.fx.Node)
355    assert isinstance(param_node, torch.fx.Node)
356
357    if param_node.op == "call_function":
358        # Using LinearPrepackParam from linear_prepack.
359        # We directly skip the packing call and inline weights and bias.
360        w_node, b_node = param_node.args[0], param_node.args[1]
361        assert isinstance(w_node, torch.fx.Node)
362        assert b_node is None or isinstance(b_node, torch.fx.Node)
363        (
364            param_0,
365            param_1,
366        ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_qtensor(
367            gm, w_node, b_node
368        )
369        op_res_node = gm.graph.call_function(
370            torch.ops.aten.linear, (inp_node, param_0, param_1, *param_node.args[2:])
371        )
372    else:
373        # Using LinearPackedParams.
374        (
375            param_0,
376            param_1,
377        ) = insert_weight_and_bias_get_attr_node_from_get_attr_to_scriptobject(
378            gm, param_node
379        )  # type: ignore[assignment]
380        op_res_node = gm.graph.call_function(
381            torch.ops.aten.linear, (inp_node, param_0, param_1)
382        )
383    return op_res_node, scale_node, zero_point_node
384
385
386def _transform_op_where_last_two_arguments_are_scale_and_zero_point(
387    gm: torch.fx.GraphModule, node: torch.fx.Node
388):
389    """
390    This transformation function can be used for function where the last two
391    parameters are scale and zero point. Additionally, the function's parameters
392    do not need any unpacking.
393    """
394    to_standard_op = {
395        "mul": torch.ops.aten.mul,
396        "mul_relu": torch.ops.aten.mul,
397        "add": torch.ops.aten.add,
398        "add_relu": torch.ops.aten.add,
399        "softmax": torch.ops.aten.softmax,
400        "cat": torch.ops.aten.cat,
401        "hardswish": torch.ops.aten.hardswish,
402    }
403
404    assert isinstance(node.target, torch._ops.OpOverload)
405    opname, args = node.target._opname, node.args
406    scale_node, zero_point_node = args[-2], args[-1]
407    op_res_node = gm.graph.call_function(to_standard_op[opname], tuple(args[:-2]))
408    return op_res_node, scale_node, zero_point_node
409
410
411def _transform_scalar_arithmetic(gm: torch.fx.GraphModule, node: torch.fx.Node):
412    """Transform scalar overload for basic arithmetic."""
413    to_standard_op = {
414        "mul": torch.ops.aten.mul.Scalar,
415        "add": torch.ops.aten.add.Scalar,
416    }
417    assert isinstance(node.target, torch._ops.OpOverload)
418    opname, args = node.target._opname, node.args
419    op_res_node = gm.graph.call_function(to_standard_op[opname], args)
420    return op_res_node, _SCALE, _ZERO_POINT
421
422
423def _transform_prepacked_op(gm: torch.fx.GraphModule, node: torch.fx.Node):
424    """
425    Transformation for functions under prepacked namespace, where they share
426    the same handling logic that [...]OpContext contains all parameters.
427    """
428    assert isinstance(node.target, torch._ops.OpOverload)
429    opname, args = node.target._opname, node.args
430    op_f = None
431    if opname == "conv2d_clamp_run":
432        op_f = torch.ops.aten.conv2d
433    elif opname == "linear_clamp_run":
434        op_f = torch.ops.aten.linear
435    else:
436        raise RuntimeError(f"Invalid operator {opname}")
437
438    assert isinstance(args[1], torch.fx.Node)
439    so = get_script_object(gm, args[1])
440
441    func_args = []
442    func_args += [args[0]]
443    func_args += so.unpack()[:2]  # type: ignore[attr-defined]
444    if opname == "conv2d_clamp_run":
445        func_args += torch.ops.prepacked.unpack_prepacked_sizes_conv2d(so)[2:]
446
447    op_res_node = gm.graph.call_function(op_f, tuple(func_args))
448    return op_res_node
449
450
451def _transform_batch_norm(gm: torch.fx.GraphModule, node: torch.fx.Node):
452    args = node.args
453    scale_node, zero_point_node = args[-2], args[-1]
454    op_res_node = gm.graph.call_function(
455        torch.ops.aten.native_batch_norm, (*args[:-3], False, 0.1, args[-3])
456    )
457    op_res_node = gm.graph.call_function(operator.getitem, (op_res_node, 0))
458    return op_res_node, scale_node, zero_point_node
459
460
461def fx_transform_quantized_op_to_standard_op(
462    gm: torch.fx.GraphModule, node: torch.fx.Node
463) -> torch.fx.Node:
464    global _SCALE, _ZERO_POINT, _INPUT_Q_DTYPE
465
466    assert isinstance(node.target, torch._ops.OpOverload)
467    opname, overload = node.target._opname, node.target._overloadname
468
469    key = f"{opname}.{overload}"
470    opname_to_transform_f = {
471        "conv1d.new": _transform_conv_with_packedparam,
472        "conv1d_relu.new": _transform_conv_with_packedparam,
473        "conv1d.default": _transform_conv_with_packedparam,
474        "conv1d_relu.default": _transform_conv_with_packedparam,
475        "conv2d.new": _transform_conv_with_packedparam,
476        "conv2d_relu.new": _transform_conv_with_packedparam,
477        "conv2d.default": _transform_conv_with_packedparam,
478        "conv2d_relu.default": _transform_conv_with_packedparam,
479        "linear.default": _transform_linear_with_packedparam,
480        "linear_relu.default": _transform_linear_with_packedparam,
481        "add.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
482        "add_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
483        "mul.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
484        "mul_relu.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
485        "softmax.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
486        "cat.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
487        "hardswish.default": _transform_op_where_last_two_arguments_are_scale_and_zero_point,
488        "batch_norm2d.default": _transform_batch_norm,
489        "mul.Scalar": _transform_scalar_arithmetic,
490        "add.Scalar": _transform_scalar_arithmetic,
491    }
492
493    if f"{key}" not in opname_to_transform_f:
494        raise RuntimeError(f"Unsupported quantized op during transformation: {key}")
495
496    op_res_node, scale_node, zero_point_node = opname_to_transform_f[f"{key}"](gm, node)
497
498    # Add fused activation layer.
499    op_res_node = insert_fused_activation_node(gm, opname, op_res_node)
500    _SCALE, _ZERO_POINT = scale_node, zero_point_node
501
502    assert _INPUT_Q_DTYPE is not None
503    qmin_node, qmax_node = insert_qmin_qmax_node(gm, _INPUT_Q_DTYPE)
504    q_fx_node = insert_quantized_node(
505        gm,
506        op_res_node,
507        scale_node,
508        zero_point_node,
509        qmin_node,
510        qmax_node,
511        _INPUT_Q_DTYPE,
512        torch.per_tensor_affine,
513    )
514    dq_fx_node = insert_dequantized_node(
515        gm,
516        q_fx_node,
517        scale_node,
518        zero_point_node,
519        qmin_node,
520        qmax_node,
521        _INPUT_Q_DTYPE,
522        None,
523        torch.per_tensor_affine,
524    )
525    return dq_fx_node
526
527
528def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
529    """
530    Replace legacy quantized ops (aten.quantize_per_tensor, quantized.conv) with
531    PT2 ops (quantize_decomposed.quantize_per_tensor, aten.conv).
532
533    Before:    x || -> aten.q        || -> quantized.conv2d     || -> quantized.linear    || -> aten.dq || -> y
534
535    After:     x || -> qd.q -> qd.dq || -> aten.conv2d -> qd.q -> qd.dq || aten.linear -> qd.q -> qd.dq || -> y
536
537    (qd == quantized_decomposed library, q = quantize, dq = dequantize)
538                                          ^
539                                          |
540                getattr(w), getattr(b) from Conv2dParamPrepack
541
542    During each iteration, the transformation spits out the transformed operator, its quantized output,
543    and its dequantized value together. We did this because dequantization need to use the
544    scale and zero point parameters from the quantization to recover the approximate original value. After each
545    iteration, the new dequantization node will be used as the input to the next node (e.g., dq2 -> linear).
546
547    For operators like conv2d and linear, their weights and bias are packed in a quantized format in the ScriptObject.
548    During the transformation, we unpack those objects, get their dequantized tensor, populate those
549    as attributes to the module, and use getattr to access them.
550
551    One exception in the transformation is conv_prepack and linear_prepack. Those calls pack
552    weight and bias constant tensors into ScriptObject, which are then used by subsequent conv2d or linear calls.
553    During transformation, we directly skip transforming conv_prepack or linear_prepack. We check whether ScriptObject to the
554    quantized::conv2d or linear is from conv_prepack or linear_prepack. If it is, we then inline those parameters
555    to the operator by converting them to a getattr fx.node.
556
557    For prepacked::conv2d_clamp_run and prepacked::linear_clamp_run, we directly convert them to aten.conv2d and aten.linear
558    without the need of doing de/quantization.
559
560    Three global variables defined are _INPUT_Q_DTYPE, _SCALE, _ZERO_POINT. _INPUT_Q_DTYPE determines the de/quantization
561    data type, which is the same across the entire program, but it only shows up in the very first quantization
562    call. _SCALE and _ZERO_POINT are used only when operators do not have those specified. E.g., mul.Scalar.
563    """
564
565    global _INPUT_Q_DTYPE
566
567    quantized = False
568
569    last_quantized_node = None
570    for node in gm.graph.nodes:
571        if isinstance(node.target, OpOverload):
572            with gm.graph.inserting_before(node):
573                namespace, opname = node.target.namespace, node.target._opname
574                if namespace == "quantized" and opname not in [
575                    "conv_prepack",
576                    "linear_prepack",
577                ]:
578                    quantized = True
579                    fx_node = fx_transform_quantized_op_to_standard_op(gm, node)
580                    node.replace_all_uses_with(fx_node)
581                    last_quantized_node = fx_node
582                elif namespace == "prepacked":
583                    quantized = True
584                    fx_node = _transform_prepacked_op(gm, node)
585                    node.replace_all_uses_with(fx_node)
586                    last_quantized_node = fx_node
587                elif namespace == "aten" and opname == "quantize_per_tensor":
588                    inp_node, scale_node, zero_point_node, dtype_node = node.args
589                    dtype_node = fx_enum_to_dtype(gm, dtype_node)
590                    _INPUT_Q_DTYPE = dtype_node
591                    qmin_node, qmax_node = insert_qmin_qmax_node(gm, dtype_node)
592                    q_fx_node = insert_quantized_node(
593                        gm,
594                        inp_node,
595                        scale_node,
596                        zero_point_node,
597                        qmin_node,
598                        qmax_node,
599                        dtype_node,
600                        torch.per_tensor_affine,
601                    )
602                    dq_fx_node = insert_dequantized_node(
603                        gm,
604                        q_fx_node,
605                        scale_node,
606                        zero_point_node,
607                        qmin_node,
608                        qmax_node,
609                        dtype_node,
610                        None,
611                        torch.per_tensor_affine,
612                    )
613                    node.replace_all_uses_with(dq_fx_node)
614                    last_quantized_node = dq_fx_node
615                elif namespace == "aten" and opname == "dequantize":
616                    assert last_quantized_node is not None
617                    node.replace_all_uses_with(last_quantized_node)
618                else:
619                    last_quantized_node = node
620
621    # Post-processing again to remove legacy ScriptObjects and quantizated tensors
622    # stored as attributes or in the buffer. This is used to clean up the GraphModule
623    # to not trigger tracing errors like missing __obj_flatten__ functions.
624    def _clean_attr(mod: torch.nn.Module):
625        for submod in mod.modules():
626            attr_names_to_clean = set()
627            for k, v in submod.__dict__.items():
628                if isinstance(v, torch.ScriptObject):
629                    attr_names_to_clean.add(k)
630                if k == "_buffers":
631                    buffer_name_to_clean = set()
632                    for b_name, b_value in v.items():
633                        if isinstance(b_value, torch.Tensor) and b_value.dtype in [
634                            torch.qint8,
635                            torch.quint8,
636                        ]:
637                            buffer_name_to_clean.add(b_name)
638                    for b_name in buffer_name_to_clean:
639                        v.pop(b_name, None)
640            for attr_name in attr_names_to_clean:
641                delattr(submod, attr_name)
642
643    if quantized:
644        """
645        TODO: SetAttr + quantized ops will result incorrect program. This flag is used to temporarily
646        bypass test cases.
647
648        The deadcode elimination pass is needed to remove legacy quantized ops. Otherwise, retracing
649        will throw errors. However, the current way of SetAttr does inplace update to attributes, so
650        this pass regard them as dead code and remove them. Below is an example of GraphModule before
651        and after the dead code elimination pass.
652
653        class GraphModule(torch.nn.Module):
654            def forward(self, x_1):
655                # No stacktrace found for following nodes
656                data = self.data;  data = None
657                data_1 = self.data
658                add_tensor = torch.ops.aten.add.Tensor(data_1, x_1, alpha = 1);  data_1 = None
659                data_2 = self.data
660                copy_ = torch_Tensor_copy_(data_2, add_tensor);  data_2 = add_tensor = copy_ = None
661                data_3 = self.data
662                add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1);  x_1 = data_3 = None
663                return add_tensor_1
664
665        class GraphModule(torch.nn.Module):
666            def forward(self, x_1):
667                # No stacktrace found for following nodes
668                data_3 = self.data
669                add_tensor_1 = torch.ops.aten.add.Tensor(x_1, data_3, alpha = 1);  x_1 = data_3 = None
670                return add_tensor_1
671        """
672        gm.graph.eliminate_dead_code()
673        _clean_attr(gm)
674