xref: /aosp_15_r20/external/pytorch/torch/ao/pruning/sparsifier/base_sparsifier.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import abc
3import copy
4from collections import defaultdict
5from typing import Any, Dict, List, Optional, Set, Tuple, Type
6
7import torch
8from torch import nn
9from torch.nn.utils import parametrize
10from torch.nn.utils.parametrize import type_before_parametrizations
11
12from .utils import (
13    FakeSparsity,
14    get_arg_info_from_tensor_fqn,
15    module_contains_param,
16    module_to_fqn,
17    swap_module,
18)
19
20
21__all__ = ["BaseSparsifier"]
22
23SUPPORTED_MODULES = {nn.Linear}
24
25KEYS_NOT_IN_STATE_DICT = ["module", "module_fqn", "tensor_name"]
26
27
28# TODO update desc with new config args
29class BaseSparsifier(abc.ABC):
30    r"""Base class for all sparsifiers.
31
32    Abstract methods that need to be implemented:
33
34    - update_mask: Function to compute a new mask for all keys in the
35        `groups`.
36
37    Args:
38        - model [nn.Module]: model to configure. The model itself is not saved
39            but used for the state_dict saving / loading.
40        - config [list]: configuration elements should be a dict map that includes
41            `tensor_fqn` of tensors to sparsify
42        - defaults [dict]: default configurations will be attached to the
43            configuration. Only the keys that don't exist in the `config` will
44            be updated.
45
46    Example::
47
48        >>> # xdoctest: +SKIP("Can't instantiate abstract class BaseSparsifier with abstract method update_mask")
49        >>> config = [{'tensor_fqn': 'layer1.weight', 'tensor_fqn': 'linear2.weight2', 'sparsity_level': 0.5}]
50        >>> defaults = {'sparsity_level': 0.7}
51        >>> # model.layer1.weight will have `sparsity_level` = 0.7 (getting default)
52        >>> sparsifier = BaseSparsifier(config, defaults)
53    """
54
55    def __init__(self, defaults: Optional[Dict[str, Any]] = None):
56        super().__init__()
57        self.defaults: Dict[str, Any] = defaults or {}
58
59        self.state: Dict[str, Dict] = defaultdict(dict)
60        self.groups: List[Dict[str, Any]] = []
61        self.enable_mask_update = True
62
63    def __getstate__(self) -> Dict[str, Any]:
64        return {
65            "defaults": self.defaults,
66            "state": self.state,
67            "groups": self.groups,
68        }
69
70    def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None:
71        self.__dict__.update(state)
72
73    def __repr__(self):
74        format_string = self.__class__.__name__ + " ("
75        for i, sparse_args in enumerate(self.groups):
76            module = sparse_args["module"]
77            format_string += "\n"
78            format_string += f"\tGroup {i}\n"
79            format_string += f"\t    module: {module}\n"
80            for key in sorted(sparse_args.keys()):
81                if key == "module":
82                    continue
83                format_string += f"\t    {key}: {sparse_args[key]}\n"
84        format_string += ")"
85        return format_string
86
87    def state_dict(self) -> Dict[str, Any]:
88        r"""Returns the state of the optimizer as a :class:`dict`.
89
90        It contains:
91        * state - current state of the sparsification.
92        * groups - a list containing all sparsity configuration groups
93            with the key 'tensor_fqn' specifying the path to the sparsified tensor within a model
94
95        TODO: Need a clean way of loading the state of the "prepared" module
96        """
97
98        groups: List[Dict[str, Any]] = [
99            dict(
100                filter(
101                    lambda key_value: key_value[0] not in KEYS_NOT_IN_STATE_DICT,
102                    mg.items(),
103                )
104            )
105            for mg in self.groups
106        ]
107
108        return {
109            "state": self.state,
110            "groups": groups,
111        }
112
113    def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True):
114        groups = copy.deepcopy(state_dict["groups"])
115        states = state_dict["state"]
116        for tensor_fqn, s in states.items():
117            arg_info = get_arg_info_from_tensor_fqn(self.model, tensor_fqn)
118            module = arg_info["module"]
119            tensor_name = arg_info["tensor_name"]
120            if strict and module is None:
121                raise RuntimeError(f"Error loading {tensor_fqn} into the model")
122
123            found = False
124            for p in module.parametrizations[tensor_name]:
125                if isinstance(p, FakeSparsity):
126                    found = True
127                    break
128            if not found:
129                p = FakeSparsity(torch.ones(getattr(module, tensor_name).shape))
130                parametrize.register_parametrization(module, tensor_name, p)
131            if s.get("mask", None) is not None:
132                mask = s.pop("mask")
133                p.mask = mask
134
135            for mg in groups:
136                if mg["tensor_fqn"] == tensor_fqn:
137                    mg.update(arg_info)
138        self.__setstate__({"state": states, "groups": groups})
139
140    def make_config_from_model(
141        self,
142        model: nn.Module,
143        SUPPORTED_MODULES: Set[Type] = SUPPORTED_MODULES,
144    ) -> None:
145        self.config = []
146        stack = [model]
147        while stack:
148            module = stack.pop()
149            for name, child in module.named_children():
150                if type(child) in SUPPORTED_MODULES:
151                    module_fqn = module_to_fqn(model, child)
152                    assert isinstance(module_fqn, str)  # for mypy
153                    self.config.append({"tensor_fqn": module_fqn + ".weight"})
154                else:
155                    stack.append(child)
156
157    def prepare(self, model, config):
158        r"""Prepares a model, by adding the parametrizations.
159
160        Note::
161
162            The model is modified inplace. If you need to preserve the original
163            model, use copy.deepcopy.
164        """
165        self.model = model  # TODO: Need to figure out how to load without this.
166        self.config = config
167
168        # If no config -- try getting all the supported layers
169        if self.config is None:
170            self.make_config_from_model(model)
171
172        # TODO: Remove the configuration by reference ('module')
173        for module_config in self.config:
174            assert isinstance(module_config, dict), (
175                "config elements should be dicts not modules i.e.:"
176                "[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]"
177            )
178
179            assert isinstance(self.defaults, Dict)  # for mypy
180            local_args = copy.deepcopy(self.defaults)
181            local_args.update(module_config)
182
183            tensor_fqn = local_args.get("tensor_fqn", None)
184            assert tensor_fqn is not None, (
185                "tensor_fqn is a required argument in the sparsity config which"
186                "replaces previous `module` and [module]`fqn` arguments"
187            )
188
189            # populate all information from tensor_fqn
190            info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn)
191
192            # check that whatever was put into local_args agrees with what was obtained
193            # from tensor_fqn
194            for key in info_from_tensor_fqn.keys():
195                if key in local_args:
196                    assert (
197                        info_from_tensor_fqn[key] == local_args[key]
198                        or (
199                            key == "tensor_fqn"
200                            and "." + info_from_tensor_fqn[key] == local_args[key]
201                        )
202                        # info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that
203                    ), f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!"
204            local_args.update(info_from_tensor_fqn)
205            self.groups.append(local_args)
206        self._prepare()
207
208    def _prepare(self, *args, **kwargs):
209        r"""Adds mask parametrization to the layer weight"""
210        for config in self.groups:
211            module = config["module"]
212            tensor_name = config["tensor_name"]
213            parametrization = config.get("parametrization", FakeSparsity)
214            mask = config.get("mask", torch.ones_like(getattr(module, tensor_name)))
215            self.state[config["tensor_fqn"]]["mask"] = mask
216            parametrize.register_parametrization(
217                module, tensor_name, parametrization(mask)
218            )
219
220    def squash_mask(
221        self,
222        params_to_keep: Optional[Tuple[str, ...]] = None,
223        params_to_keep_per_layer: Optional[Dict[str, Tuple[str, ...]]] = None,
224        *args,
225        **kwargs,
226    ):
227        r"""Squashes the sparse masks into the appropriate tensors.
228
229        If either the `params_to_keep` or `params_to_keep_per_layer` is set,
230        the module will have a `sparse_params` dict attached to it.
231
232        Args:
233            params_to_keep: List of keys to save in the module or a dict
234                            representing the modules and keys that will have
235                            sparsity parameters saved
236            params_to_keep_per_layer: Dict to specify the params that should be
237                            saved for specific layers. The keys in the dict
238                            should be the module fqn, while the values should
239                            be a list of strings with the names of the variables
240                            to save in the `sparse_params`
241
242        Examples:
243            >>> # xdoctest: +SKIP("locals are undefined")
244            >>> # Don't save any sparse params
245            >>> sparsifier.squash_mask()
246            >>> hasattr(model.submodule1, 'sparse_params')
247            False
248
249            >>> # Keep sparse params per layer
250            >>> sparsifier.squash_mask(
251            ...     params_to_keep_per_layer={
252            ...         'submodule1.linear1': ('foo', 'bar'),
253            ...         'submodule2.linear42': ('baz',)
254            ...     })
255            >>> print(model.submodule1.linear1.sparse_params)
256            {'foo': 42, 'bar': 24}
257            >>> print(model.submodule2.linear42.sparse_params)
258            {'baz': 0.1}
259
260            >>> # Keep sparse params for all layers
261            >>> sparsifier.squash_mask(params_to_keep=('foo', 'bar'))
262            >>> print(model.submodule1.linear1.sparse_params)
263            {'foo': 42, 'bar': 24}
264            >>> print(model.submodule2.linear42.sparse_params)
265            {'foo': 42, 'bar': 24}
266
267            >>> # Keep some sparse params for all layers, and specific ones for
268            >>> # some other layers
269            >>> sparsifier.squash_mask(
270            ...     params_to_keep=('foo', 'bar'),
271            ...     params_to_keep_per_layer={
272            ...         'submodule2.linear42': ('baz',)
273            ...     })
274            >>> print(model.submodule1.linear1.sparse_params)
275            {'foo': 42, 'bar': 24}
276            >>> print(model.submodule2.linear42.sparse_params)
277            {'foo': 42, 'bar': 24, 'baz': 0.1}
278        """
279        for config in self.groups:
280            module = config["module"]
281            tensor_name = config["tensor_name"]
282            parametrize.remove_parametrizations(
283                module, tensor_name, leave_parametrized=True
284            )
285            sparse_params = {}
286            if params_to_keep is not None:
287                global_params = {k: config[k] for k in params_to_keep}
288                sparse_params.update(global_params)
289            if params_to_keep_per_layer is not None:
290                params = params_to_keep_per_layer.get(config["module_fqn"], None)
291                if params is not None:
292                    per_layer_params = {k: config[k] for k in params}
293                    sparse_params.update(per_layer_params)
294            if sparse_params:
295                # TODO handle multiple tensor being quantized on a single module, where to store sparse_params?
296                module.sparse_params = sparse_params
297
298    def convert(
299        self,
300        module: nn.Module,
301        mapping: Optional[Dict[Type[nn.Module], Type[nn.Module]]] = None,
302        inplace: bool = False,
303        parameterization: Type[nn.Module] = FakeSparsity,
304    ):
305        r"""Converts submodules in input module to a different module according to `mapping`
306        by calling `from_dense` method on the target module class
307        Args:
308            module: input module
309            mapping: a dictionary that maps from source module type to target
310                module type, can be overwritten to allow swapping user defined
311                Modules
312            inplace: carry out model transformations in-place, the original module
313                is mutated
314        """
315        if mapping is None:
316            raise NotImplementedError("Need to auto generate mapping ")
317        if not inplace:
318            module = copy.deepcopy(module)
319
320        reassign = {}
321        for name, mod in module.named_children():
322            # leaf node
323            if (
324                module_contains_param(mod, parameterization)
325                and type_before_parametrizations(mod) in mapping
326            ):
327                reassign[name] = swap_module(mod, mapping)
328            else:
329                # recurse
330                reassign[name] = self.convert(
331                    mod,
332                    mapping=mapping,
333                    inplace=True,
334                    parameterization=parameterization,
335                )
336
337        for key, value in reassign.items():
338            module._modules[key] = value
339
340        return module
341
342    def step(self, use_path: bool = True) -> None:
343        if not self.enable_mask_update:
344            return
345        with torch.no_grad():
346            for config in self.groups:
347                self.update_mask(**config)
348
349    @abc.abstractmethod
350    def update_mask(self, module: nn.Module, tensor_name: str, **kwargs):
351        pass
352