xref: /aosp_15_r20/external/executorch/backends/cadence/aot/fuse_ops.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
3
4# This file contains all the functions that fuse ops in the fx graph.
5
6import logging
7import math
8import operator
9from collections import deque
10from numbers import Number
11from typing import cast, Sequence
12
13import torch
14import torch.fx
15from executorch.backends.cadence.aot.compiler_utils import (
16    broadcastable,
17    get_cascaded_ops,
18    get_permuted_dims,
19    get_scale,
20    get_shape,
21    get_tensor_from_attr,
22    get_transposed_dims,
23    get_zero_point,
24)
25from executorch.backends.cadence.aot.pass_utils import (
26    CadencePassAttribute,
27    register_cadence_pass,
28)
29from executorch.backends.cadence.aot.utils import get_edge_overload_packet
30from executorch.exir.dialects._ops import ops as exir_ops
31from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
32from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
33from executorch.exir.passes import dead_code_elimination_pass
34from executorch.exir.passes.spec_prop_pass import SpecPropPass
35from torch.fx.node import Argument
36from torch.nn.utils.fusion import fuse_conv_bn_weights
37
38
39@register_cadence_pass(CadencePassAttribute(opt_level=1))
40class FuseMMWithAdd(ExportPass):
41    # Return true if the node is a view node.
42
43    def is_view_node(self, node: torch.fx.Node):
44        return node.target == exir_ops.edge.aten.view_copy.default
45
46    def fuse_mm_with_add(self, graph_module: torch.fx.GraphModule):
47        """
48        Given a graph of the form:
49        X = aten.mm(A, B)
50        Y = aten.add(X, C)
51        Fuse X and Y into a single addmm node, after making sure that we can
52        broadcast C into X.
53        There could be view node that takes a view of X, and feeds that
54        to the aten.add node:
55        X = aten.mm(A, B)
56        Y = X.view()
57        Z = aten.add(Y, C)
58        Handle this case as well. There are a few conditions for the
59        optimization to be valid:
60        1. There should be a single user of the mm node, otherwise we cannot
61        remove it.
62        2. There should be a single user of the add node, otherwise we cannot
63        fuse it with mm.
64        """
65        graph = graph_module.graph
66        for node in graph.nodes:
67            # We want to discover a chain of mm -> add, or mm -> view -> add.
68            # Only proceed if the current node is an mm node, and has only one
69            # user/successor.
70            if node.target != exir_ops.edge.aten.mm.default or len(node.users) != 1:
71                continue
72
73            # Our addmm implementation computes (mat1 * mat2 + bias). So the
74            # addmm node in the graph should have three args. We collectively
75            # term mat1 and mat2 as mm_arg since they are the args of mm node,
76            # and bias as bias_arg.
77            # Since we already have discovered the mm node, we can get mat1 and
78            # mat2 by iterating over its args. So the current node is mm_arg.
79            # bias_arg can be found once we discover the add op that consumes
80            # the output of this mm node. Our next step is to find the add op.
81            mm_arg = node
82            user = list(node.users.keys())[0]
83            # intermediate_view is True when the fusion case is mm -> view -> add
84            intermediate_view = False
85            # Check if the single user of the mm node is a view op. If so, our
86            # graph could potentially have mm -> view -> add. We need to skip
87            # the view op, and check if its successor is the add op. One condition
88            # we need to verify is that the view op must have only a single user
89            # (the add op).
90            if self.is_view_node(user) and len(user.users) == 1:
91                # We want to maintain two invariants:
92                # (1) 'user' is a potential add op that will get fused with the
93                #     mm node;
94                # (2) 'node' is the single predecessor of 'user' that is either
95                #     the mm node, or the current view node;
96                # To maintain the invariant, we must mark this view op as 'node',
97                # and its single successor as 'user'.
98                intermediate_view = True
99                node = user
100                user = list(node.users.keys())[0]
101
102            # Thanks to the invariant, we can now simply check if 'user' is an
103            # add op. We also want to ensure that the add op has only one user,
104            # otherwise we will get not be able to eliminate add op post fusion.
105            if user.target != exir_ops.edge.aten.add.Tensor or len(user.users) != 1:
106                continue
107
108            # At this point, we have found an mm and an add node that we can
109            # fuse together. One arg of the add op is 'node' (thanks to the
110            # invariant). Find the other arg, and tag it as bias_arg.
111            assert len(user.args) == 2
112            bias_arg = user.args[1] if user.args[0] == node else user.args[0]
113
114            # As a last check, make sure that we can broadcast the bias tensor
115            # to the output of mm.
116            mm_arg_shape = get_shape(graph_module, mm_arg)
117            bias_arg_shape = get_shape(graph_module, bias_arg)
118            if (
119                mm_arg_shape is None
120                or bias_arg_shape is None
121                or not broadcastable(mm_arg_shape, bias_arg_shape)
122            ):
123                continue
124
125            # Create a new addmm node, and insert it before add node. DCE should
126            # take care of removing the dead mm and/or view node. Based on the
127            # invariant, add node corresponds to 'user'.
128            with graph.inserting_before(user):
129                addmm_node = graph.call_function(
130                    exir_ops.edge.aten.addmm.default,
131                    args=(bias_arg, mm_arg.args[0], mm_arg.args[1]),
132                )
133            # Replace all the uses of add node with addmm node, and remove add
134            # node from the graph.
135            user.replace_all_uses_with(addmm_node)
136            graph.erase_node(user)
137
138            # As a finishing step, we want to ensure that the output of addmm is
139            # in the expected shape. For example, Let us assume the following
140            # input, where A, B are (4, 4) sized tensors, and C is (1, 4) sized
141            # tensor.
142            # T1 = torch.mm(A, B)
143            # T2 = T1.view((2, 2, 4))
144            # return torch.add(T2, C)
145            # Here, the expectation is to get an output of size (2, 2, 4), which
146            # is the shape out of view node T2. However, the fused addmm will
147            # return an output of shape (4, 4). In a nutshell, we need to take
148            # care of the output shape when the following two conditions are met:
149            # 1. The fusion case is mm -> view -> add (i.e., intermediate_view
150            #    is True)
151            # 2. The single successor of addmm is not a view op.
152            addmm_user = list(addmm_node.users.keys())[0]
153            if intermediate_view and not self.is_view_node(addmm_user):
154                # Create a view node that correctly reshapes the output of addmm
155                # (i.e., 'user') to match the output shape of the add node.
156                # Thanks to our invariant, we know that the correct shape is held
157                # by 'node', which points to the view op in mm -> view -> add chain.
158                # We create its copy, and insert it just before addmm_user.
159                with graph.inserting_before(addmm_user):
160                    view_copy_node = graph_module.graph.node_copy(node)
161                # Any uses of addmm are replaced with this view_copy node.
162                addmm_node.replace_all_uses_with(view_copy_node)
163                # Now we massage the args of the view_copy node, so that it takes
164                # view of addmm node.
165                view_args = list(view_copy_node.args)
166                view_args[0] = addmm_node
167                view_copy_node.args = tuple(view_args)
168
169        graph_module.recompile()
170
171    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
172        # Compute the spec prop pass before we begin the fusion pipeline
173        result = SpecPropPass()(graph_module)
174        assert result is not None
175        self.fuse_mm_with_add(result.graph_module)
176        result = super().call(result.graph_module)
177        return result
178
179
180@register_cadence_pass(CadencePassAttribute(opt_level=1))
181class FuseBatchNormWithConv(ExportPass):
182    """
183    This pass fuses a conv op with batchnorm if the following two conditions
184    are met:
185    1. The only user of conv op should be batchnorm;
186    2. Only the first element from the batchnorm output tuple should be used
187    in the graph.
188    """
189
190    def fuse_batch_norm_with_conv(self, graph_module: torch.fx.GraphModule) -> None:
191        graph = graph_module.graph
192        for conv in graph.nodes:
193            # We want to discover a chain of conv1d -> batch_norm.
194            # Only proceed if the current node is a conv1d node, and has a single
195            # user/successor.
196            if (
197                conv.target != exir_ops.edge.aten.convolution.default
198                or len(conv.users) != 1
199            ):
200                continue
201
202            # The single user of conv op must be batch_norm. If not, bail.
203            bn = list(conv.users.keys())[0]
204            if bn.target != exir_ops.edge.aten.native_batch_norm.default:
205                continue
206
207            # All the users of batchnorm node must be getitem ops. batchnorm
208            # returns a 3-element tuple. Each user must only access the first
209            # element of the tuple.
210            if [
211                (user.target == operator.getitem and user.args[1] == 0)
212                for user in bn.users
213            ].count(False):
214                continue
215
216            # Check that the weights for conv1d and batchnorm are both params
217            if [node.op == "get_attr" for node in {conv.args[1], bn.args[1]}].count(
218                False
219            ):
220                continue
221
222            # Get the parameters from conv op
223            assert len(conv.args) == 9
224            conv_weight = get_tensor_from_attr(graph_module, conv.args[1])
225            assert isinstance(conv_weight, torch.Tensor)
226            conv_bias = get_tensor_from_attr(graph_module, conv.args[2])
227            transpose = conv.args[6]
228
229            # Get the parameters from the batchnorm op
230            assert len(bn.args) == 8
231            bn_weight = get_tensor_from_attr(graph_module, bn.args[1])
232            bn_bias = get_tensor_from_attr(graph_module, bn.args[2])
233            running_mean = get_tensor_from_attr(graph_module, bn.args[3])
234            assert isinstance(running_mean, torch.Tensor)
235            running_var = get_tensor_from_attr(graph_module, bn.args[4])
236            assert isinstance(running_var, torch.Tensor)
237            eps = bn.args[-1]
238
239            # Compute the updated weight and bias after fusing conv op
240            # with batchnorm op.
241            fused_weight, fused_bias = fuse_conv_bn_weights(
242                conv_weight,
243                conv_bias,
244                running_mean,
245                running_var,
246                eps,
247                bn_weight,
248                bn_bias,
249                transpose,
250            )
251
252            # Modify the graph by updating the weight and bias of conv op
253            # with the fused weight and bias params, and replacing all the users
254            # of getitem(batchnorm) with the conv op.
255            with graph.inserting_before(conv):
256                fused_weight_name = f"_fused_with_bn_weight_{self.counter}"
257                graph_module.register_parameter(fused_weight_name, fused_weight)
258                fused_weight_node = graph.get_attr(fused_weight_name)
259                fused_bias_name = f"_fused_with_bn_bias_{self.counter}"
260                graph_module.register_parameter(fused_bias_name, fused_bias)
261                fused_bias_node = graph.get_attr(fused_bias_name)
262
263            # Update the weight and bias of conv op
264            conv_args = list(conv.args)
265            conv_args[1] = fused_weight_node
266            conv_args[2] = fused_bias_node
267            conv.args = tuple(conv_args)
268            # Remove any use of batchnorm from the graph
269            for user in bn.users:
270                assert user.target == operator.getitem
271                user.replace_all_uses_with(conv)
272            self.counter += 1
273
274        graph_module.recompile()
275
276    def __init__(self):
277        super().__init__()
278        self.counter = 0
279
280    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
281        self.fuse_batch_norm_with_conv(graph_module)
282        result = super().call(graph_module)
283        return result
284
285
286@register_cadence_pass(CadencePassAttribute(opt_level=1))
287class FuseQuantizedBatchNormWithConv(ExportPass):
288    """
289    This pass fuses a quantized::conv op with quantized::batchnorm if the
290    following two conditions are met:
291    1. The only user of quantized::conv op should be quantized::batchnorm;
292    2. The outputs of both ops are quantized with same scale and zero_point
293    """
294
295    def fuse_quantized_batch_norm_with_conv(
296        self, graph_module: torch.fx.GraphModule
297    ) -> None:
298        graph = graph_module.graph
299        for conv in graph.nodes:
300            # We want to discover a chain of quantized::conv1d ->
301            # quantized::batch_norm. Only proceed if the current node is a
302            # quantized::conv node, and has a single user/successor.
303            if (
304                conv.target
305                not in {
306                    exir_ops.edge.quantized.conv1d.default,
307                    exir_ops.edge.quantized.conv2d.new,
308                }
309                or len(conv.users) != 1
310            ):
311                continue
312
313            # The single user of conv op must be batch_norm. If not, bail.
314            bn = list(conv.users.keys())[0]
315            if bn.target not in {
316                exir_ops.edge.quantized.batch_norm1d.default,
317                exir_ops.edge.quantized.batch_norm2d.default,
318            }:
319                continue
320
321            # The outputs of conv and bn must both have same scale and zero_point
322            if not math.isclose(
323                conv.args[-2], bn.args[-2], rel_tol=1e-05, abs_tol=1e-05
324            ):
325                continue
326            if conv.args[-1] != bn.args[-1]:
327                continue
328
329            # The weight and bias of quantized::conv op are packed in the second
330            # arg. Unpack them.
331            assert conv.args[1].op == "get_attr"
332            packed_args = getattr(graph_module, conv.args[1].target)
333            conv_weight_tensor, conv_bias_tensor = packed_args.unpack()
334            # Assert that we have discovered the conv op's weight and bias tensors
335            assert isinstance(conv_weight_tensor, torch.Tensor)
336            assert conv_bias_tensor is None or isinstance(
337                conv_bias_tensor, torch.Tensor
338            )
339
340            # Get the scale, zero_point, and dtype of convolution weight
341            assert conv_weight_tensor.is_quantized
342            per_tensor_quantization = (
343                conv_weight_tensor.qscheme() == torch.per_tensor_affine
344            )
345            weight_dtype = conv_weight_tensor.dtype
346            weight_scale = get_scale(conv_weight_tensor)
347            weight_zero_point = get_zero_point(conv_weight_tensor, reduce=False)
348            weight_axis = (
349                0
350                if per_tensor_quantization
351                else conv_weight_tensor.q_per_channel_axis()
352            )
353            # Dequantize the convolution weight
354            conv_weight_tensor = conv_weight_tensor.dequantize()
355
356            # Get the parameters from the batchnorm op
357            assert len(bn.args) == 8
358            (bn_weight, bn_bias, running_mean, running_var, eps) = bn.args[1:6]
359            # Get the tensors from the batchnorm args
360            bn_weight_tensor = get_tensor_from_attr(graph_module, bn_weight)
361            bn_bias_tensor = get_tensor_from_attr(graph_module, bn_bias)
362            running_mean_tensor = get_tensor_from_attr(graph_module, running_mean)
363            running_var_tensor = get_tensor_from_attr(graph_module, running_var)
364
365            # Assert that we have discovered the batch_norm op's tensors
366            assert bn_weight_tensor is None or isinstance(
367                bn_weight_tensor, torch.Tensor
368            )
369            assert bn_bias_tensor is None or isinstance(bn_bias_tensor, torch.Tensor)
370            assert isinstance(running_mean_tensor, torch.Tensor)
371            assert isinstance(running_var_tensor, torch.Tensor)
372
373            # Get the fused weights and bias
374            fused_weight, fused_bias = fuse_conv_bn_weights(
375                conv_weight_tensor,
376                conv_bias_tensor,
377                running_mean_tensor,
378                running_var_tensor,
379                eps,
380                bn_weight_tensor,
381                bn_bias_tensor,
382                transpose=False,
383            )
384
385            # Requantize the fused weight with the scale and zero point of the
386            # quantized::conv's weight
387            if per_tensor_quantization:
388                fused_weight = torch.quantize_per_tensor(
389                    fused_weight,
390                    weight_scale.item(),
391                    cast(int, weight_zero_point.item()),
392                    weight_dtype,
393                )
394            else:
395                fused_weight = torch.quantize_per_channel(
396                    fused_weight,
397                    weight_scale,
398                    weight_zero_point,
399                    weight_axis,
400                    weight_dtype,
401                )
402
403            # Now that we have the fused weight and bias, pack them for the
404            # quantized::conv.
405            stride = packed_args.stride()
406            padding = packed_args.padding()
407            dilation = packed_args.dilation()
408            groups = packed_args.groups()
409            args = (fused_weight, fused_bias, stride, padding, dilation, groups)
410            packed_args = (
411                exir_ops.edge.quantized.conv1d_prepack(*args)
412                if conv.target == exir_ops.edge.quantized.conv1d.default
413                else exir_ops.edge.quantized.conv2d_prepack(*args)
414            )
415
416            # Modify the graph by updating the weight and bias of conv op
417            # with the fused weight and bias params, and replacing all the users
418            # of batchnorm with the conv op.
419            conv_args = list(conv.args)
420            conv_args[1] = packed_args
421            conv.args = tuple(conv_args)
422            bn.replace_all_uses_with(conv)
423            graph.erase_node(bn)
424            self.counter += 1
425
426        # Note: there is a quantized.conv2d.new operator in the resulting graph
427        # that takes a torch.classes.quantized.Conv2dPackedParamsBase as one of the input
428        # this prevents us to directly call graph_module.recompile().
429        graph_module._code = graph_module._graph.python_code(root_module="self").src
430
431    def __init__(self):
432        super().__init__()
433        self.counter = 0
434
435    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
436        self.fuse_quantized_batch_norm_with_conv(graph_module)
437        result = super().call(graph_module)
438        return result
439
440
441@register_cadence_pass(CadencePassAttribute(opt_level=1))
442class FuseCascadedTransposeOrPermuteOps(ExportPass):
443    """
444    Fuse a cascaded chain of transpose and permute ops
445    """
446
447    transpose_or_permute_target = {
448        exir_ops.edge.aten.transpose_copy.int,
449        exir_ops.edge.aten.permute_copy.default,
450    }
451
452    # Find a chain of transpose or permute ops, and fuse them into a single permute op.
453
454    def fuse_cascaded_transpose_or_permute_ops(
455        self, graph_module: torch.fx.GraphModule
456    ):
457        graph = graph_module.graph
458        for node in graph.nodes:
459            # We are only interested in permute/transpose ops
460            if node.target not in self.transpose_or_permute_target:
461                continue
462            # Get the cascaded chain of transpose/permute ops starting at node
463            cascaded_transpose_or_permute_ops = get_cascaded_ops(
464                [node], self.transpose_or_permute_target
465            )
466            # The chain must have more than 1 node
467            if len(cascaded_transpose_or_permute_ops) == 1:
468                continue
469
470            out_shape = get_shape(graph_module, node)
471            assert out_shape is not None
472            out_dims = len(out_shape)
473            # This is the trivial dimension order
474            dims = list(range(out_dims))
475            # Compute the effect of the chain on dims
476            for tp in cascaded_transpose_or_permute_ops:
477                dims = (
478                    get_transposed_dims(tp, dims)
479                    if tp.target == exir_ops.edge.aten.transpose_copy.int
480                    else get_permuted_dims(tp, dims)
481                )
482
483            # In case the permute chain cancelled each other, the final dims will
484            # be the same as the initial order. In that case, the chain was nop.
485            # Otherwise create a new permute op that encompasses the effect of the
486            # chain.
487            if dims == list(range(out_dims)):
488                cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(
489                    node.args[0]
490                )
491            else:
492                with graph.inserting_before(cascaded_transpose_or_permute_ops[-1]):
493                    new_permute = graph.call_function(
494                        exir_ops.edge.aten.permute_copy.default,
495                        args=(node.args[0], dims),
496                    )
497                cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(new_permute)
498
499            # Now erase the chain
500            for tp in reversed(cascaded_transpose_or_permute_ops):
501                graph.erase_node(tp)
502
503        graph_module.recompile()
504
505    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
506        self.fuse_cascaded_transpose_or_permute_ops(graph_module)
507        result = super().call(graph_module)
508        return result
509
510
511@register_cadence_pass(CadencePassAttribute(opt_level=1))
512class FuseCascadedViewOps(ExportPass):
513    """
514    Fuse a cascaded chain of view ops
515    """
516
517    # Find a chain of view ops, and fuse them into a single permute op.
518
519    def fuse_cascaded_view_ops(self, graph_module: torch.fx.GraphModule):
520        graph = graph_module.graph
521        for node in graph.nodes:
522            # We are only interested in view ops
523            if node.target != exir_ops.edge.aten.view_copy.default:
524                continue
525
526            # Get the cascaded chain of view ops starting at node
527            cascaded_view_ops = get_cascaded_ops(
528                [node], [exir_ops.edge.aten.view_copy.default]
529            )
530            # The chain must have more than 1 node
531            if len(cascaded_view_ops) == 1:
532                continue
533
534            last_view_node = cascaded_view_ops[-1]
535            with graph.inserting_before(last_view_node):
536                new_view = graph.call_function(
537                    exir_ops.edge.aten.view_copy.default,
538                    args=(node.args[0], last_view_node.args[1]),
539                )
540                last_view_node.replace_all_uses_with(new_view)
541
542            # Now erase the chain
543            for v in reversed(cascaded_view_ops):
544                graph.erase_node(v)
545
546        graph_module.recompile()
547
548    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
549        self.fuse_cascaded_view_ops(graph_module)
550        dead_code_elimination_pass(graph_module)
551        result = super().call(graph_module)
552        return result
553
554
555class FuseOpPairsAcrossBranchesPass(ExportPass):
556    def check_ok_to_fuse(
557        self,
558        producer: torch.fx.Node,
559        consumers: list[torch.fx.Node],
560    ) -> bool:
561        # Always ok to replace / remove.
562        return True
563
564    def can_fuse_for_chain(
565        self,
566        producer: torch.fx.Node,
567        consumer: torch.fx.Node,
568        consumer_op_packets: set[EdgeOpOverloadPacket],
569    ) -> bool:
570        """
571        Returns true if producer and consumer can be fused for a single chain
572        (-> producer -> ops -> consumer ->) to (-> ops -> fused_op)
573        """
574        if (
575            isinstance(consumer.target, EdgeOpOverload)
576            and get_edge_overload_packet(consumer.target) in consumer_op_packets
577        ):
578            return True
579        return False
580
581    def get_fuse_candidates(
582        self,
583        producer: torch.fx.Node,
584        consumer_op_packets: set[EdgeOpOverloadPacket],
585        bypass_ops: set[EdgeOpOverload],
586    ) -> list[torch.fx.Node]:
587        # Start by iterating over all the users of this node, and check
588        # if they are have their target in consumer_op_packets.
589        users = deque(producer.users.keys())
590        # This holds the list of the user ops that directly (or transitively
591        # via view/slice) consume this producer_op_packets, and hence can be removed.
592        removal_candidates = []
593        while users:
594            user = users.popleft()
595
596            # If the user is a bypass op, we bypass it, and examine
597            # its users instead for consumer_op_packets.
598            if user.target in bypass_ops:
599                users.extend(list(user.users.keys()))
600            elif self.can_fuse_for_chain(producer, user, consumer_op_packets):
601                removal_candidates.append(user)
602            else:
603                removal_candidates.clear()
604                break
605        return removal_candidates
606
607    def find_and_fuse(
608        self,
609        graph_module: torch.fx.GraphModule,
610        producer_op_packets: set[EdgeOpOverloadPacket],
611        consumer_op_packets: set[EdgeOpOverloadPacket],
612        bypass_ops: set[EdgeOpOverload],
613    ) -> None:
614        for node in graph_module.graph.nodes:
615            # We are only interested in ops that have overload target in
616            # producer_op.
617            if not (
618                isinstance(node.target, EdgeOpOverload)
619                and get_edge_overload_packet(node.target) in producer_op_packets
620            ):
621                continue
622
623            removal_candidates = self.get_fuse_candidates(
624                node, consumer_op_packets, bypass_ops
625            )
626
627            if len(removal_candidates) == 0:
628                # No candidates found.
629                continue
630
631            if not self.check_ok_to_fuse(node, removal_candidates):
632                # Not ok to remove quant-dequant pairs or replace with requantize.
633                continue
634
635            self.fuse(node, removal_candidates, graph_module)
636
637        graph_module.recompile()
638
639    def get_fused_node(
640        self,
641        producer: torch.fx.Node,
642        consumer: torch.fx.Node,
643        graph_module: torch.fx.GraphModule,
644    ) -> torch.fx.Node:
645        return consumer
646
647    def fuse(
648        self,
649        node: torch.fx.Node,
650        removal_candidates: list[torch.fx.Node],
651        graph_module: torch.fx.GraphModule,
652    ) -> None:
653        # Replace all the uses of the producer op with it's input.
654        node.replace_all_uses_with(cast(torch.fx.Node, node.args[0]))
655        graph_module.graph.erase_node(node)
656
657        # Iterate over all the removal candidates (quantize op users) and generate replacements.
658        for rnode in removal_candidates:
659            rnode.replace_all_uses_with(self.get_fused_node(node, rnode, graph_module))
660            graph_module.graph.erase_node(rnode)
661
662
663@register_cadence_pass(CadencePassAttribute(opt_level=1))
664class FuseQuantDequantToRequantizePass(FuseOpPairsAcrossBranchesPass):
665    """
666    Fuse dequantize-quantize op pairs to a single requantize op.
667    For the special case where quant params match, this will remove
668    both dequant and quant ops.
669    """
670
671    # A list of ops that can be bypassed when looking for a
672    # dequantize->quantize chain
673    bypass_ops: set[EdgeOpOverload] = {
674        exir_ops.edge.aten.slice_copy.Tensor,
675        exir_ops.edge.aten.view_copy.default,
676        exir_ops.edge.aten.clone.default,
677        exir_ops.edge.aten.transpose_copy.int,
678        exir_ops.edge.aten.permute_copy.default,
679    }
680
681    quantize_op_packets: set[EdgeOpOverloadPacket] = {
682        exir_ops.edge.cadence.quantize_per_tensor,
683        exir_ops.edge.quantized_decomposed.quantize_per_tensor,
684    }
685    dequantize_op_packets: set[EdgeOpOverloadPacket] = {
686        exir_ops.edge.cadence.dequantize_per_tensor,
687        exir_ops.edge.quantized_decomposed.dequantize_per_tensor,
688    }
689
690    def __init__(
691        self, allow_requantize: bool = True, force_quant_dequant_fusion: bool = False
692    ) -> None:
693        super().__init__()
694        self.allow_requantize: bool = allow_requantize
695        self.force_quant_dequant_fusion: bool = force_quant_dequant_fusion
696
697    def _pkg_name_match(self, node1: torch.fx.Node, node2: torch.fx.Node) -> bool:
698        # pyre-ignore[16]: Item `typing.Callable` has no attribute `_op`
699        return node1.target._op.namespace == node2.target._op.namespace
700
701    def can_fuse_for_chain(
702        self,
703        producer: torch.fx.Node,
704        consumer: torch.fx.Node,
705        consumer_op_packets: set[EdgeOpOverloadPacket],
706    ) -> bool:
707        return super().can_fuse_for_chain(
708            producer, consumer, consumer_op_packets
709        ) and self._pkg_name_match(producer, consumer)
710
711    def _create_requantize_node(
712        self,
713        in_tensor: torch.fx.Node,
714        in_scale: float,
715        in_zero_point: int,
716        out_scale: float,
717        out_zero_point: int,
718        out_dtype: torch.dtype,
719        graph: torch.fx.Graph,
720    ) -> torch.fx.Node:
721        in_scale_tensor = graph.call_function(
722            exir_ops.edge.aten.full.default, args=((1,), in_scale)
723        )
724        in_zero_point_tensor = graph.call_function(
725            exir_ops.edge.aten.full.default,
726            args=((1,), in_zero_point),
727            kwargs={"dtype": torch.int32},
728        )
729        out_scale_tensor = graph.call_function(
730            exir_ops.edge.aten.full.default, args=((1,), out_scale)
731        )
732        out_zero_point_tensor = graph.call_function(
733            exir_ops.edge.aten.full.default,
734            args=((1,), out_zero_point),
735            kwargs={"dtype": torch.int32},
736        )
737        # cadence::requantize(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, ScalarType out_dtype) -> Tensor Y
738        # TODO(hardiksharma): Add support for per-tensor requantize.
739        return graph.call_function(
740            exir_ops.edge.cadence.requantize.default,
741            args=(
742                in_tensor,
743                in_scale_tensor,
744                in_zero_point_tensor,
745                out_scale_tensor,
746                out_zero_point_tensor,
747                out_dtype,
748            ),
749        )
750
751    def _quant_params_match(self, node1: torch.fx.Node, node2: torch.fx.Node) -> bool:
752        return node1.args[1:] == node2.args[1:]
753
754    def check_ok_to_fuse(
755        self,
756        producer: torch.fx.Node,
757        consumers: list[torch.fx.Node],
758    ) -> bool:
759        """Check if all node-user pairs are nops or are ok to replace with requant."""
760        for rnode in consumers:
761            if self.allow_requantize or self._quant_params_match(producer, rnode):
762                # Cannot remove quant-dequant pair if quant params don't match and requantize
763                # is not allowed.
764                continue
765            return False
766        return True
767
768    def get_fused_node(
769        self,
770        producer: torch.fx.Node,
771        consumer: torch.fx.Node,
772        graph_module: torch.fx.GraphModule,
773    ) -> torch.fx.Node:
774        in_scale, in_zero_point = producer.args[1:3]
775        in_tensor, out_scale, out_zero_point, _, _, out_dtype = consumer.args
776        if in_scale == out_scale and in_zero_point == out_zero_point:
777            # If the quant params match, we can remove both dequantize-quantize ops.
778            return cast(torch.fx.Node, consumer.args[0])
779
780        assert (
781            self.allow_requantize
782        ), f"Found {producer=} {in_scale=} {in_zero_point=} | {consumer=} {out_scale=} {out_zero_point=}"
783
784        with graph_module.graph.inserting_before(consumer):
785            requantize_node = self._create_requantize_node(
786                in_tensor=cast(torch.fx.Node, consumer.args[0]),
787                in_scale=cast(float, in_scale),
788                in_zero_point=cast(int, in_zero_point),
789                out_scale=cast(float, out_scale),
790                out_zero_point=cast(int, out_zero_point),
791                out_dtype=cast(torch.dtype, out_dtype),
792                graph=graph_module.graph,
793            )
794        return requantize_node
795
796    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
797        # Remove any dequantize op that has only quantize ops as its users.
798        self.find_and_fuse(
799            graph_module,
800            producer_op_packets=self.dequantize_op_packets,
801            consumer_op_packets=self.quantize_op_packets,
802            bypass_ops=self.bypass_ops,
803        )
804        # Remove any quantize op that has only dequantze ops as its users.
805        self.find_and_fuse(
806            graph_module,
807            producer_op_packets=self.quantize_op_packets,
808            consumer_op_packets=self.dequantize_op_packets,
809            # Do not requantize for quantize-dequantize pairs as this is not guaranteed
810            # to be better for performance/memory.
811            # Only fuse if all users of quant are dequant.
812            bypass_ops=(
813                self.bypass_ops
814                if self.force_quant_dequant_fusion
815                else {exir_ops.edge.aten.view_copy.default}
816            ),
817        )
818        result = super().call(graph_module)
819        return result
820
821
822@register_cadence_pass(CadencePassAttribute(opt_level=1))
823class FuseMulIntoDequantPass(ExportPass):
824    """
825    Looks for the pattern where atem.mul is multiplying the outputs of dequantize
826    and aten.full. If found, updates the dequant scale to reflect the multiplication
827    and removes the full and mul nodes.
828    """
829
830    def attempt_fusion(
831        self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
832    ) -> None:
833        if node.target != exir_ops.edge.aten.mul.Tensor:
834            return
835
836        # ensure that one of the args to mul is dequantize and the other is aten.full
837        dequant_nodes = [
838            arg
839            for arg in node.args
840            if isinstance(arg, torch.fx.Node)
841            and isinstance(arg.target, EdgeOpOverload)
842            and get_edge_overload_packet(arg.target)
843            == exir_ops.edge.quantized_decomposed.dequantize_per_tensor
844        ]
845        multiplier_nodes = [
846            arg
847            for arg in node.args
848            if isinstance(arg, torch.fx.Node)
849            and arg.target == exir_ops.edge.aten.full.default
850        ]
851
852        if len(dequant_nodes) != 1 or len(multiplier_nodes) != 1:
853            return
854
855        deq_node = dequant_nodes[0]
856        mplier_node = multiplier_nodes[0]
857
858        # ensure that dequant and full don't have any other users
859        if len(deq_node.users) > 1 or len(mplier_node.users) > 1:
860            return
861
862        new_deq_args = list(deq_node.args)
863        assert isinstance(deq_node.args[1], Number)
864        assert isinstance(mplier_node.args[1], Number)
865        # pyre-ignore[58]: Unsupported operand *
866        new_deq_args[1] = deq_node.args[1] * mplier_node.args[1]
867
868        logging.debug(
869            f"Fused {node} and {mplier_node} into {deq_node}. Updated scale from {deq_node.args[1]} to {new_deq_args[1]}"
870        )
871
872        node.replace_all_uses_with(deq_node)
873        deq_node.args = tuple(new_deq_args)
874
875        graph_module.graph.erase_node(node)
876        graph_module.graph.erase_node(mplier_node)
877        graph_module.recompile()
878
879    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
880        for node in graph_module.graph.nodes:
881            self.attempt_fusion(graph_module, node)
882        result = super().call(graph_module)
883        return result
884
885
886@register_cadence_pass(CadencePassAttribute(opt_level=1))
887class FuseTransposeOpPairsPass(FuseOpPairsAcrossBranchesPass):
888    """
889    Fuse dequantize-quantize op pairs to a single requantize op.
890    For the special case where quant params match, this will remove
891    both dequant and quant ops.
892    """
893
894    # A list of ops that can be bypassed when looking for a
895    # dequantize->quantize chain
896    bypass_ops: set[EdgeOpOverload] = {
897        exir_ops.edge.cadence.quantize_per_tensor.default,
898        exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
899        exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
900        exir_ops.edge.cadence.dequantize_per_tensor.default,
901        exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
902        exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
903    }
904
905    def can_fuse_for_chain(
906        self,
907        producer: torch.fx.Node,
908        consumer: torch.fx.Node,
909        consumer_op_packets: set[EdgeOpOverloadPacket],
910    ) -> bool:
911        if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets):
912            return False
913
914        def get_dims(node: torch.fx.Node) -> tuple[int, int]:
915            def canonicalize(dim: int) -> int:
916                if dim < 0:
917                    dim += len(node.meta["val"].shape)
918                return dim
919
920            return tuple(canonicalize(cast(int, d)) for d in node.args[1:3])
921
922        def is_equivalent(
923            shape: Sequence[int],
924            transpose0: tuple[int, int],
925            transpose1: tuple[int, int],
926        ) -> bool:
927            def permute_order(
928                order: Sequence[int], dims: tuple[int, int]
929            ) -> Sequence[int]:
930                new_order = list(order)
931                new_order[dims[0]], new_order[dims[1]] = (
932                    new_order[dims[1]],
933                    new_order[dims[0]],
934                )
935                return new_order
936
937            order = permute_order(range(len(shape)), transpose0)
938            order = permute_order(order, transpose1)
939
940            non_unit_dims = [dim for dim in range(len(shape)) if shape[dim] != 1]
941            non_unit_dims_permuted = [dim for dim in order if shape[dim] != 1]
942
943            return non_unit_dims == non_unit_dims_permuted
944
945        return is_equivalent(
946            cast(torch.fx.Node, producer.args[0]).meta["val"].shape,
947            get_dims(producer),
948            get_dims(consumer),
949        )
950
951    def get_fused_node(
952        self,
953        producer: torch.fx.Node,
954        consumer: torch.fx.Node,
955        graph_module: torch.fx.GraphModule,
956    ) -> torch.fx.Node:
957        output_shape = consumer.meta["val"].shape
958        with graph_module.graph.inserting_after(consumer):
959            view = graph_module.graph.call_function(
960                exir_ops.edge.aten.view_copy.default,
961                (consumer.args[0], output_shape),
962                {},
963            )
964        return view
965
966    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
967        # Remove any dequantize op that has only quantize ops as its users.
968        self.find_and_fuse(
969            graph_module,
970            producer_op_packets={exir_ops.edge.aten.transpose_copy},
971            consumer_op_packets={exir_ops.edge.aten.transpose_copy},
972            bypass_ops=self.bypass_ops,
973        )
974        result = super().call(graph_module)
975        return result
976
977
978@register_cadence_pass(CadencePassAttribute(opt_level=1))
979class FuseFullThenReshapePass(ExportPass):
980    """
981    A pass that fuses a chain of full and reshape-like operations into a single full operation.
982    """
983
984    fusion_candidates: set[EdgeOpOverload] = {
985        exir_ops.edge.aten.transpose_copy.int,
986        exir_ops.edge.aten.permute_copy.default,
987        exir_ops.edge.aten.view_copy.default,
988    }
989
990    def call_operator(
991        self,
992        op,
993        args: tuple[Argument, ...],
994        kwargs: dict[str, Argument],
995        meta: NodeMetadata,
996    ) -> ProxyValue:
997        if op not in self.fusion_candidates:
998            return super().call_operator(op, args, kwargs, meta)
999
1000        full_node = cast(ProxyValue, args[0]).node
1001        if not (
1002            full_node.op == "call_function"
1003            and full_node.target == exir_ops.edge.aten.full.default
1004        ):
1005            # full -> self.fusion_candidates.
1006            return super().call_operator(op, args, kwargs, meta)
1007
1008        fill_value = full_node.args[1]
1009        return super().call_operator(
1010            exir_ops.edge.aten.full.default,
1011            (
1012                meta["val"].shape,
1013                fill_value,
1014            ),
1015            {"dtype": meta["val"].dtype},
1016            meta,
1017        )
1018
1019    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
1020        graph_module = super().call(graph_module).graph_module
1021        graph_module.graph.eliminate_dead_code()
1022        return PassResult(graph_module, True)
1023
1024
1025class CadenceFuseOpsInGraph:
1026    passes = [
1027        FuseMMWithAdd,
1028        FuseBatchNormWithConv,
1029        FuseQuantizedBatchNormWithConv,
1030        FuseCascadedTransposeOrPermuteOps,
1031        FuseCascadedViewOps,
1032        FuseQuantDequantToRequantizePass,
1033        FuseMulIntoDequantPass,
1034        FuseFullThenReshapePass,
1035        FuseTransposeOpPairsPass,
1036    ]
1037