1# mypy: allow-untyped-defs 2from typing import Dict, Tuple, Any 3 4import torch 5from torch.fx.passes.infra.pass_base import PassBase, PassResult 6from torch.utils._pytree import tree_flatten 7 8from torch.fx import GraphModule, Graph 9from torch.fx import Node 10 11aten = torch.ops.aten 12 13 14# stateful ops are banned from CSE 15rand_ops = {aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm} # noqa: E501,B950 16 17inplace_ops = {aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_} # noqa: E501 18 19 20@torch.fx._compatibility.compatibility(is_backward_compatible=False) 21def get_CSE_banned_ops(): 22 return rand_ops.union(inplace_ops) 23 24 25@torch.fx._compatibility.compatibility(is_backward_compatible=False) 26class CSEPass(PassBase): 27 28 def __init__(self, banned_ops=None): 29 """ 30 This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node. 31 32 For functional dialects, user would only need to specify the random ops in ban list. 33 34 Warning: CSE Pass cannot be safely applied on a FX graph in non-functional dialects. 35 If your dialect contains stateful operators, please customized the banned_ops. 36 37 """ 38 if banned_ops is None: 39 banned_ops = set() 40 self.banned_ops = banned_ops 41 super().__init__() 42 43 def call(self, graph_module: GraphModule) -> PassResult: 44 """ 45 Return a new copy of torch.fx.GraphModule with CSE applied to the input graph 46 47 Example usage: 48 49 from torch.fx.experimental.proxy_tensor import make_fx 50 def f(a): 51 b = a * a 52 c = a * a 53 return b+c 54 55 p = CSEPass() 56 traced_graph = make_fx(f)(torch.tensor(1)) 57 print(traced_graph) 58 result = p(traced_graph) 59 print(result.graph_module) 60 """ 61 def get_aten_target(node): 62 if hasattr(node.target, 'overloadpacket'): 63 return node.target.overloadpacket 64 return node.target 65 66 modified = False 67 new_graph = Graph() 68 env: Dict[Node, Node] = {} # map from node in the old graph to node in the new graph 69 hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {} # map from hash to a node in the new graph 70 token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {} # map from hash to token 71 for n in graph_module.graph.nodes: 72 # The placeholder, output, and get_attr nodes are copied to the new graph without change 73 # do not CSE away random operations 74 if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in self.banned_ops: 75 new_node = new_graph.node_copy(n, lambda x: env[x]) 76 env[n] = new_node 77 else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' 78 # substitute args and kwargs members to their mapping in env if exists 79 # specs can be used to reconstruct nested list/dictionaries 80 def substitute(arg_list): 81 arg_list, spec = tree_flatten(arg_list) 82 for i in range(len(arg_list)): 83 v = arg_list[i] 84 if isinstance(v, Node) and v in env: 85 arg_list[i] = env[v] 86 return tuple(arg_list), spec 87 args, args_spec = substitute(n.args) 88 kwargs, kwargs_spec = substitute(n.kwargs) 89 90 # each token corresponds to a unique node 91 # nodes with the same token can be substituted 92 token = {"target": n.target, "args": args, "args_spec": args_spec, 93 "kwargs": kwargs, "kwargs_spec": kwargs_spec} 94 95 # hash substituted args to a number, do not hash specs because specs are not hashable 96 hash_arg = hash((args, kwargs)) 97 hash_val = (n.target, hash_arg) 98 99 # check if a node has a substitute and can be eliminated 100 hash_val_in_hash_env = hash_val in hash_env 101 if hash_val_in_hash_env and token_map[hash_val] == token: 102 modified = True # substitution happens and the graph is modified 103 env[n] = hash_env[hash_val] 104 continue 105 106 new_node = new_graph.node_copy(n, lambda x: env[x]) 107 env[n] = new_node 108 if not hash_val_in_hash_env: 109 hash_env[hash_val] = new_node 110 token_map[hash_val] = token 111 112 csed_gm = GraphModule(graph_module, new_graph) 113 return PassResult(csed_gm, modified) 114