1# mypy: allow-untyped-defs 2import itertools 3from typing import Any, Optional, Protocol, Type 4 5import torch 6from torch.nn.parameter import is_lazy 7 8 9__all__ = ["LazyModuleMixin"] 10 11 12class _LazyProtocol(Protocol): 13 """This class is used to avoid errors with mypy checks for the attributes in a mixin. 14 15 https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes 16 """ 17 18 def _register_load_state_dict_pre_hook(self, hook): 19 ... 20 21 def register_forward_pre_hook(self, hook, *, prepend=False, with_kwargs=False): 22 ... 23 24 def _lazy_load_hook( 25 self, 26 state_dict, 27 prefix, 28 local_metadata, 29 strict, 30 missing_keys, 31 unexpected_keys, 32 error_msgs, 33 ): 34 ... 35 36 def _get_name(self): 37 ... 38 39 def _infer_parameters(self, module, input): 40 ... 41 42 @property 43 def _parameters(self): 44 ... 45 46 @property 47 def _buffers(self): 48 ... 49 50 @property 51 def _non_persistent_buffers_set(self): 52 ... 53 54 @property 55 def _load_hook(self): 56 ... 57 58 @property 59 def _initialize_hook(self): 60 ... 61 62 63class LazyModuleMixin: 64 r"""A mixin for modules that lazily initialize parameters, also known as "lazy modules". 65 66 .. warning: 67 Lazy modules are an experimental new feature under active development, 68 and their API is likely to change. 69 70 Modules that lazily initialize parameters, or "lazy modules", 71 derive the shapes of their parameters from the first input(s) 72 to their forward method. Until that first forward they contain 73 :class:`torch.nn.UninitializedParameter` s that should not be accessed 74 or used, and afterward they contain regular :class:`torch.nn.Parameter` s. 75 Lazy modules are convenient since they don't require computing some 76 module arguments, like the :attr:`in_features` argument of a 77 typical :class:`torch.nn.Linear`. 78 79 After construction, networks with lazy modules should first 80 be converted to the desired dtype and placed on the expected device. 81 This is because lazy modules only perform shape inference so the usual dtype 82 and device placement behavior applies. 83 The lazy modules should then perform "dry runs" to initialize all the components in the module. 84 These "dry runs" send inputs of the correct size, dtype, and device through 85 the network and to each one of its lazy modules. After this the network can be used as usual. 86 87 >>> # xdoctest: +SKIP 88 >>> class LazyMLP(torch.nn.Module): 89 ... def __init__(self) -> None: 90 ... super().__init__() 91 ... self.fc1 = torch.nn.LazyLinear(10) 92 ... self.relu1 = torch.nn.ReLU() 93 ... self.fc2 = torch.nn.LazyLinear(1) 94 ... self.relu2 = torch.nn.ReLU() 95 ... 96 ... def forward(self, input): 97 ... x = self.relu1(self.fc1(input)) 98 ... y = self.relu2(self.fc2(x)) 99 ... return y 100 >>> # constructs a network with lazy modules 101 >>> lazy_mlp = LazyMLP() 102 >>> # transforms the network's device and dtype 103 >>> # NOTE: these transforms can and should be applied after construction and before any 'dry runs' 104 >>> lazy_mlp = lazy_mlp.cuda().double() 105 >>> lazy_mlp 106 LazyMLP( (fc1): LazyLinear(in_features=0, out_features=10, bias=True) 107 (relu1): ReLU() 108 (fc2): LazyLinear(in_features=0, out_features=1, bias=True) 109 (relu2): ReLU() 110 ) 111 >>> # performs a dry run to initialize the network's lazy modules 112 >>> lazy_mlp(torch.ones(10,10).cuda()) 113 >>> # after initialization, LazyLinear modules become regular Linear modules 114 >>> lazy_mlp 115 LazyMLP( 116 (fc1): Linear(in_features=10, out_features=10, bias=True) 117 (relu1): ReLU() 118 (fc2): Linear(in_features=10, out_features=1, bias=True) 119 (relu2): ReLU() 120 ) 121 >>> # attaches an optimizer, since parameters can now be used as usual 122 >>> optim = torch.optim.SGD(mlp.parameters(), lr=0.01) 123 124 A final caveat when using lazy modules is that the order of initialization of a network's 125 parameters may change, since the lazy modules are always initialized after other modules. 126 For example, if the LazyMLP class defined above had a :class:`torch.nn.LazyLinear` module 127 first and then a regular :class:`torch.nn.Linear` second, the second module would be 128 initialized on construction and the first module would be initialized during the first dry run. 129 This can cause the parameters of a network using lazy modules to be initialized differently 130 than the parameters of a network without lazy modules as the order of parameter initializations, 131 which often depends on a stateful random number generator, is different. 132 Check :doc:`/notes/randomness` for more details. 133 134 Lazy modules can be serialized with a state dict like other modules. For example: 135 136 >>> lazy_mlp = LazyMLP() 137 >>> # The state dict shows the uninitialized parameters 138 >>> lazy_mlp.state_dict() 139 OrderedDict([('fc1.weight', Uninitialized parameter), 140 ('fc1.bias', 141 tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30, 142 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])), 143 ('fc2.weight', Uninitialized parameter), 144 ('fc2.bias', tensor([0.0019]))]) 145 146 147 Lazy modules can load regular :class:`torch.nn.Parameter` s (i.e. you can serialize/deserialize 148 initialized LazyModules and they will remain initialized) 149 150 151 >>> full_mlp = LazyMLP() 152 >>> # Dry run to initialize another module 153 >>> full_mlp.forward(torch.ones(10, 1)) 154 >>> # Load an initialized state into a lazy module 155 >>> lazy_mlp.load_state_dict(full_mlp.state_dict()) 156 >>> # The state dict now holds valid values 157 >>> lazy_mlp.state_dict() 158 OrderedDict([('fc1.weight', 159 tensor([[-0.3837], 160 [ 0.0907], 161 [ 0.6708], 162 [-0.5223], 163 [-0.9028], 164 [ 0.2851], 165 [-0.4537], 166 [ 0.6813], 167 [ 0.5766], 168 [-0.8678]])), 169 ('fc1.bias', 170 tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30, 171 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])), 172 ('fc2.weight', 173 tensor([[ 0.1320, 0.2938, 0.0679, 0.2793, 0.1088, -0.1795, -0.2301, 0.2807, 174 0.2479, 0.1091]])), 175 ('fc2.bias', tensor([0.0019]))]) 176 177 Note, however, that the loaded parameters will not be replaced when doing a "dry run" if they are initialized 178 when the state is loaded. This prevents using initialized modules in different contexts. 179 """ 180 181 # modules inheriting from this will change their __class__ to the specified 182 # one after they are fully initialized 183 cls_to_become: Optional[Type[Any]] = None 184 185 def __init__(self: _LazyProtocol, *args, **kwargs): 186 # Mypy doesnt like this super call in a mixin 187 super().__init__(*args, **kwargs) # type: ignore[misc] 188 self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook) 189 self._initialize_hook = self.register_forward_pre_hook( 190 self._infer_parameters, with_kwargs=True 191 ) 192 193 def _save_to_state_dict(self: _LazyProtocol, destination, prefix, keep_vars): 194 # This should be ideally implemented as a hook, 195 # but we should override `detach` in the UninitializedParameter to return itself 196 # which is not clean 197 for name, param in self._parameters.items(): 198 if param is not None: 199 if not (is_lazy(param) or keep_vars): 200 param = param.detach() 201 destination[prefix + name] = param 202 for name, buf in self._buffers.items(): 203 if buf is not None and name not in self._non_persistent_buffers_set: 204 if not (is_lazy(buf) or keep_vars): 205 buf = buf.detach() 206 destination[prefix + name] = buf 207 208 def _lazy_load_hook( 209 self: _LazyProtocol, 210 state_dict, 211 prefix, 212 local_metadata, 213 strict, 214 missing_keys, 215 unexpected_keys, 216 error_msgs, 217 ): 218 """load_state_dict pre-hook function for lazy buffers and parameters. 219 220 The purpose of this hook is to adjust the current state and/or 221 ``state_dict`` being loaded so that a module instance serialized in 222 both un/initialized state can be deserialized onto both un/initialized 223 module instance. 224 See comment in ``torch.nn.Module._register_load_state_dict_pre_hook`` 225 for the details of the hook specification. 226 """ 227 for name, param in itertools.chain( 228 self._parameters.items(), self._buffers.items() 229 ): 230 key = prefix + name 231 if key in state_dict and param is not None: 232 input_param = state_dict[key] 233 if is_lazy(param): 234 # The current parameter is not initialized but the one being loaded one is 235 # create a new parameter based on the uninitialized one 236 if not is_lazy(input_param): 237 with torch.no_grad(): 238 param.materialize(input_param.shape) 239 240 def initialize_parameters(self: _LazyProtocol, *args, **kwargs): 241 r"""Initialize parameters according to the input batch properties. 242 243 This adds an interface to isolate parameter initialization from the 244 forward pass when doing parameter shape inference. 245 """ 246 raise NotImplementedError( 247 f"initialize_parameters is not implemented for {self.__class__.__name__}" 248 ) 249 250 def has_uninitialized_params(self: _LazyProtocol): 251 r"""Check if a module has parameters that are not initialized.""" 252 # This is to avoid the JIT to track this parameter and force 253 # custom modules __setstate__ to add it 254 params = self._parameters.values() 255 buffers = self._buffers.values() 256 for param in itertools.chain(params, buffers): 257 if is_lazy(param): 258 return True 259 return False 260 261 # torchrec tests the code consistency with the following code 262 # fmt: off 263 def _infer_parameters(self: _LazyProtocol, module, args, kwargs=None): 264 r"""Infers the size and initializes the parameters according to the provided input batch. 265 266 Given a module that contains parameters that were declared inferrable 267 using :class:`torch.nn.parameter.ParameterMode.Infer`, runs a forward pass 268 in the complete module using the provided input to initialize all the parameters 269 as needed. 270 The module is set into evaluation mode before running the forward pass in order 271 to avoid saving statistics or calculating gradients 272 """ 273 kwargs = kwargs if kwargs else {} 274 module.initialize_parameters(*args, **kwargs) 275 if module.has_uninitialized_params(): 276 raise RuntimeError(f'module {self._get_name()} has not been fully initialized') 277 module._initialize_hook.remove() 278 module._load_hook.remove() 279 delattr(module, '_initialize_hook') 280 delattr(module, '_load_hook') 281 if module.cls_to_become is not None: 282 module.__class__ = module.cls_to_become 283 # fmt: on 284 285 def _replicate_for_data_parallel(self: _LazyProtocol): 286 raise RuntimeError( 287 "Modules with uninitialized parameters can't be used with `DataParallel`. " 288 "Run a dummy forward pass to correctly initialize the modules" 289 ) 290