xref: /aosp_15_r20/external/pytorch/torch/nn/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from torch.nn.parameter import (  # usort: skip
3    Buffer as Buffer,
4    Parameter as Parameter,
5    UninitializedBuffer as UninitializedBuffer,
6    UninitializedParameter as UninitializedParameter,
7)
8from torch.nn.modules import *  # usort: skip # noqa: F403
9from torch.nn import (
10    attention as attention,
11    functional as functional,
12    init as init,
13    modules as modules,
14    parallel as parallel,
15    parameter as parameter,
16    utils as utils,
17)
18from torch.nn.parallel import DataParallel as DataParallel
19
20
21def factory_kwargs(kwargs):
22    r"""Return a canonicalized dict of factory kwargs.
23
24    Given kwargs, returns a canonicalized dict of factory kwargs that can be directly passed
25    to factory functions like torch.empty, or errors if unrecognized kwargs are present.
26
27    This function makes it simple to write code like this::
28
29        class MyModule(nn.Module):
30            def __init__(self, **kwargs):
31                factory_kwargs = torch.nn.factory_kwargs(kwargs)
32                self.weight = Parameter(torch.empty(10, **factory_kwargs))
33
34    Why should you use this function instead of just passing `kwargs` along directly?
35
36    1. This function does error validation, so if there are unexpected kwargs we will
37    immediately report an error, instead of deferring it to the factory call
38    2. This function supports a special `factory_kwargs` argument, which can be used to
39    explicitly specify a kwarg to be used for factory functions, in the event one of the
40    factory kwargs conflicts with an already existing argument in the signature (e.g.
41    in the signature ``def f(dtype, **kwargs)``, you can specify ``dtype`` for factory
42    functions, as distinct from the dtype argument, by saying
43    ``f(dtype1, factory_kwargs={"dtype": dtype2})``)
44    """
45    if kwargs is None:
46        return {}
47    simple_keys = {"device", "dtype", "memory_format"}
48    expected_keys = simple_keys | {"factory_kwargs"}
49    if not kwargs.keys() <= expected_keys:
50        raise TypeError(f"unexpected kwargs {kwargs.keys() - expected_keys}")
51
52    # guarantee no input kwargs is untouched
53    r = dict(kwargs.get("factory_kwargs", {}))
54    for k in simple_keys:
55        if k in kwargs:
56            if k in r:
57                raise TypeError(
58                    f"{k} specified twice, in **kwargs and in factory_kwargs"
59                )
60            r[k] = kwargs[k]
61
62    return r
63