xref: /aosp_15_r20/external/pytorch/torch/nn/utils/init.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import inspect
3
4import torch
5
6
7def skip_init(module_cls, *args, **kwargs):
8    r"""
9    Given a module class object and args / kwargs, instantiate the module without initializing parameters / buffers.
10
11    This can be useful if initialization is slow or if custom initialization will
12    be performed, making the default initialization unnecessary. There are some caveats to this, due to
13    the way this function is implemented:
14
15    1. The module must accept a `device` arg in its constructor that is passed to any parameters
16    or buffers created during construction.
17
18    2. The module must not perform any computation on parameters in its constructor except
19    initialization (i.e. functions from :mod:`torch.nn.init`).
20
21    If these conditions are satisfied, the module can be instantiated with parameter / buffer values
22    uninitialized, as if having been created using :func:`torch.empty`.
23
24    Args:
25        module_cls: Class object; should be a subclass of :class:`torch.nn.Module`
26        args: args to pass to the module's constructor
27        kwargs: kwargs to pass to the module's constructor
28
29    Returns:
30        Instantiated module with uninitialized parameters / buffers
31
32    Example::
33
34        >>> # xdoctest: +IGNORE_WANT("non-deterministic")
35        >>> import torch
36        >>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1)
37        >>> m.weight
38        Parameter containing:
39        tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]],
40               requires_grad=True)
41        >>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1)
42        >>> m2.weight
43        Parameter containing:
44        tensor([[-1.4677e+24,  4.5915e-41,  1.4013e-45,  0.0000e+00, -1.4677e+24,
45                  4.5915e-41]], requires_grad=True)
46
47    """
48    if not issubclass(module_cls, torch.nn.Module):
49        raise RuntimeError(f"Expected a Module; got {module_cls}")
50    if "device" not in inspect.signature(module_cls).parameters:
51        raise RuntimeError("Module must support a 'device' arg to skip initialization")
52
53    final_device = kwargs.pop("device", "cpu")
54    kwargs["device"] = "meta"
55    return module_cls(*args, **kwargs).to_empty(device=final_device)
56