xref: /aosp_15_r20/external/executorch/backends/xnnpack/_passes/convert_to_upsample_bilinear2d.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import torch
8from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass
9from executorch.backends.xnnpack.partition.graphs import bilinear_2d
10from executorch.backends.xnnpack.utils.utils import check_or_raise
11from executorch.exir.dialects._ops import ops as exir_ops
12from torch.fx.passes.infra.pass_base import PassResult
13from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher
14
15
16class ConvertToUpsampleBilinear2d(XNNPACKPass):
17    output_nodes_to_bilinear_node = {}
18
19    def create_upsample_bilinear_2d(
20        self,
21        graph_module: torch.fx.GraphModule,
22        internal_match: InternalMatch,
23        align_corners: bool,
24    ):
25        output = internal_match.returning_nodes[0]
26        output_shape = output.meta["val"].shape
27        output_h = output_shape[-2]
28        output_w = output_shape[-1]
29        check_or_raise(
30            isinstance(output_h, int) and isinstance(output_w, int),
31            "XNNPACK Upsample Bilinear2d does not support dynamic shape",
32        )
33
34        input_node = internal_match.placeholder_nodes[-1]
35        input_node = self.output_nodes_to_bilinear_node.get(input_node, input_node)
36        with graph_module.graph.inserting_before(output):
37            upsample_node = graph_module.graph.create_node(
38                "call_function",
39                exir_ops.edge.aten.upsample_bilinear2d.vec,
40                # TODO(T166527012): Using output_h and output_w here only works with static shapes
41                args=(input_node, [output_h, output_w], align_corners, None),
42            )
43        output.replace_all_uses_with(upsample_node)
44        self.output_nodes_to_bilinear_node[output] = upsample_node
45        graph_module.graph.eliminate_dead_code()
46        graph_module.recompile()
47
48    def call(self, graph_module: torch.fx.GraphModule):
49        for pattern, align_corners in bilinear_2d.get_graphs_dict().items():
50            sm = SubgraphMatcher(pattern.graph, ignore_literals=True)
51            matches = list(sm.match(graph_module.graph))
52            for partition_to_replace in matches:
53                self.create_upsample_bilinear_2d(
54                    graph_module, partition_to_replace, align_corners
55                )
56
57        graph_module.recompile()
58        graph_module = super().call(graph_module).graph_module
59        return PassResult(graph_module, True)
60