xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fx/graph_module.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3from typing import Any, Dict, Set, Union
4
5import torch
6from torch.fx import GraphModule
7from torch.fx.graph import Graph
8
9
10__all__ = [
11    "FusedGraphModule",
12    "ObservedGraphModule",
13    "ObservedStandaloneGraphModule",
14    "QuantizedGraphModule",
15]
16
17
18class FusedGraphModule(GraphModule):
19    def __init__(
20        self,
21        root: Union[torch.nn.Module, Dict[str, Any]],
22        graph: Graph,
23        preserved_attr_names: Set[str],
24    ):
25        self.preserved_attr_names = preserved_attr_names
26        preserved_attrs = {
27            attr: getattr(root, attr)
28            for attr in self.preserved_attr_names
29            if hasattr(root, attr)
30        }
31        super().__init__(root, graph)
32        for attr in preserved_attrs:
33            setattr(self, attr, preserved_attrs[attr])
34
35    # GraphModule does not copy attributes which are not in the __dict__
36    # of vanilla nn.Module.  So, we override __deepcopy__ in order
37    # to copy the quantization specific attributes correctly.
38    def __deepcopy__(self, memo):
39        fake_mod = torch.nn.Module()
40        fake_mod.__dict__ = copy.deepcopy(self.__dict__)
41        return FusedGraphModule(
42            fake_mod,
43            copy.deepcopy(self.graph),
44            copy.deepcopy(self.preserved_attr_names),
45        )
46
47
48class ObservedGraphModule(GraphModule):
49    def __init__(
50        self,
51        root: Union[torch.nn.Module, Dict[str, Any]],
52        graph: Graph,
53        preserved_attr_names: Set[str],
54    ):
55        self.preserved_attr_names = {
56            "_activation_post_process_map",
57            "_activation_post_process_indexes",
58            "_patterns",
59            "_node_name_to_qconfig",
60            "_prepare_custom_config",
61            "_equalization_node_name_to_qconfig",
62            "_node_name_to_scope",
63            "_qconfig_mapping",
64            "_is_qat",
65            "_observed_node_names",
66        }.union(preserved_attr_names)
67        preserved_attrs = {
68            attr: getattr(root, attr)
69            for attr in self.preserved_attr_names
70            if hasattr(root, attr)
71        }
72        super().__init__(root, graph)
73        for attr in preserved_attrs:
74            setattr(self, attr, preserved_attrs[attr])
75
76    # GraphModule does not copy attributes which are not in the __dict__
77    # of vanilla nn.Module.  So, we override __deepcopy__ in order
78    # to copy the quantization specific attributes correctly.
79    def __deepcopy__(self, memo):
80        fake_mod = torch.nn.Module()
81        fake_mod.__dict__ = copy.deepcopy(self.__dict__)
82        return ObservedGraphModule(
83            fake_mod,
84            copy.deepcopy(self.graph),
85            copy.deepcopy(self.preserved_attr_names),
86        )
87
88
89def _is_observed_module(module: Any) -> bool:
90    return hasattr(module, "meta") and "_observed_graph_module_attrs" in module.meta
91
92
93def _get_observed_graph_module_attr(
94    model: Union[torch.nn.Module, GraphModule], attr_name: str
95) -> Any:
96    if hasattr(model, "meta") and "_observed_graph_module_attrs" in model.meta:  # type: ignore[operator, index]
97        return getattr(model.meta["_observed_graph_module_attrs"], attr_name)  # type: ignore[index]
98    return None
99
100
101class ObservedStandaloneGraphModule(ObservedGraphModule):
102    def __init__(
103        self,
104        root: Union[torch.nn.Module, Dict[str, Any]],
105        graph: Graph,
106        preserved_attr_names: Set[str],
107    ):
108        preserved_attr_names = preserved_attr_names.union(
109            {
110                "_standalone_module_input_quantized_idxs",
111                "_standalone_module_output_quantized_idxs",
112            }
113        )
114        super().__init__(root, graph, preserved_attr_names)
115
116    def __deepcopy__(self, memo):
117        fake_mod = torch.nn.Module()
118        fake_mod.__dict__ = copy.deepcopy(self.__dict__)
119        return ObservedStandaloneGraphModule(
120            fake_mod,
121            copy.deepcopy(self.graph),
122            copy.deepcopy(self.preserved_attr_names),
123        )
124
125
126def _is_observed_standalone_module(module: Any) -> bool:
127    return (
128        _is_observed_module(module)
129        and module.meta["_observed_graph_module_attrs"].is_observed_standalone_module
130    )
131
132
133def _save_packed_weight(self, destination, prefix, keep_vars):
134    for attr_name in dir(self):
135        if "_packed_weight" in attr_name and isinstance(
136            getattr(self, attr_name), torch._C.ScriptObject
137        ):  # type: ignore[attr-defined]
138            packed_weight = getattr(self, attr_name)
139            destination[prefix + attr_name] = packed_weight
140
141
142class QuantizedGraphModule(GraphModule):
143    """This class is created to make sure PackedParams
144    (e.g. LinearPackedParams, Conv2dPackedParams) to appear in state_dict
145    so that we can serialize and deserialize quantized graph module with
146    torch.save(m.state_dict()) and m.load_state_dict(state_dict)
147    """
148
149    def __init__(
150        self,
151        root: Union[torch.nn.Module, Dict[str, Any]],
152        graph: Graph,
153        preserved_attr_names: Set[str],
154    ):
155        self.preserved_attr_names = preserved_attr_names
156        preserved_attrs = {
157            attr: getattr(root, attr)
158            for attr in self.preserved_attr_names
159            if hasattr(root, attr)
160        }
161        super().__init__(root, graph)
162        for attr in preserved_attrs:
163            setattr(self, attr, preserved_attrs[attr])
164        self._register_state_dict_hook(_save_packed_weight)
165
166    def _load_from_state_dict(
167        self,
168        state_dict,
169        prefix,
170        local_metadata,
171        strict,
172        missing_keys,
173        unexpected_keys,
174        error_msgs,
175    ):
176        attrs_to_pop = []
177        for attr_name in state_dict:
178            if attr_name.startswith("_packed_weight") and isinstance(state_dict[attr_name], torch._C.ScriptObject):  # type: ignore[attr-defined] # noqa: B950
179                setattr(self, attr_name, state_dict[attr_name])
180                attrs_to_pop.append(attr_name)
181
182        # pop the packed param attributesn
183        for attr_name in attrs_to_pop:
184            state_dict.pop(attr_name)
185
186        super()._load_from_state_dict(
187            state_dict,
188            prefix,
189            local_metadata,
190            strict,
191            missing_keys,
192            unexpected_keys,
193            error_msgs,
194        )
195
196    def __deepcopy__(self, memo):
197        fake_mod = torch.nn.Module()
198        fake_mod.__dict__ = copy.deepcopy(self.__dict__)
199        return QuantizedGraphModule(
200            fake_mod,
201            copy.deepcopy(self.graph),
202            copy.deepcopy(self.preserved_attr_names),
203        )
204