1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7 8import torch 9from executorch.exir.dialects._ops import ops as exir_ops 10from executorch.exir.dim_order_utils import get_dim_order 11from executorch.exir.pass_base import ExportPass, PassResult 12 13 14class MergeToDimPass(ExportPass): 15 """ 16 This pass will insert to_dim ops to the pattern if satisfis requirement, like pattern_op.permuate_memory_format is set as True. 17 Example: 18 # Done for 1 to 1 19 before pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out 20 after pass: x -> to_dim(channel_last) -> conv -> conv -> to_dim_(contiguous) -> out 21 22 # Not Done for 1 to N 23 before pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out 24 |-------------> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out 25 after pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out 26 |--------------> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out 27 28 # Not Done for N to 1 29 before pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out 30 y -> to_dim(channel_last) -> conv -> to_dim_(contiguous) ---------| 31 after pass: x -> to_dim(channel_last) -> conv -> conv -> to_dim_(contiguous) -> out 32 y -> to_dim(channel_last) -> conv-----| 33 34 # Not Done for N to N 35 """ 36 37 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 38 for node in graph_module.graph.nodes: 39 if node.target == exir_ops.edge.dim_order_ops._to_dim_order_copy.default: 40 # print(node, node.args, list(node.users), list(list(node.users)[0].args)) 41 if len(node.users) == 1 and len(list(node.users)[0].args) == 2: 42 args_map = {} 43 node_kwargs = node.args[-1] 44 node_users = list(node.users) 45 46 in_to_dim_node_dim_order = node_kwargs["dim_order"] 47 in_to_dim_node_dtype = node_kwargs["dtype"] 48 out_to_dim_node = node_users[0] 49 out_to_dim_node_kwargs = out_to_dim_node.args[-1] 50 out_to_dim_node_dim_order = out_to_dim_node_kwargs["dim_order"] 51 out_to_dim_node_dtype = out_to_dim_node_kwargs["dtype"] 52 53 if ( 54 in_to_dim_node_dtype == out_to_dim_node_dtype 55 and in_to_dim_node_dim_order 56 == get_dim_order(torch.channels_last, 4) 57 and out_to_dim_node_dim_order 58 == get_dim_order(torch.contiguous_format, 4) 59 ): 60 61 out_to_dim_node_users = list(out_to_dim_node.users) 62 assert len(out_to_dim_node_users) == 1 63 out_to_dim_node_user = out_to_dim_node_users[0] 64 args_map[out_to_dim_node] = node.args[0] 65 out_to_dim_node_user_new_args = [ 66 args_map[out_to_dim_node] if arg in args_map else arg 67 for arg in out_to_dim_node_user.args 68 ] 69 print("out_to_dim_node_user.args: ", out_to_dim_node_user.args) 70 print( 71 "out_to_dim_node_user_new_args: ", 72 out_to_dim_node_user_new_args, 73 ) 74 out_to_dim_node_user.args = tuple(out_to_dim_node_user_new_args) 75 76 # pyre-fixme[29]: `Union[torch._tensor.Tensor, 77 # torch.nn.modules.module.Module]` is not a function. 78 graph_module.erase_node(out_to_dim_node) 79 # pyre-fixme[29]: `Union[torch._tensor.Tensor, 80 # torch.nn.modules.module.Module]` is not a function. 81 graph_module.erase_node(node) 82 # TODO: Handle other merging rules, including 1->N, N->1, N->N 83 return PassResult(graph_module, True) 84