1# mypy: allow-untyped-defs 2import abc 3import warnings 4import weakref 5from functools import wraps 6 7from torch.ao.pruning._experimental.data_sparsifier import BaseDataSparsifier 8 9 10__all__ = ["BaseDataScheduler"] 11 12 13class BaseDataScheduler: 14 r""" 15 The BaseDataScheduler is the abstract scheduler class specifically for the 16 BaseDataSparsifier class. This class controls a specific hyperparameter of 17 the sparsifier class and varies it across the training process (or across time). 18 19 Args: 20 data_sparsifier (instance of BaseDataSparsifier) 21 Implemented class data sparsifier class wherein the update_mask is implemented 22 schedule_param (str) 23 A specific hyperparameter of the passed sparsifier that needs to be scheduled/varied 24 last_epoch (int, default=-1) 25 This is specifically is passed when training needs to be resumed from a particular 26 point. 27 verbose (bool, default=False) 28 Verbosity of the BaseDataScheduler 29 30 The *get_hyperparam()* function needs to be implemented by the user. 31 """ 32 33 def __init__( 34 self, data_sparsifier, schedule_param: str, last_epoch=-1, verbose=False 35 ): 36 # Attach sparsifier 37 if not isinstance(data_sparsifier, BaseDataSparsifier): 38 raise TypeError( 39 f"{type(data_sparsifier).__name__} is not an instance of torch.ao.pruning.BaseDataSparsifier" 40 ) 41 self.data_sparsifier = data_sparsifier 42 self.schedule_param = schedule_param 43 44 # Initialize epoch and base hyper-params 45 self.base_param = { 46 name: config.get(schedule_param, None) 47 for name, config in self.data_sparsifier.data_groups.items() 48 } 49 50 self.last_epoch = last_epoch 51 52 # Following https://github.com/pytorch/pytorch/issues/20124 53 # We would like to ensure that `scheduler.step()` is called after 54 # `sparsifier.step()` 55 def with_counter(method): 56 if getattr(method, "_with_counter", False): 57 # `sparsifier.step()` has already been replaced, return. 58 return method 59 60 # Keep a weak reference to the sparsifier instance to prevent 61 # cyclic references. 62 instance_ref = weakref.ref(method.__self__) 63 # Get the unbound method for the same purpose. 64 func = method.__func__ 65 cls = instance_ref().__class__ 66 del method 67 68 @wraps(func) 69 def wrapper(*args, **kwargs): 70 instance = instance_ref() 71 instance._step_count += 1 # type: ignore[union-attr] 72 wrapped = func.__get__(instance, cls) 73 return wrapped(*args, **kwargs) 74 75 # Note that the returned function here is no longer a bound method, 76 # so attributes like `__func__` and `__self__` no longer exist. 77 wrapper._with_counter = True # type: ignore[attr-defined] 78 return wrapper 79 80 self.data_sparsifier.step = with_counter(self.data_sparsifier.step) # type: ignore[assignment] 81 self.data_sparsifier._step_count = 0 # type: ignore[attr-defined] 82 self._step_count: int = 0 83 self.verbose = verbose 84 85 # Housekeeping 86 self._get_sp_called_within_step: bool = False # sp -> schedule parameter 87 self.step() 88 89 @abc.abstractmethod 90 def get_schedule_param(self): 91 r""" 92 Abstract method that needs to be implemented by the child class. 93 The expected return type should is a dictionary of name to schedule_param value 94 The returned values will be updated in sparsifier when the scheduler step() function 95 is called. 96 97 Example: 98 >>> def get_schedule_param(self): 99 ... new_param = {} 100 ... for name in self.sparsifier.data_groups.keys(): 101 ... new_param[name] = self.sparsifier.data_groups[name][self.schedule_param] * 0.5 102 ... return new_param 103 104 When the step() function is called, the value in self.sparsifier.data_groups[name][self.schedule_param] 105 would be halved 106 """ 107 raise NotImplementedError 108 109 def __repr__(self): 110 format_string = self.__class__.__name__ + " (" 111 format_string += "\n" 112 format_string += f"Data Sparsifier {self.data_sparsifier}\n" 113 format_string += f" {self.schedule_param}: {self.base_param}\n" 114 format_string += ")" 115 return format_string 116 117 def state_dict(self): 118 """Returns the state of the scheduler as a :class:`dict`. 119 120 It contains an entry for every variable in self.__dict__ which 121 is not the sparsifier. 122 123 Note: 124 The scheduler class does not track the state of the data_sparsifier. 125 Make sure to store the state of the sparsifier before storing the 126 state of the scheduler 127 """ 128 return { 129 key: value 130 for key, value in self.__dict__.items() 131 if key != "data_sparsifier" 132 } 133 134 def load_state_dict(self, state_dict): 135 """Loads the schedulers state. 136 137 Note: 138 Remember to restore the state of the data_sparsifier before the scheduler. 139 140 Args: 141 state_dict (dict): scheduler state. Should be an object returned 142 from a call to :meth:`state_dict`. 143 """ 144 self.__dict__.update(state_dict) 145 146 def get_last_param(self): 147 return self._last_param 148 149 def step(self): 150 # Raise warning if trying to call scheduler step before the sparsifier. 151 # https://github.com/pytorch/pytorch/issues/20124 152 if self._step_count == 1: 153 if not hasattr(self.data_sparsifier.step, "_with_counter"): 154 warnings.warn( 155 "Seems like `data_sparsifier.step()` has been overridden after sparsity scheduler " 156 "initialization. Please, make sure to call `data_sparsifier.step()` before " 157 "`scheduler.step()`.", 158 UserWarning, 159 ) 160 161 # Just check if there were two first scheduler.step() calls before sparsifier.step() 162 elif self.data_sparsifier._step_count < 1: # type: ignore[attr-defined] 163 warnings.warn( 164 "Detected call of `scheduler.step()` before `data_sparsifier.step()`. " 165 "You have to make sure you run the data_sparsifier.step() BEFORE any " 166 "calls to the scheduler.step().", 167 UserWarning, 168 ) 169 self._step_count += 1 170 171 class _enable_get_sp_call: 172 def __init__(self, o): 173 self.o = o 174 175 def __enter__(self): 176 self.o._get_sp_called_within_step = True 177 return self 178 179 def __exit__(self, type, value, traceback): 180 self.o._get_sp_called_within_step = False 181 182 with _enable_get_sp_call(self): 183 self.last_epoch += 1 184 updated_scheduler_params = self.get_schedule_param() 185 186 for name, param in updated_scheduler_params.items(): 187 self.data_sparsifier.data_groups[name][self.schedule_param] = param 188 if self.verbose: 189 print(f"Adjusting {self.schedule_param} for group {name} to {param}") 190 191 self._last_param = { 192 name: config.get(self.schedule_param, None) 193 for name, config in self.data_sparsifier.data_groups.items() 194 } 195 self.data_sparsifier.enable_mask_update = True 196