1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import torch 8from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass 9from executorch.backends.xnnpack.partition.graphs import bilinear_2d 10from executorch.backends.xnnpack.utils.utils import check_or_raise 11from executorch.exir.dialects._ops import ops as exir_ops 12from torch.fx.passes.infra.pass_base import PassResult 13from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher 14 15 16class ConvertToUpsampleBilinear2d(XNNPACKPass): 17 output_nodes_to_bilinear_node = {} 18 19 def create_upsample_bilinear_2d( 20 self, 21 graph_module: torch.fx.GraphModule, 22 internal_match: InternalMatch, 23 align_corners: bool, 24 ): 25 output = internal_match.returning_nodes[0] 26 output_shape = output.meta["val"].shape 27 output_h = output_shape[-2] 28 output_w = output_shape[-1] 29 check_or_raise( 30 isinstance(output_h, int) and isinstance(output_w, int), 31 "XNNPACK Upsample Bilinear2d does not support dynamic shape", 32 ) 33 34 input_node = internal_match.placeholder_nodes[-1] 35 input_node = self.output_nodes_to_bilinear_node.get(input_node, input_node) 36 with graph_module.graph.inserting_before(output): 37 upsample_node = graph_module.graph.create_node( 38 "call_function", 39 exir_ops.edge.aten.upsample_bilinear2d.vec, 40 # TODO(T166527012): Using output_h and output_w here only works with static shapes 41 args=(input_node, [output_h, output_w], align_corners, None), 42 ) 43 output.replace_all_uses_with(upsample_node) 44 self.output_nodes_to_bilinear_node[output] = upsample_node 45 graph_module.graph.eliminate_dead_code() 46 graph_module.recompile() 47 48 def call(self, graph_module: torch.fx.GraphModule): 49 for pattern, align_corners in bilinear_2d.get_graphs_dict().items(): 50 sm = SubgraphMatcher(pattern.graph, ignore_literals=True) 51 matches = list(sm.match(graph_module.graph)) 52 for partition_to_replace in matches: 53 self.create_upsample_bilinear_2d( 54 graph_module, partition_to_replace, align_corners 55 ) 56 57 graph_module.recompile() 58 graph_module = super().call(graph_module).graph_module 59 return PassResult(graph_module, True) 60