1# mypy: allow-untyped-defs 2from typing import Dict, Tuple 3 4from torch.fx._compatibility import compatibility 5from torch.fx.graph import Graph 6 7from torch.fx.graph_module import GraphModule 8from torch.fx.passes.utils.matcher_utils import SubgraphMatcher 9from torch.nn import Module 10 11 12__all__ = ["HolderModule", "lift_subgraph_as_module", "compare_graphs"] 13 14 15@compatibility(is_backward_compatible=False) 16class HolderModule(Module): 17 """ 18 HolderModule is used to copy all the attributes from original module to submodules 19 that uses the attributes 20 """ 21 22 def __init__(self, d): 23 super().__init__() 24 for k, v in d.items(): 25 self.add_module(k, v) 26 27 28@compatibility(is_backward_compatible=False) 29def lift_subgraph_as_module( 30 gm: GraphModule, 31 subgraph: Graph, 32 comp_name: str = "", 33 class_name: str = "GraphModule", 34) -> Tuple[GraphModule, Dict[str, str]]: 35 """ 36 Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module. 37 38 Args: 39 gm (GraphModule): parent graph module 40 41 subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph 42 43 comp_name (str): name for the new component 44 45 class_name (str): name for the submodule 46 47 """ 48 49 # Loop through all module calls (call_module) and param fetches (get_attr) 50 # in this component, creating HolderModules as necessary to match the path. 51 # e.g. if in the original module there's a get_attr node fetches "conv.weight". 52 # We create a HolderModule as root -> add a HolderModule named "conv" -> 53 # make "weight" a attribute of "conv" HolderModule and point to conv.weight in 54 # the original module. 55 submodule = HolderModule({}) 56 orig_to_split_fqn_mapping: Dict[str, str] = {} 57 for n in subgraph.nodes: 58 if n.op not in ("call_module", "get_attr"): 59 continue 60 61 target = n.target 62 assert isinstance(target, str) 63 target_name_parts = target.split(".") 64 curr = submodule 65 orig_gm = gm 66 67 for name in target_name_parts[:-1]: 68 if not hasattr(curr, name): 69 curr.add_module(name, HolderModule({})) 70 71 curr = getattr(curr, name) 72 orig_gm = getattr(orig_gm, name) 73 74 leaf_node_name = target_name_parts[-1] 75 leaf_node = getattr(orig_gm, leaf_node_name) 76 77 orig_to_split_fqn_mapping[target] = f"{comp_name}.{target}" 78 # Relies on custom __setattr__ magic. 79 setattr(curr, leaf_node_name, leaf_node) 80 81 return GraphModule(submodule, subgraph, class_name), orig_to_split_fqn_mapping 82 83 84@compatibility(is_backward_compatible=False) 85def compare_graphs(left: Graph, right: Graph) -> bool: 86 """ 87 Return True if two graphs are identical, i.e they 88 - have the same number of outputs in the same order 89 - have the same number of inputs in the same order 90 - have the same set of nodes, and identical connectivity 91 """ 92 93 matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True) 94 matches = matcher.match(right) 95 96 return len(matches) > 0 97