xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/pt2e/graph_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import itertools
3import operator
4from typing import Any, Callable, List, Optional, OrderedDict, Set
5
6import torch
7from torch.fx import Node
8from torch.fx.passes.utils.source_matcher_utils import (
9    check_subgraphs_connected,
10    get_source_partitions,
11    SourcePartition,
12)
13
14
15__all__ = [
16    "find_sequential_partitions",
17    "get_equivalent_types",
18    "update_equivalent_types_dict",
19]
20
21_EQUIVALENT_TYPES: List[Set] = [
22    {torch.nn.Conv1d, torch.nn.functional.conv1d},
23    {torch.nn.Conv2d, torch.nn.functional.conv2d},
24    {torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d},
25    {torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_},
26    {torch.nn.BatchNorm2d, torch.nn.functional.batch_norm},
27    {torch.nn.Hardtanh, torch.nn.functional.hardtanh, torch.nn.functional.hardtanh_},
28    {torch.add, operator.add, operator.iadd, "add", "add_"},
29    {torch.mul, operator.mul, operator.imul, "mul", "mul_"},
30]
31
32
33def _create_equivalent_types_dict():
34    _DICT = {}
35    for values in _EQUIVALENT_TYPES:
36        for v in values:
37            _DICT[v] = list(values)
38    return _DICT
39
40
41_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
42
43
44def get_equivalent_types() -> List[Set]:
45    return _EQUIVALENT_TYPES
46
47
48def update_equivalent_types_dict(customized_equivalent_types=None):
49    """Help function for user who wants to customize the _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT.
50    When customized_equivalent_types passes in,
51    re-generate _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT.
52    """
53    if customized_equivalent_types is None:
54        raise ValueError("customized_equivalent_types should not be None")
55    global _EQUIVALENT_TYPES
56    global _EQUIVALENT_TYPES_DICT
57    _EQUIVALENT_TYPES = customized_equivalent_types
58    _EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
59
60
61def _partitions_sequential(partitions: List[SourcePartition]):
62    prev_partition = None
63    for partition in partitions:
64        if prev_partition is not None and not check_subgraphs_connected(
65            prev_partition, partition
66        ):
67            return False
68        prev_partition = partition
69    return True
70
71
72def _get_matching_types(partition_type):
73    matching_types = [partition_type]
74    if partition_type in _EQUIVALENT_TYPES_DICT:
75        matching_types.extend(_EQUIVALENT_TYPES_DICT[partition_type])
76    return matching_types
77
78
79def _valid_type_sequence(partition_types: List[Any]):
80    partition_types_set = set()  # type: ignore[var-annotated]
81    for partition_type in partition_types:
82        matching_types = _get_matching_types(partition_type)
83        matching_types_set = set(matching_types)
84        if len(partition_types_set & matching_types_set) > 0:
85            return False
86        partition_types_set |= matching_types_set
87    return True
88
89
90def find_sequential_partitions(
91    gm: torch.fx.GraphModule,
92    partition_types: List[Any],
93    include_functional_equivalent=True,
94    filter_fn: Optional[Callable[[Node], bool]] = None,
95):
96    if not _valid_type_sequence(partition_types):
97        raise ValueError(
98            f"Invalid partition types: {partition_types}. Each type in the sequence must be unique"
99        )
100
101    typed_partitions: OrderedDict[Any, List[SourcePartition]] = OrderedDict()
102    for partition_type in partition_types:
103        types_to_match = _get_matching_types(partition_type)
104        partitions = get_source_partitions(gm.graph, types_to_match, filter_fn)
105        typed_partitions[partition_type] = list(
106            itertools.chain.from_iterable(partitions.values())
107        )
108
109    typed_partitions_list = list(typed_partitions.values())
110    fusion_candidates = itertools.product(*typed_partitions_list)
111    fused_partitions = []
112    for candidate in fusion_candidates:
113        if _partitions_sequential(candidate):  # type: ignore[arg-type]
114            fused_partitions.append(candidate)
115    return fused_partitions
116