xref: /aosp_15_r20/external/pytorch/torch/nn/modules/linear.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import math
3from typing import Any
4
5import torch
6from torch import Tensor
7from torch.nn import functional as F, init
8from torch.nn.parameter import Parameter, UninitializedParameter
9
10from .lazy import LazyModuleMixin
11from .module import Module
12
13
14__all__ = [
15    "Bilinear",
16    "Identity",
17    "LazyLinear",
18    "Linear",
19]
20
21
22class Identity(Module):
23    r"""A placeholder identity operator that is argument-insensitive.
24
25    Args:
26        args: any argument (unused)
27        kwargs: any keyword argument (unused)
28
29    Shape:
30        - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
31        - Output: :math:`(*)`, same shape as the input.
32
33    Examples::
34
35        >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
36        >>> input = torch.randn(128, 20)
37        >>> output = m(input)
38        >>> print(output.size())
39        torch.Size([128, 20])
40
41    """
42
43    def __init__(self, *args: Any, **kwargs: Any) -> None:
44        super().__init__()
45
46    def forward(self, input: Tensor) -> Tensor:
47        return input
48
49
50class Linear(Module):
51    r"""Applies an affine linear transformation to the incoming data: :math:`y = xA^T + b`.
52
53    This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
54
55    On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
56
57    Args:
58        in_features: size of each input sample
59        out_features: size of each output sample
60        bias: If set to ``False``, the layer will not learn an additive bias.
61            Default: ``True``
62
63    Shape:
64        - Input: :math:`(*, H_{in})` where :math:`*` means any number of
65          dimensions including none and :math:`H_{in} = \text{in\_features}`.
66        - Output: :math:`(*, H_{out})` where all but the last dimension
67          are the same shape as the input and :math:`H_{out} = \text{out\_features}`.
68
69    Attributes:
70        weight: the learnable weights of the module of shape
71            :math:`(\text{out\_features}, \text{in\_features})`. The values are
72            initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
73            :math:`k = \frac{1}{\text{in\_features}}`
74        bias:   the learnable bias of the module of shape :math:`(\text{out\_features})`.
75                If :attr:`bias` is ``True``, the values are initialized from
76                :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
77                :math:`k = \frac{1}{\text{in\_features}}`
78
79    Examples::
80
81        >>> m = nn.Linear(20, 30)
82        >>> input = torch.randn(128, 20)
83        >>> output = m(input)
84        >>> print(output.size())
85        torch.Size([128, 30])
86    """
87
88    __constants__ = ["in_features", "out_features"]
89    in_features: int
90    out_features: int
91    weight: Tensor
92
93    def __init__(
94        self,
95        in_features: int,
96        out_features: int,
97        bias: bool = True,
98        device=None,
99        dtype=None,
100    ) -> None:
101        factory_kwargs = {"device": device, "dtype": dtype}
102        super().__init__()
103        self.in_features = in_features
104        self.out_features = out_features
105        self.weight = Parameter(
106            torch.empty((out_features, in_features), **factory_kwargs)
107        )
108        if bias:
109            self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
110        else:
111            self.register_parameter("bias", None)
112        self.reset_parameters()
113
114    def reset_parameters(self) -> None:
115        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
116        # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
117        # https://github.com/pytorch/pytorch/issues/57109
118        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
119        if self.bias is not None:
120            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
121            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
122            init.uniform_(self.bias, -bound, bound)
123
124    def forward(self, input: Tensor) -> Tensor:
125        return F.linear(input, self.weight, self.bias)
126
127    def extra_repr(self) -> str:
128        return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
129
130
131# This class exists solely to avoid triggering an obscure error when scripting
132# an improperly quantized attention layer. See this issue for details:
133# https://github.com/pytorch/pytorch/issues/58969
134# TODO: fail fast on quantization API usage error, then remove this class
135# and replace uses of it with plain Linear
136class NonDynamicallyQuantizableLinear(Linear):
137    def __init__(
138        self,
139        in_features: int,
140        out_features: int,
141        bias: bool = True,
142        device=None,
143        dtype=None,
144    ) -> None:
145        super().__init__(
146            in_features, out_features, bias=bias, device=device, dtype=dtype
147        )
148
149
150class Bilinear(Module):
151    r"""Applies a bilinear transformation to the incoming data: :math:`y = x_1^T A x_2 + b`.
152
153    Args:
154        in1_features: size of each first input sample
155        in2_features: size of each second input sample
156        out_features: size of each output sample
157        bias: If set to False, the layer will not learn an additive bias.
158            Default: ``True``
159
160    Shape:
161        - Input1: :math:`(*, H_{in1})` where :math:`H_{in1}=\text{in1\_features}` and
162          :math:`*` means any number of additional dimensions including none. All but the last dimension
163          of the inputs should be the same.
164        - Input2: :math:`(*, H_{in2})` where :math:`H_{in2}=\text{in2\_features}`.
165        - Output: :math:`(*, H_{out})` where :math:`H_{out}=\text{out\_features}`
166          and all but the last dimension are the same shape as the input.
167
168    Attributes:
169        weight: the learnable weights of the module of shape
170            :math:`(\text{out\_features}, \text{in1\_features}, \text{in2\_features})`.
171            The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
172            :math:`k = \frac{1}{\text{in1\_features}}`
173        bias:   the learnable bias of the module of shape :math:`(\text{out\_features})`.
174                If :attr:`bias` is ``True``, the values are initialized from
175                :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
176                :math:`k = \frac{1}{\text{in1\_features}}`
177
178    Examples::
179
180        >>> m = nn.Bilinear(20, 30, 40)
181        >>> input1 = torch.randn(128, 20)
182        >>> input2 = torch.randn(128, 30)
183        >>> output = m(input1, input2)
184        >>> print(output.size())
185        torch.Size([128, 40])
186    """
187
188    __constants__ = ["in1_features", "in2_features", "out_features"]
189    in1_features: int
190    in2_features: int
191    out_features: int
192    weight: Tensor
193
194    def __init__(
195        self,
196        in1_features: int,
197        in2_features: int,
198        out_features: int,
199        bias: bool = True,
200        device=None,
201        dtype=None,
202    ) -> None:
203        factory_kwargs = {"device": device, "dtype": dtype}
204        super().__init__()
205        self.in1_features = in1_features
206        self.in2_features = in2_features
207        self.out_features = out_features
208        self.weight = Parameter(
209            torch.empty((out_features, in1_features, in2_features), **factory_kwargs)
210        )
211
212        if bias:
213            self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
214        else:
215            self.register_parameter("bias", None)
216        self.reset_parameters()
217
218    def reset_parameters(self) -> None:
219        bound = 1 / math.sqrt(self.weight.size(1))
220        init.uniform_(self.weight, -bound, bound)
221        if self.bias is not None:
222            init.uniform_(self.bias, -bound, bound)
223
224    def forward(self, input1: Tensor, input2: Tensor) -> Tensor:
225        return F.bilinear(input1, input2, self.weight, self.bias)
226
227    def extra_repr(self) -> str:
228        return (
229            f"in1_features={self.in1_features}, in2_features={self.in2_features}, "
230            f"out_features={self.out_features}, bias={self.bias is not None}"
231        )
232
233
234class LazyLinear(LazyModuleMixin, Linear):
235    r"""A :class:`torch.nn.Linear` module where `in_features` is inferred.
236
237    In this module, the `weight` and `bias` are of :class:`torch.nn.UninitializedParameter`
238    class. They will be initialized after the first call to ``forward`` is done and the
239    module will become a regular :class:`torch.nn.Linear` module. The ``in_features`` argument
240    of the :class:`Linear` is inferred from the ``input.shape[-1]``.
241
242    Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
243    on lazy modules and their limitations.
244
245    Args:
246        out_features: size of each output sample
247        bias: If set to ``False``, the layer will not learn an additive bias.
248            Default: ``True``
249
250    Attributes:
251        weight: the learnable weights of the module of shape
252            :math:`(\text{out\_features}, \text{in\_features})`. The values are
253            initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
254            :math:`k = \frac{1}{\text{in\_features}}`
255        bias:   the learnable bias of the module of shape :math:`(\text{out\_features})`.
256                If :attr:`bias` is ``True``, the values are initialized from
257                :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
258                :math:`k = \frac{1}{\text{in\_features}}`
259
260
261    """
262
263    cls_to_become = Linear  # type: ignore[assignment]
264    weight: UninitializedParameter
265    bias: UninitializedParameter  # type: ignore[assignment]
266
267    def __init__(
268        self, out_features: int, bias: bool = True, device=None, dtype=None
269    ) -> None:
270        factory_kwargs = {"device": device, "dtype": dtype}
271        # bias is hardcoded to False to avoid creating tensor
272        # that will soon be overwritten.
273        super().__init__(0, 0, False)
274        self.weight = UninitializedParameter(**factory_kwargs)
275        self.out_features = out_features
276        if bias:
277            self.bias = UninitializedParameter(**factory_kwargs)
278
279    def reset_parameters(self) -> None:
280        if not self.has_uninitialized_params() and self.in_features != 0:
281            super().reset_parameters()
282
283    def initialize_parameters(self, input) -> None:  # type: ignore[override]
284        if self.has_uninitialized_params():
285            with torch.no_grad():
286                self.in_features = input.shape[-1]
287                self.weight.materialize((self.out_features, self.in_features))
288                if self.bias is not None:
289                    self.bias.materialize((self.out_features,))
290                self.reset_parameters()
291
292
293# TODO: PartialLinear - maybe in sparse?
294