xref: /aosp_15_r20/external/pytorch/torch/nn/modules/lazy.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import itertools
3from typing import Any, Optional, Protocol, Type
4
5import torch
6from torch.nn.parameter import is_lazy
7
8
9__all__ = ["LazyModuleMixin"]
10
11
12class _LazyProtocol(Protocol):
13    """This class is used to avoid errors with mypy checks for the attributes in a mixin.
14
15    https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes
16    """
17
18    def _register_load_state_dict_pre_hook(self, hook):
19        ...
20
21    def register_forward_pre_hook(self, hook, *, prepend=False, with_kwargs=False):
22        ...
23
24    def _lazy_load_hook(
25        self,
26        state_dict,
27        prefix,
28        local_metadata,
29        strict,
30        missing_keys,
31        unexpected_keys,
32        error_msgs,
33    ):
34        ...
35
36    def _get_name(self):
37        ...
38
39    def _infer_parameters(self, module, input):
40        ...
41
42    @property
43    def _parameters(self):
44        ...
45
46    @property
47    def _buffers(self):
48        ...
49
50    @property
51    def _non_persistent_buffers_set(self):
52        ...
53
54    @property
55    def _load_hook(self):
56        ...
57
58    @property
59    def _initialize_hook(self):
60        ...
61
62
63class LazyModuleMixin:
64    r"""A mixin for modules that lazily initialize parameters, also known as "lazy modules".
65
66    .. warning:
67        Lazy modules are an experimental new feature under active development,
68        and their API is likely to change.
69
70    Modules that lazily initialize parameters, or "lazy modules",
71    derive the shapes of their parameters from the first input(s)
72    to their forward method. Until that first forward they contain
73    :class:`torch.nn.UninitializedParameter` s that should not be accessed
74    or used, and afterward they contain regular :class:`torch.nn.Parameter` s.
75    Lazy modules are convenient since they don't require computing some
76    module arguments, like the :attr:`in_features` argument of a
77    typical :class:`torch.nn.Linear`.
78
79    After construction, networks with lazy modules should first
80    be converted to the desired dtype and placed on the expected device.
81    This is because lazy modules only perform shape inference so the usual dtype
82    and device placement behavior applies.
83    The lazy modules should then perform "dry runs" to initialize all the components in the module.
84    These "dry runs" send inputs of the correct size, dtype, and device through
85    the network and to each one of its lazy modules. After this the network can be used as usual.
86
87    >>> # xdoctest: +SKIP
88    >>> class LazyMLP(torch.nn.Module):
89    ...    def __init__(self) -> None:
90    ...        super().__init__()
91    ...        self.fc1 = torch.nn.LazyLinear(10)
92    ...        self.relu1 = torch.nn.ReLU()
93    ...        self.fc2 = torch.nn.LazyLinear(1)
94    ...        self.relu2 = torch.nn.ReLU()
95    ...
96    ...    def forward(self, input):
97    ...        x = self.relu1(self.fc1(input))
98    ...        y = self.relu2(self.fc2(x))
99    ...        return y
100    >>> # constructs a network with lazy modules
101    >>> lazy_mlp = LazyMLP()
102    >>> # transforms the network's device and dtype
103    >>> # NOTE: these transforms can and should be applied after construction and before any 'dry runs'
104    >>> lazy_mlp = lazy_mlp.cuda().double()
105    >>> lazy_mlp
106    LazyMLP( (fc1): LazyLinear(in_features=0, out_features=10, bias=True)
107      (relu1): ReLU()
108      (fc2): LazyLinear(in_features=0, out_features=1, bias=True)
109      (relu2): ReLU()
110    )
111    >>> # performs a dry run to initialize the network's lazy modules
112    >>> lazy_mlp(torch.ones(10,10).cuda())
113    >>> # after initialization, LazyLinear modules become regular Linear modules
114    >>> lazy_mlp
115    LazyMLP(
116      (fc1): Linear(in_features=10, out_features=10, bias=True)
117      (relu1): ReLU()
118      (fc2): Linear(in_features=10, out_features=1, bias=True)
119      (relu2): ReLU()
120    )
121    >>> # attaches an optimizer, since parameters can now be used as usual
122    >>> optim = torch.optim.SGD(mlp.parameters(), lr=0.01)
123
124    A final caveat when using lazy modules is that the order of initialization of a network's
125    parameters may change, since the lazy modules are always initialized after other modules.
126    For example, if the LazyMLP class defined above had a :class:`torch.nn.LazyLinear` module
127    first and then a regular :class:`torch.nn.Linear` second, the second module would be
128    initialized on construction and the first module would be initialized during the first dry run.
129    This can cause the parameters of a network using lazy modules to be initialized differently
130    than the parameters of a network without lazy modules as the order of parameter initializations,
131    which often depends on a stateful random number generator, is different.
132    Check :doc:`/notes/randomness` for more details.
133
134    Lazy modules can be serialized with a state dict like other modules. For example:
135
136    >>> lazy_mlp = LazyMLP()
137    >>> # The state dict shows the uninitialized parameters
138    >>> lazy_mlp.state_dict()
139    OrderedDict([('fc1.weight', Uninitialized parameter),
140                 ('fc1.bias',
141                  tensor([-1.8832e+25,  4.5636e-41, -1.8832e+25,  4.5636e-41, -6.1598e-30,
142                           4.5637e-41, -1.8788e+22,  4.5636e-41, -2.0042e-31,  4.5637e-41])),
143                 ('fc2.weight', Uninitialized parameter),
144                 ('fc2.bias', tensor([0.0019]))])
145
146
147    Lazy modules can load regular :class:`torch.nn.Parameter` s (i.e. you can serialize/deserialize
148    initialized LazyModules and they will remain initialized)
149
150
151    >>> full_mlp = LazyMLP()
152    >>> # Dry run to initialize another module
153    >>> full_mlp.forward(torch.ones(10, 1))
154    >>> # Load an initialized state into a lazy module
155    >>> lazy_mlp.load_state_dict(full_mlp.state_dict())
156    >>> # The state dict now holds valid values
157    >>> lazy_mlp.state_dict()
158    OrderedDict([('fc1.weight',
159                  tensor([[-0.3837],
160                          [ 0.0907],
161                          [ 0.6708],
162                          [-0.5223],
163                          [-0.9028],
164                          [ 0.2851],
165                          [-0.4537],
166                          [ 0.6813],
167                          [ 0.5766],
168                          [-0.8678]])),
169                 ('fc1.bias',
170                  tensor([-1.8832e+25,  4.5636e-41, -1.8832e+25,  4.5636e-41, -6.1598e-30,
171                           4.5637e-41, -1.8788e+22,  4.5636e-41, -2.0042e-31,  4.5637e-41])),
172                 ('fc2.weight',
173                  tensor([[ 0.1320,  0.2938,  0.0679,  0.2793,  0.1088, -0.1795, -0.2301,  0.2807,
174                            0.2479,  0.1091]])),
175                 ('fc2.bias', tensor([0.0019]))])
176
177    Note, however, that the loaded parameters will not be replaced when doing a "dry run" if they are initialized
178    when the state is loaded. This prevents using initialized modules in different contexts.
179    """
180
181    # modules inheriting from this will change their __class__ to the specified
182    # one after they are fully initialized
183    cls_to_become: Optional[Type[Any]] = None
184
185    def __init__(self: _LazyProtocol, *args, **kwargs):
186        # Mypy doesnt like this super call in a mixin
187        super().__init__(*args, **kwargs)  # type: ignore[misc]
188        self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook)
189        self._initialize_hook = self.register_forward_pre_hook(
190            self._infer_parameters, with_kwargs=True
191        )
192
193    def _save_to_state_dict(self: _LazyProtocol, destination, prefix, keep_vars):
194        # This should be ideally implemented as a hook,
195        # but we should override `detach` in the UninitializedParameter to return itself
196        # which is not clean
197        for name, param in self._parameters.items():
198            if param is not None:
199                if not (is_lazy(param) or keep_vars):
200                    param = param.detach()
201                destination[prefix + name] = param
202        for name, buf in self._buffers.items():
203            if buf is not None and name not in self._non_persistent_buffers_set:
204                if not (is_lazy(buf) or keep_vars):
205                    buf = buf.detach()
206                destination[prefix + name] = buf
207
208    def _lazy_load_hook(
209        self: _LazyProtocol,
210        state_dict,
211        prefix,
212        local_metadata,
213        strict,
214        missing_keys,
215        unexpected_keys,
216        error_msgs,
217    ):
218        """load_state_dict pre-hook function for lazy buffers and parameters.
219
220        The purpose of this hook is to adjust the current state and/or
221        ``state_dict`` being loaded so that a module instance serialized in
222        both un/initialized state can be deserialized onto both un/initialized
223        module instance.
224        See comment in ``torch.nn.Module._register_load_state_dict_pre_hook``
225        for the details of the hook specification.
226        """
227        for name, param in itertools.chain(
228            self._parameters.items(), self._buffers.items()
229        ):
230            key = prefix + name
231            if key in state_dict and param is not None:
232                input_param = state_dict[key]
233                if is_lazy(param):
234                    # The current parameter is not initialized but the one being loaded one is
235                    # create a new parameter based on the uninitialized one
236                    if not is_lazy(input_param):
237                        with torch.no_grad():
238                            param.materialize(input_param.shape)
239
240    def initialize_parameters(self: _LazyProtocol, *args, **kwargs):
241        r"""Initialize parameters according to the input batch properties.
242
243        This adds an interface to isolate parameter initialization from the
244        forward pass when doing parameter shape inference.
245        """
246        raise NotImplementedError(
247            f"initialize_parameters is not implemented for {self.__class__.__name__}"
248        )
249
250    def has_uninitialized_params(self: _LazyProtocol):
251        r"""Check if a module has parameters that are not initialized."""
252        # This is to avoid the JIT to track this parameter and force
253        # custom modules __setstate__ to add it
254        params = self._parameters.values()
255        buffers = self._buffers.values()
256        for param in itertools.chain(params, buffers):
257            if is_lazy(param):
258                return True
259        return False
260
261    # torchrec tests the code consistency with the following code
262    # fmt: off
263    def _infer_parameters(self: _LazyProtocol, module, args, kwargs=None):
264        r"""Infers the size and initializes the parameters according to the provided input batch.
265
266        Given a module that contains parameters that were declared inferrable
267        using :class:`torch.nn.parameter.ParameterMode.Infer`, runs a forward pass
268        in the complete module using the provided input to initialize all the parameters
269        as needed.
270        The module is set into evaluation mode before running the forward pass in order
271        to avoid saving statistics or calculating gradients
272        """
273        kwargs = kwargs if kwargs else {}
274        module.initialize_parameters(*args, **kwargs)
275        if module.has_uninitialized_params():
276            raise RuntimeError(f'module {self._get_name()} has not been fully initialized')
277        module._initialize_hook.remove()
278        module._load_hook.remove()
279        delattr(module, '_initialize_hook')
280        delattr(module, '_load_hook')
281        if module.cls_to_become is not None:
282            module.__class__ = module.cls_to_become
283    # fmt: on
284
285    def _replicate_for_data_parallel(self: _LazyProtocol):
286        raise RuntimeError(
287            "Modules with uninitialized parameters can't be used with `DataParallel`. "
288            "Run a dummy forward pass to correctly initialize the modules"
289        )
290