1# mypy: allow-untyped-defs 2import copy 3from typing import Any, Dict, Set, Union 4 5import torch 6from torch.fx import GraphModule 7from torch.fx.graph import Graph 8 9 10__all__ = [ 11 "FusedGraphModule", 12 "ObservedGraphModule", 13 "ObservedStandaloneGraphModule", 14 "QuantizedGraphModule", 15] 16 17 18class FusedGraphModule(GraphModule): 19 def __init__( 20 self, 21 root: Union[torch.nn.Module, Dict[str, Any]], 22 graph: Graph, 23 preserved_attr_names: Set[str], 24 ): 25 self.preserved_attr_names = preserved_attr_names 26 preserved_attrs = { 27 attr: getattr(root, attr) 28 for attr in self.preserved_attr_names 29 if hasattr(root, attr) 30 } 31 super().__init__(root, graph) 32 for attr in preserved_attrs: 33 setattr(self, attr, preserved_attrs[attr]) 34 35 # GraphModule does not copy attributes which are not in the __dict__ 36 # of vanilla nn.Module. So, we override __deepcopy__ in order 37 # to copy the quantization specific attributes correctly. 38 def __deepcopy__(self, memo): 39 fake_mod = torch.nn.Module() 40 fake_mod.__dict__ = copy.deepcopy(self.__dict__) 41 return FusedGraphModule( 42 fake_mod, 43 copy.deepcopy(self.graph), 44 copy.deepcopy(self.preserved_attr_names), 45 ) 46 47 48class ObservedGraphModule(GraphModule): 49 def __init__( 50 self, 51 root: Union[torch.nn.Module, Dict[str, Any]], 52 graph: Graph, 53 preserved_attr_names: Set[str], 54 ): 55 self.preserved_attr_names = { 56 "_activation_post_process_map", 57 "_activation_post_process_indexes", 58 "_patterns", 59 "_node_name_to_qconfig", 60 "_prepare_custom_config", 61 "_equalization_node_name_to_qconfig", 62 "_node_name_to_scope", 63 "_qconfig_mapping", 64 "_is_qat", 65 "_observed_node_names", 66 }.union(preserved_attr_names) 67 preserved_attrs = { 68 attr: getattr(root, attr) 69 for attr in self.preserved_attr_names 70 if hasattr(root, attr) 71 } 72 super().__init__(root, graph) 73 for attr in preserved_attrs: 74 setattr(self, attr, preserved_attrs[attr]) 75 76 # GraphModule does not copy attributes which are not in the __dict__ 77 # of vanilla nn.Module. So, we override __deepcopy__ in order 78 # to copy the quantization specific attributes correctly. 79 def __deepcopy__(self, memo): 80 fake_mod = torch.nn.Module() 81 fake_mod.__dict__ = copy.deepcopy(self.__dict__) 82 return ObservedGraphModule( 83 fake_mod, 84 copy.deepcopy(self.graph), 85 copy.deepcopy(self.preserved_attr_names), 86 ) 87 88 89def _is_observed_module(module: Any) -> bool: 90 return hasattr(module, "meta") and "_observed_graph_module_attrs" in module.meta 91 92 93def _get_observed_graph_module_attr( 94 model: Union[torch.nn.Module, GraphModule], attr_name: str 95) -> Any: 96 if hasattr(model, "meta") and "_observed_graph_module_attrs" in model.meta: # type: ignore[operator, index] 97 return getattr(model.meta["_observed_graph_module_attrs"], attr_name) # type: ignore[index] 98 return None 99 100 101class ObservedStandaloneGraphModule(ObservedGraphModule): 102 def __init__( 103 self, 104 root: Union[torch.nn.Module, Dict[str, Any]], 105 graph: Graph, 106 preserved_attr_names: Set[str], 107 ): 108 preserved_attr_names = preserved_attr_names.union( 109 { 110 "_standalone_module_input_quantized_idxs", 111 "_standalone_module_output_quantized_idxs", 112 } 113 ) 114 super().__init__(root, graph, preserved_attr_names) 115 116 def __deepcopy__(self, memo): 117 fake_mod = torch.nn.Module() 118 fake_mod.__dict__ = copy.deepcopy(self.__dict__) 119 return ObservedStandaloneGraphModule( 120 fake_mod, 121 copy.deepcopy(self.graph), 122 copy.deepcopy(self.preserved_attr_names), 123 ) 124 125 126def _is_observed_standalone_module(module: Any) -> bool: 127 return ( 128 _is_observed_module(module) 129 and module.meta["_observed_graph_module_attrs"].is_observed_standalone_module 130 ) 131 132 133def _save_packed_weight(self, destination, prefix, keep_vars): 134 for attr_name in dir(self): 135 if "_packed_weight" in attr_name and isinstance( 136 getattr(self, attr_name), torch._C.ScriptObject 137 ): # type: ignore[attr-defined] 138 packed_weight = getattr(self, attr_name) 139 destination[prefix + attr_name] = packed_weight 140 141 142class QuantizedGraphModule(GraphModule): 143 """This class is created to make sure PackedParams 144 (e.g. LinearPackedParams, Conv2dPackedParams) to appear in state_dict 145 so that we can serialize and deserialize quantized graph module with 146 torch.save(m.state_dict()) and m.load_state_dict(state_dict) 147 """ 148 149 def __init__( 150 self, 151 root: Union[torch.nn.Module, Dict[str, Any]], 152 graph: Graph, 153 preserved_attr_names: Set[str], 154 ): 155 self.preserved_attr_names = preserved_attr_names 156 preserved_attrs = { 157 attr: getattr(root, attr) 158 for attr in self.preserved_attr_names 159 if hasattr(root, attr) 160 } 161 super().__init__(root, graph) 162 for attr in preserved_attrs: 163 setattr(self, attr, preserved_attrs[attr]) 164 self._register_state_dict_hook(_save_packed_weight) 165 166 def _load_from_state_dict( 167 self, 168 state_dict, 169 prefix, 170 local_metadata, 171 strict, 172 missing_keys, 173 unexpected_keys, 174 error_msgs, 175 ): 176 attrs_to_pop = [] 177 for attr_name in state_dict: 178 if attr_name.startswith("_packed_weight") and isinstance(state_dict[attr_name], torch._C.ScriptObject): # type: ignore[attr-defined] # noqa: B950 179 setattr(self, attr_name, state_dict[attr_name]) 180 attrs_to_pop.append(attr_name) 181 182 # pop the packed param attributesn 183 for attr_name in attrs_to_pop: 184 state_dict.pop(attr_name) 185 186 super()._load_from_state_dict( 187 state_dict, 188 prefix, 189 local_metadata, 190 strict, 191 missing_keys, 192 unexpected_keys, 193 error_msgs, 194 ) 195 196 def __deepcopy__(self, memo): 197 fake_mod = torch.nn.Module() 198 fake_mod.__dict__ = copy.deepcopy(self.__dict__) 199 return QuantizedGraphModule( 200 fake_mod, 201 copy.deepcopy(self.graph), 202 copy.deepcopy(self.preserved_attr_names), 203 ) 204