xref: /aosp_15_r20/external/pytorch/torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import abc
3import warnings
4import weakref
5from functools import wraps
6
7from torch.ao.pruning._experimental.data_sparsifier import BaseDataSparsifier
8
9
10__all__ = ["BaseDataScheduler"]
11
12
13class BaseDataScheduler:
14    r"""
15    The BaseDataScheduler is the abstract scheduler class specifically for the
16    BaseDataSparsifier class. This class controls a specific hyperparameter of
17    the sparsifier class and varies it across the training process (or across time).
18
19    Args:
20        data_sparsifier (instance of BaseDataSparsifier)
21            Implemented class data sparsifier class wherein the update_mask is implemented
22        schedule_param (str)
23            A specific hyperparameter of the passed sparsifier that needs to be scheduled/varied
24        last_epoch (int, default=-1)
25            This is specifically is passed when training needs to be resumed from a particular
26            point.
27        verbose (bool, default=False)
28            Verbosity of the BaseDataScheduler
29
30    The *get_hyperparam()* function needs to be implemented by the user.
31    """
32
33    def __init__(
34        self, data_sparsifier, schedule_param: str, last_epoch=-1, verbose=False
35    ):
36        # Attach sparsifier
37        if not isinstance(data_sparsifier, BaseDataSparsifier):
38            raise TypeError(
39                f"{type(data_sparsifier).__name__} is not an instance of torch.ao.pruning.BaseDataSparsifier"
40            )
41        self.data_sparsifier = data_sparsifier
42        self.schedule_param = schedule_param
43
44        # Initialize epoch and base hyper-params
45        self.base_param = {
46            name: config.get(schedule_param, None)
47            for name, config in self.data_sparsifier.data_groups.items()
48        }
49
50        self.last_epoch = last_epoch
51
52        # Following https://github.com/pytorch/pytorch/issues/20124
53        # We would like to ensure that `scheduler.step()` is called after
54        # `sparsifier.step()`
55        def with_counter(method):
56            if getattr(method, "_with_counter", False):
57                # `sparsifier.step()` has already been replaced, return.
58                return method
59
60            # Keep a weak reference to the sparsifier instance to prevent
61            # cyclic references.
62            instance_ref = weakref.ref(method.__self__)
63            # Get the unbound method for the same purpose.
64            func = method.__func__
65            cls = instance_ref().__class__
66            del method
67
68            @wraps(func)
69            def wrapper(*args, **kwargs):
70                instance = instance_ref()
71                instance._step_count += 1  # type: ignore[union-attr]
72                wrapped = func.__get__(instance, cls)
73                return wrapped(*args, **kwargs)
74
75            # Note that the returned function here is no longer a bound method,
76            # so attributes like `__func__` and `__self__` no longer exist.
77            wrapper._with_counter = True  # type: ignore[attr-defined]
78            return wrapper
79
80        self.data_sparsifier.step = with_counter(self.data_sparsifier.step)  # type: ignore[assignment]
81        self.data_sparsifier._step_count = 0  # type: ignore[attr-defined]
82        self._step_count: int = 0
83        self.verbose = verbose
84
85        # Housekeeping
86        self._get_sp_called_within_step: bool = False  # sp -> schedule parameter
87        self.step()
88
89    @abc.abstractmethod
90    def get_schedule_param(self):
91        r"""
92        Abstract method that needs to be implemented by the child class.
93        The expected return type should is a dictionary of name to schedule_param value
94        The returned values will be updated in sparsifier when the scheduler step() function
95        is called.
96
97        Example:
98            >>> def get_schedule_param(self):
99            ...     new_param = {}
100            ...     for name in self.sparsifier.data_groups.keys():
101            ...         new_param[name] = self.sparsifier.data_groups[name][self.schedule_param] * 0.5
102            ...     return new_param
103
104        When the step() function is called, the value in self.sparsifier.data_groups[name][self.schedule_param]
105        would be halved
106        """
107        raise NotImplementedError
108
109    def __repr__(self):
110        format_string = self.__class__.__name__ + " ("
111        format_string += "\n"
112        format_string += f"Data Sparsifier {self.data_sparsifier}\n"
113        format_string += f"    {self.schedule_param}: {self.base_param}\n"
114        format_string += ")"
115        return format_string
116
117    def state_dict(self):
118        """Returns the state of the scheduler as a :class:`dict`.
119
120        It contains an entry for every variable in self.__dict__ which
121        is not the sparsifier.
122
123        Note:
124            The scheduler class does not track the state of the data_sparsifier.
125            Make sure to store the state of the sparsifier before storing the
126            state of the scheduler
127        """
128        return {
129            key: value
130            for key, value in self.__dict__.items()
131            if key != "data_sparsifier"
132        }
133
134    def load_state_dict(self, state_dict):
135        """Loads the schedulers state.
136
137        Note:
138            Remember to restore the state of the data_sparsifier before the scheduler.
139
140        Args:
141            state_dict (dict): scheduler state. Should be an object returned
142                from a call to :meth:`state_dict`.
143        """
144        self.__dict__.update(state_dict)
145
146    def get_last_param(self):
147        return self._last_param
148
149    def step(self):
150        # Raise warning if trying to call scheduler step before the sparsifier.
151        # https://github.com/pytorch/pytorch/issues/20124
152        if self._step_count == 1:
153            if not hasattr(self.data_sparsifier.step, "_with_counter"):
154                warnings.warn(
155                    "Seems like `data_sparsifier.step()` has been overridden after sparsity scheduler "
156                    "initialization. Please, make sure to call `data_sparsifier.step()` before "
157                    "`scheduler.step()`.",
158                    UserWarning,
159                )
160
161            # Just check if there were two first scheduler.step() calls before sparsifier.step()
162            elif self.data_sparsifier._step_count < 1:  # type: ignore[attr-defined]
163                warnings.warn(
164                    "Detected call of `scheduler.step()` before `data_sparsifier.step()`. "
165                    "You have to make sure you run the data_sparsifier.step() BEFORE any "
166                    "calls to the scheduler.step().",
167                    UserWarning,
168                )
169        self._step_count += 1
170
171        class _enable_get_sp_call:
172            def __init__(self, o):
173                self.o = o
174
175            def __enter__(self):
176                self.o._get_sp_called_within_step = True
177                return self
178
179            def __exit__(self, type, value, traceback):
180                self.o._get_sp_called_within_step = False
181
182        with _enable_get_sp_call(self):
183            self.last_epoch += 1
184            updated_scheduler_params = self.get_schedule_param()
185
186        for name, param in updated_scheduler_params.items():
187            self.data_sparsifier.data_groups[name][self.schedule_param] = param
188            if self.verbose:
189                print(f"Adjusting {self.schedule_param} for group {name} to {param}")
190
191        self._last_param = {
192            name: config.get(self.schedule_param, None)
193            for name, config in self.data_sparsifier.data_groups.items()
194        }
195        self.data_sparsifier.enable_mask_update = True
196