1# mypy: allow-untyped-defs 2import itertools 3import operator 4from typing import Any, Callable, List, Optional, OrderedDict, Set 5 6import torch 7from torch.fx import Node 8from torch.fx.passes.utils.source_matcher_utils import ( 9 check_subgraphs_connected, 10 get_source_partitions, 11 SourcePartition, 12) 13 14 15__all__ = [ 16 "find_sequential_partitions", 17 "get_equivalent_types", 18 "update_equivalent_types_dict", 19] 20 21_EQUIVALENT_TYPES: List[Set] = [ 22 {torch.nn.Conv1d, torch.nn.functional.conv1d}, 23 {torch.nn.Conv2d, torch.nn.functional.conv2d}, 24 {torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d}, 25 {torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_}, 26 {torch.nn.BatchNorm2d, torch.nn.functional.batch_norm}, 27 {torch.nn.Hardtanh, torch.nn.functional.hardtanh, torch.nn.functional.hardtanh_}, 28 {torch.add, operator.add, operator.iadd, "add", "add_"}, 29 {torch.mul, operator.mul, operator.imul, "mul", "mul_"}, 30] 31 32 33def _create_equivalent_types_dict(): 34 _DICT = {} 35 for values in _EQUIVALENT_TYPES: 36 for v in values: 37 _DICT[v] = list(values) 38 return _DICT 39 40 41_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict() 42 43 44def get_equivalent_types() -> List[Set]: 45 return _EQUIVALENT_TYPES 46 47 48def update_equivalent_types_dict(customized_equivalent_types=None): 49 """Help function for user who wants to customize the _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT. 50 When customized_equivalent_types passes in, 51 re-generate _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT. 52 """ 53 if customized_equivalent_types is None: 54 raise ValueError("customized_equivalent_types should not be None") 55 global _EQUIVALENT_TYPES 56 global _EQUIVALENT_TYPES_DICT 57 _EQUIVALENT_TYPES = customized_equivalent_types 58 _EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict() 59 60 61def _partitions_sequential(partitions: List[SourcePartition]): 62 prev_partition = None 63 for partition in partitions: 64 if prev_partition is not None and not check_subgraphs_connected( 65 prev_partition, partition 66 ): 67 return False 68 prev_partition = partition 69 return True 70 71 72def _get_matching_types(partition_type): 73 matching_types = [partition_type] 74 if partition_type in _EQUIVALENT_TYPES_DICT: 75 matching_types.extend(_EQUIVALENT_TYPES_DICT[partition_type]) 76 return matching_types 77 78 79def _valid_type_sequence(partition_types: List[Any]): 80 partition_types_set = set() # type: ignore[var-annotated] 81 for partition_type in partition_types: 82 matching_types = _get_matching_types(partition_type) 83 matching_types_set = set(matching_types) 84 if len(partition_types_set & matching_types_set) > 0: 85 return False 86 partition_types_set |= matching_types_set 87 return True 88 89 90def find_sequential_partitions( 91 gm: torch.fx.GraphModule, 92 partition_types: List[Any], 93 include_functional_equivalent=True, 94 filter_fn: Optional[Callable[[Node], bool]] = None, 95): 96 if not _valid_type_sequence(partition_types): 97 raise ValueError( 98 f"Invalid partition types: {partition_types}. Each type in the sequence must be unique" 99 ) 100 101 typed_partitions: OrderedDict[Any, List[SourcePartition]] = OrderedDict() 102 for partition_type in partition_types: 103 types_to_match = _get_matching_types(partition_type) 104 partitions = get_source_partitions(gm.graph, types_to_match, filter_fn) 105 typed_partitions[partition_type] = list( 106 itertools.chain.from_iterable(partitions.values()) 107 ) 108 109 typed_partitions_list = list(typed_partitions.values()) 110 fusion_candidates = itertools.product(*typed_partitions_list) 111 fused_partitions = [] 112 for candidate in fusion_candidates: 113 if _partitions_sequential(candidate): # type: ignore[arg-type] 114 fused_partitions.append(candidate) 115 return fused_partitions 116