1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import torch 7from executorch.exir.dialects._ops import ops as exir_ops 8from executorch.exir.pass_base import ExportPass, PassResult 9from torch.fx.passes.utils.source_matcher_utils import get_source_partitions 10 11 12class ConvertInterpolateWithUpsample2D(ExportPass): 13 """ 14 Merge decomposed operators from interpolate back to one super node. 15 TODO: Currently we only map to upsample2d version, should extend the 16 capability by reverse engineering the decomposition process. 17 """ 18 19 def __init__(self): 20 super(ConvertInterpolateWithUpsample2D, self).__init__() 21 22 def call(self, graph_module: torch.fx.GraphModule): 23 graph = graph_module.graph 24 partitions = get_source_partitions(graph, [torch.nn.functional.interpolate]) 25 for _, src_partitions in partitions.items(): 26 for src_partition in src_partitions: 27 input_node = src_partition.input_nodes[0] 28 output_node = src_partition.output_nodes[0] 29 with graph.inserting_after(input_node): 30 # TODO: robust way to get the configuration parameters and operator 31 # please check torch/_decomp/decomposition.py for details 32 if output_node.target.__name__ == "aten.index.Tensor": 33 # nearest_2d 34 # args: input, output_size, scales_h, scales_w 35 output_size = list(output_node.meta["val"].shape) 36 args = [input_node, output_size[-2:]] 37 upsample_op = exir_ops.edge.aten.upsample_nearest2d.default 38 else: 39 # upsample_2d 40 # args: input, output_size, aligned_corners, scales_h, scales_w 41 output_size = list(output_node.meta["val"].shape) 42 args = [input_node, output_size[-2:], False] 43 upsample_op = exir_ops.edge.aten.upsample_bilinear2d.default 44 45 upsample2d_node = graph.create_node( 46 "call_function", upsample_op, tuple(args) 47 ) 48 users = output_node.users.copy() 49 for user in users: 50 user.replace_input_with(output_node, upsample2d_node) 51 # copy metadata 52 upsample2d_node.meta = output_node.meta 53 54 graph.eliminate_dead_code() 55 graph_module.recompile() 56 return PassResult(graph_module, True) 57