1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from itertools import chain 8 9import torch 10from executorch.backends.example.example_operators.ops import module_to_annotator 11from executorch.exir.dialects._ops import ops as exir_ops 12from executorch.exir.dim_order_utils import get_dim_order 13from executorch.exir.pass_base import ExportPass, PassResult 14from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions 15 16 17class PermuteMemoryFormatsPass(ExportPass): 18 """ 19 This pass will insert to_dim ops to the pattern if satisfis requirement, like pattern_op.permuate_memory_format is set as True. 20 Example 1: 21 before pass: x -> conv -> out 22 after pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out 23 24 before pass: x -> conv -> conv -> out 25 after pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out 26 27 before pass: x -> conv -> linear -> out 28 after pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> to_dim(channel_last) -> linear -> to_dim_(contiguous) -> out 29 """ 30 31 def call( # noqa: suprress function is too complex (13) 32 self, graph_module: torch.fx.GraphModule 33 ) -> PassResult: 34 for pattern in list(module_to_annotator.keys()): 35 pattern_op = module_to_annotator[pattern] 36 if pattern_op.permuate_memory_format: 37 partitions = find_sequential_partitions( 38 graph_module, 39 pattern, 40 ) 41 for partition in partitions: 42 # Some unpacking logic to get a flatten exit nodes list 43 output_nodes = [ 44 node 45 for node in partition[0].output_nodes 46 if node.op != "placeholder" 47 ] 48 exit_nodes = [output_node.users for output_node in output_nodes] 49 exit_nodes = list(chain.from_iterable(exit_nodes)) 50 51 """ 52 # Step 1. Insert to_dim op when exit the pattern 53 # for example, if the pattern is conv, x -> conv -> out will become x -> conv -> to_dim(contiguous) -> out when permute memory format 54 # for x -> conv -> conv -> out, it will become x -> conv -> to_dim(contiguous) -> conv -> to_dim(contiguous) -> out 55 """ 56 for exit_node in exit_nodes: 57 with graph_module.graph.inserting_before(exit_node): 58 # Handle the case when the pattern output is also the graph output, 59 # like, x -> conv -> out will become x -> conv -> to_dim(contiguous) -> out 60 if exit_node.op == "output": 61 exit_node_args = exit_node.args[0] 62 exit_to_dim_op = graph_module.graph.call_function( 63 exir_ops.edge.dim_order_ops._to_dim_order_copy.default, 64 args=exit_node_args, 65 kwargs={ 66 "dtype": torch.float64, 67 "dim_order": get_dim_order( 68 torch.contiguous_format, 4 69 ), 70 }, 71 ) 72 # Insert to_dim op and it'll be the return op 73 _ = graph_module.graph.output((exit_to_dim_op,)) 74 # Remove the old return op. 75 graph_module.graph.erase_node(exit_node) 76 # Handle the case when the pattern output is intermediate output, 77 # like, x -> conv -> conv -> out will become x -> conv -> to_dim(contiguous) -> conv -> out 78 elif exit_node.op == "call_function": 79 exit_node_args = [] 80 for exit_node_arg in exit_node.args: 81 if ( 82 isinstance(exit_node_arg, torch.fx.Node) 83 and exit_node_arg.op != "placeholder" 84 ): 85 exit_to_dim_op = graph_module.graph.call_function( 86 exir_ops.edge.dim_order_ops._to_dim_order_copy.default, 87 args=(exit_node_arg,), 88 kwargs={ 89 "dtype": torch.float64, 90 "dim_order": get_dim_order( 91 torch.contiguous_format, 4 92 ), 93 }, 94 ) 95 exit_node_args.append(exit_to_dim_op) 96 else: 97 exit_node_args.append(exit_node_arg) 98 exit_node.args = list(exit_node_args) 99 100 """ 101 # Step 2. Insert to_dim op when enter the pattern. After the first step, we already have to_dim(default) when exiting the pattern. 102 # Now we need to insert to_dim(channel_last) when enter the pattern. 103 # for example, if the pattern is conv, x -> conv -> to_dim(contiguous) -> out will become x -> to_dim(channel_last) -> conv -> to_dim(contiguous) -> out 104 # for x -> conv -> to_dim(contiguous) -> conv -> to_dim(contiguous) -> out, it will become x -> to_dim(channel_last) -> conv -> to_dim(contiguous) -> to_dim(channel_last) -> conv -> to_dim(contiguous) -> out 105 """ 106 # create the input_node and the to_dim_op map 107 # for example, if the pattern is conv, x -> conv -> out, node 108 input_node_map = {} # key: input_node, value: to_dim_op 109 to_dim_op_set = set() 110 for input_node in partition[0].input_nodes: 111 with graph_module.graph.inserting_after(input_node): 112 to_dim_op = graph_module.graph.call_function( 113 # Insert the to_dim op and update input_node_map 114 exir_ops.edge.dim_order_ops._to_dim_order_copy.default, 115 args=(input_node,), 116 kwargs={ 117 "dtype": torch.float64, 118 "dim_order": get_dim_order(torch.channels_last, 4), 119 }, 120 ) 121 input_node_map[input_node] = to_dim_op 122 to_dim_op_set.add(to_dim_op) 123 124 # Update the args to the new to_dim op, skip if it's already set 125 for input_node in partition[0].input_nodes: 126 for user in list(input_node.users): 127 # if user is in to_dim_op_set, it means the user's arg is already set to_dim op 128 if user not in to_dim_op_set: 129 user_new_arg = [ 130 ( 131 input_node_map[user_arg] 132 if user_arg in input_node_map 133 else user_arg 134 ) 135 for user_arg in user.args 136 ] 137 # Update input node's users arg 138 user.args = tuple(user_new_arg) 139 140 # Ensure the graph is still valid 141 graph_module.graph.lint() 142 graph_module.recompile() 143 return PassResult(graph_module, True) 144