1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import torch 8from executorch.exir.dialects._ops import ops as exir_ops 9from executorch.exir.pass_base import ExportPass, PassResult 10from executorch.exir.passes import dead_code_elimination_pass 11 12 13class ExpandBroadcastTensorShape(ExportPass): 14 """ 15 Make tensors have same rank for layout-transform to work properly. 16 """ 17 18 def __init__(self): 19 super(ExpandBroadcastTensorShape, self).__init__() 20 self.broadcast_op_targets = [ 21 exir_ops.edge.aten.add.Tensor, 22 exir_ops.edge.aten.sub.Tensor, 23 exir_ops.edge.aten.mul.Tensor, 24 exir_ops.edge.aten.div.Tensor, 25 ] 26 27 def traverse_broadcast_node(self, graph_module: torch.fx.GraphModule): 28 for node in graph_module.graph.nodes: 29 if node.target in self.broadcast_op_targets: 30 for arg in node.args: 31 input_rank = len(arg.meta["val"].shape) 32 output_rank = len(node.meta["val"].shape) 33 if input_rank != output_rank: 34 with graph_module.graph.inserting_after(arg): 35 new_rank = [1] * (output_rank - input_rank) + list( 36 arg.meta["val"].shape 37 ) 38 users = list(arg.users.keys()) 39 reshape_node = graph_module.graph.create_node( 40 "call_function", 41 exir_ops.edge.aten.view_copy.default, 42 (arg, tuple(new_rank)), 43 ) 44 # meta needs to be copied elementwisely for fake-tensor 45 # to be updated correctly and not affect meta of arg 46 for k, v in arg.meta.items(): 47 reshape_node.meta[k] = v 48 reshape_node.meta["val"] = reshape_node.meta["val"].reshape( 49 new_rank 50 ) 51 for user in users: 52 user.replace_input_with(arg, reshape_node) 53 54 def call(self, graph_module: torch.fx.GraphModule): 55 self.traverse_broadcast_node(graph_module) 56 graph_module.recompile() 57 dead_code_elimination_pass(graph_module) 58 return PassResult(graph_module, True) 59