1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from typing import List, Optional, Union 10 11import torch 12 13from executorch.exir.dialects._ops import ops as exir_ops 14from executorch.exir.pass_base import ExportPass, PassResult 15 16 17class ViewCopyToSqueezeUnsqueezePass(ExportPass): 18 """ 19 Replaces view_copy nodes with squeeze_copy.dims nodes if the view node reduces dims of size 1. 20 Replaces view_copy nodes with unsqueeze_copy.default nodes if the view node adds a dim of size 1. 21 """ 22 23 def __init__(self) -> None: 24 super().__init__() 25 self.view_copy_op: torch._ops.OpOverload = exir_ops.edge.aten.view_copy.default 26 self.squeeze_op: torch._ops.OpOverload = exir_ops.edge.aten.squeeze_copy.dims 27 self.unsqueeze_op: torch._ops.OpOverload = ( 28 exir_ops.edge.aten.unsqueeze_copy.default 29 ) 30 31 def is_node_target( 32 self, node: torch.fx.Node, target: torch._ops.OperatorBase 33 ) -> bool: 34 return node.op == "call_function" and node.target == target 35 36 def find_squeeze_dims( 37 self, 38 input_shape: List[int], 39 view_shape: List[int], 40 ) -> Optional[List[int]]: 41 # view_shape should be a subset of input_shape 42 if len(input_shape) <= len(view_shape): 43 return None 44 45 # check that all dims are equal except the removed dims 46 i = 0 47 j = 0 48 idx = [] 49 while i < len(input_shape): 50 if input_shape[i] != view_shape[j]: 51 if input_shape[i] == 1: 52 idx.append(i) 53 j -= 1 54 # continue to check remaining dims are equal 55 else: 56 return None 57 i += 1 58 j += 1 59 return idx 60 61 def find_unsqueeze_dim( 62 self, 63 input_shape: List[int], 64 view_shape: List[int], 65 ) -> Optional[int]: 66 # unsqueeze should increase the length of input_shape by 1 67 if len(view_shape) - len(input_shape) != 1: 68 return None 69 70 # check that all dims are equal except the added dim 71 i = 0 72 j = 0 73 idx = -1 74 while j < len(view_shape): 75 if input_shape[i] != view_shape[j]: 76 if view_shape[j] == 1: 77 idx = j 78 i -= 1 79 # continue to check remaining dims are equal 80 else: 81 return None 82 i += 1 83 j += 1 84 return idx 85 86 def replace_view_copy_node( 87 self, 88 graph_module: torch.fx.GraphModule, 89 view_node: torch.fx.Node, 90 op: torch._ops.OpOverload, 91 arg: Union[List[int], int], 92 ) -> None: 93 with graph_module.graph.inserting_before(view_node): 94 new_node = graph_module.graph.create_node( 95 "call_function", 96 op, 97 (view_node.args[0], arg), 98 ) 99 new_node.meta = view_node.meta 100 view_node.replace_all_uses_with(new_node) 101 graph_module.graph.erase_node(view_node) 102 103 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 104 modified = False 105 for node in graph_module.graph.nodes: 106 if self.is_node_target(node, self.view_copy_op): 107 input_node = node.args[0] 108 input_shape = input_node.meta["val"].shape 109 view_shape = node.args[1] 110 squeeze_dims = self.find_squeeze_dims(input_shape, view_shape) 111 if squeeze_dims: 112 self.replace_view_copy_node( 113 graph_module, node, self.squeeze_op, squeeze_dims 114 ) 115 modified = True 116 continue 117 unsqueeze_dim = self.find_unsqueeze_dim(input_shape, view_shape) 118 if unsqueeze_dim: 119 self.replace_view_copy_node( 120 graph_module, node, self.unsqueeze_op, unsqueeze_dim 121 ) 122 modified = True 123 continue 124 125 if modified: 126 graph_module.recompile() 127 graph_module = super().call(graph_module).graph_module 128 return PassResult(graph_module, modified) 129