1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import copy
8
9import torch
10from executorch.exir.dialects._ops import ops as exir_ops
11from executorch.exir.pass_base import ExportPass, PassResult
12
13"""
14The passes below were taking fron bolt/nn/executorch/passes/quant_fusion.py
15"""
16
17T_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
18
19
20class DequantDuplicator:
21    def __init__(self, pass_obj):
22        self.pass_obj = pass_obj
23        self.duplicated_dequant = False
24
25    def __call__(self, arg):
26        if (
27            not isinstance(arg, torch.fx.Node)
28            or arg.op != "call_function"
29            or arg.target != T_DQuantPerTensor  # TODO handle per channel case
30        ):
31            return arg
32
33        if arg not in self.pass_obj.dequant_map:
34            self.pass_obj.dequant_map[arg] = (arg.args, arg.kwargs, arg.meta)
35            return arg
36        else:
37            args, kwargs, meta = self.pass_obj.dequant_map[arg]
38            with self.pass_obj.mod.graph.inserting_before(self.pass_obj.current_node):
39                dup_dequant = self.pass_obj.mod.graph.call_function(
40                    T_DQuantPerTensor,
41                    args=args,
42                    kwargs=kwargs,
43                )
44                dup_dequant.meta = meta
45                dup_dequant.meta["val"] = copy.copy(meta["val"])
46            self.duplicated_dequant = True
47            return dup_dequant
48
49
50class DuplicateDequantNodePass(ExportPass):
51    """
52    Duplicates all of the dequantize. This is such that all
53    quantized ops have their own unique dequant nodes. Since
54    quantized ops are represented as dq -> op -> q. Sharing dq nodes
55    for quantized ops makes it impossibl to partition against quantized
56    ops. As a result we need to duplicate dq nodes for ops which
57    share a dq node
58
59    In this example, the graph below:
60
61                            --> op --> q
62                          /
63      dq -> op -> q -> dq
64                          \
65                            --> op --> q
66
67    is transformed into:
68
69                      --> dq --> op --> q
70                    /
71      dq -> op -> q
72                    \
73                      --> dq --> op --> q
74
75
76    """
77
78    def __init__(self):
79        super().__init__()
80        self.dequant_map = {}  # Map of dequant results to its node's arguments
81
82    def call(self, mod):
83        self.mod = mod
84        duplicator = DequantDuplicator(self)
85        for node in list(mod.graph.nodes):
86            self.current_node = node
87
88            if node.op != "call_function":
89                continue
90
91            new_args = []
92            duplicator.duplicated_dequant = False
93            for arg in node.args:
94                if isinstance(arg, list):
95                    new_args.append(list(map(duplicator, arg)))
96                else:
97                    new_args.append(duplicator(arg))
98
99            if duplicator.duplicated_dequant:
100                with mod.graph.inserting_before(node):
101                    new_node = mod.graph.call_function(
102                        node.target,
103                        args=tuple(new_args),
104                        kwargs=node.kwargs,
105                    )
106                    new_node.meta = node.meta
107                node.replace_all_uses_with(new_node)
108        mod.graph.eliminate_dead_code()
109        mod.recompile()
110
111        return PassResult(mod, True)
112