xref: /aosp_15_r20/external/pytorch/torch/fx/passes/utils/common.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Dict, Tuple
3
4from torch.fx._compatibility import compatibility
5from torch.fx.graph import Graph
6
7from torch.fx.graph_module import GraphModule
8from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
9from torch.nn import Module
10
11
12__all__ = ["HolderModule", "lift_subgraph_as_module", "compare_graphs"]
13
14
15@compatibility(is_backward_compatible=False)
16class HolderModule(Module):
17    """
18    HolderModule is used to copy all the attributes from original module to submodules
19    that uses the attributes
20    """
21
22    def __init__(self, d):
23        super().__init__()
24        for k, v in d.items():
25            self.add_module(k, v)
26
27
28@compatibility(is_backward_compatible=False)
29def lift_subgraph_as_module(
30    gm: GraphModule,
31    subgraph: Graph,
32    comp_name: str = "",
33    class_name: str = "GraphModule",
34) -> Tuple[GraphModule, Dict[str, str]]:
35    """
36    Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module.
37
38    Args:
39        gm (GraphModule): parent graph module
40
41        subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph
42
43        comp_name (str): name for the new component
44
45        class_name (str): name for the submodule
46
47    """
48
49    # Loop through all module calls (call_module) and param fetches (get_attr)
50    # in this component, creating HolderModules as necessary to match the path.
51    # e.g. if in the original module there's a get_attr node fetches "conv.weight".
52    # We create a HolderModule as root -> add a HolderModule named "conv" ->
53    # make "weight" a attribute of "conv" HolderModule and point to conv.weight in
54    # the original module.
55    submodule = HolderModule({})
56    orig_to_split_fqn_mapping: Dict[str, str] = {}
57    for n in subgraph.nodes:
58        if n.op not in ("call_module", "get_attr"):
59            continue
60
61        target = n.target
62        assert isinstance(target, str)
63        target_name_parts = target.split(".")
64        curr = submodule
65        orig_gm = gm
66
67        for name in target_name_parts[:-1]:
68            if not hasattr(curr, name):
69                curr.add_module(name, HolderModule({}))
70
71            curr = getattr(curr, name)
72            orig_gm = getattr(orig_gm, name)
73
74        leaf_node_name = target_name_parts[-1]
75        leaf_node = getattr(orig_gm, leaf_node_name)
76
77        orig_to_split_fqn_mapping[target] = f"{comp_name}.{target}"
78        # Relies on custom __setattr__ magic.
79        setattr(curr, leaf_node_name, leaf_node)
80
81    return GraphModule(submodule, subgraph, class_name), orig_to_split_fqn_mapping
82
83
84@compatibility(is_backward_compatible=False)
85def compare_graphs(left: Graph, right: Graph) -> bool:
86    """
87    Return True if two graphs are identical, i.e they
88        - have the same number of outputs in the same order
89        - have the same number of inputs in the same order
90        - have the same set of nodes, and identical connectivity
91    """
92
93    matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True)
94    matches = matcher.match(right)
95
96    return len(matches) > 0
97