xref: /aosp_15_r20/external/pytorch/torch/nn/init.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Worker"""This file contains utilities for initializing neural network parameters."""
3*da0073e9SAndroid Build Coastguard Workerimport math
4*da0073e9SAndroid Build Coastguard Workerimport warnings
5*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional as _Optional
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker# These no_grad_* functions are necessary as wrappers around the parts of these
12*da0073e9SAndroid Build Coastguard Worker# functions that use `with torch.no_grad()`. The JIT doesn't support context
13*da0073e9SAndroid Build Coastguard Worker# managers, so these need to be implemented as builtins. Using these wrappers
14*da0073e9SAndroid Build Coastguard Worker# lets us keep those builtins small and re-usable.
15*da0073e9SAndroid Build Coastguard Workerdef _no_grad_uniform_(tensor, a, b, generator=None):
16*da0073e9SAndroid Build Coastguard Worker    with torch.no_grad():
17*da0073e9SAndroid Build Coastguard Worker        return tensor.uniform_(a, b, generator=generator)
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerdef _no_grad_normal_(tensor, mean, std, generator=None):
21*da0073e9SAndroid Build Coastguard Worker    with torch.no_grad():
22*da0073e9SAndroid Build Coastguard Worker        return tensor.normal_(mean, std, generator=generator)
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Workerdef _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None):
26*da0073e9SAndroid Build Coastguard Worker    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
27*da0073e9SAndroid Build Coastguard Worker    def norm_cdf(x):
28*da0073e9SAndroid Build Coastguard Worker        # Computes standard normal cumulative distribution function
29*da0073e9SAndroid Build Coastguard Worker        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker    if (mean < a - 2 * std) or (mean > b + 2 * std):
32*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
33*da0073e9SAndroid Build Coastguard Worker            "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
34*da0073e9SAndroid Build Coastguard Worker            "The distribution of values may be incorrect.",
35*da0073e9SAndroid Build Coastguard Worker            stacklevel=2,
36*da0073e9SAndroid Build Coastguard Worker        )
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker    with torch.no_grad():
39*da0073e9SAndroid Build Coastguard Worker        # Values are generated by using a truncated uniform distribution and
40*da0073e9SAndroid Build Coastguard Worker        # then using the inverse CDF for the normal distribution.
41*da0073e9SAndroid Build Coastguard Worker        # Get upper and lower cdf values
42*da0073e9SAndroid Build Coastguard Worker        l = norm_cdf((a - mean) / std)
43*da0073e9SAndroid Build Coastguard Worker        u = norm_cdf((b - mean) / std)
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker        # Uniformly fill tensor with values from [l, u], then translate to
46*da0073e9SAndroid Build Coastguard Worker        # [2l-1, 2u-1].
47*da0073e9SAndroid Build Coastguard Worker        tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator)
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker        # Use inverse cdf transform for normal distribution to get truncated
50*da0073e9SAndroid Build Coastguard Worker        # standard normal
51*da0073e9SAndroid Build Coastguard Worker        tensor.erfinv_()
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker        # Transform to proper mean, std
54*da0073e9SAndroid Build Coastguard Worker        tensor.mul_(std * math.sqrt(2.0))
55*da0073e9SAndroid Build Coastguard Worker        tensor.add_(mean)
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker        # Clamp to ensure it's in the proper range
58*da0073e9SAndroid Build Coastguard Worker        tensor.clamp_(min=a, max=b)
59*da0073e9SAndroid Build Coastguard Worker        return tensor
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Workerdef _no_grad_fill_(tensor, val):
63*da0073e9SAndroid Build Coastguard Worker    with torch.no_grad():
64*da0073e9SAndroid Build Coastguard Worker        return tensor.fill_(val)
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Workerdef _no_grad_zero_(tensor):
68*da0073e9SAndroid Build Coastguard Worker    with torch.no_grad():
69*da0073e9SAndroid Build Coastguard Worker        return tensor.zero_()
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Workerdef calculate_gain(nonlinearity, param=None):
73*da0073e9SAndroid Build Coastguard Worker    r"""Return the recommended gain value for the given nonlinearity function.
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    The values are as follows:
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker    ================= ====================================================
78*da0073e9SAndroid Build Coastguard Worker    nonlinearity      gain
79*da0073e9SAndroid Build Coastguard Worker    ================= ====================================================
80*da0073e9SAndroid Build Coastguard Worker    Linear / Identity :math:`1`
81*da0073e9SAndroid Build Coastguard Worker    Conv{1,2,3}D      :math:`1`
82*da0073e9SAndroid Build Coastguard Worker    Sigmoid           :math:`1`
83*da0073e9SAndroid Build Coastguard Worker    Tanh              :math:`\frac{5}{3}`
84*da0073e9SAndroid Build Coastguard Worker    ReLU              :math:`\sqrt{2}`
85*da0073e9SAndroid Build Coastguard Worker    Leaky Relu        :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
86*da0073e9SAndroid Build Coastguard Worker    SELU              :math:`\frac{3}{4}`
87*da0073e9SAndroid Build Coastguard Worker    ================= ====================================================
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker    .. warning::
90*da0073e9SAndroid Build Coastguard Worker        In order to implement `Self-Normalizing Neural Networks`_ ,
91*da0073e9SAndroid Build Coastguard Worker        you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
92*da0073e9SAndroid Build Coastguard Worker        This gives the initial weights a variance of ``1 / N``,
93*da0073e9SAndroid Build Coastguard Worker        which is necessary to induce a stable fixed point in the forward pass.
94*da0073e9SAndroid Build Coastguard Worker        In contrast, the default gain for ``SELU`` sacrifices the normalization
95*da0073e9SAndroid Build Coastguard Worker        effect for more stable gradient flow in rectangular layers.
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    Args:
98*da0073e9SAndroid Build Coastguard Worker        nonlinearity: the non-linear function (`nn.functional` name)
99*da0073e9SAndroid Build Coastguard Worker        param: optional parameter for the non-linear function
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker    Examples:
102*da0073e9SAndroid Build Coastguard Worker        >>> gain = nn.init.calculate_gain('leaky_relu', 0.2)  # leaky_relu with negative_slope=0.2
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker    .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
105*da0073e9SAndroid Build Coastguard Worker    """
106*da0073e9SAndroid Build Coastguard Worker    linear_fns = [
107*da0073e9SAndroid Build Coastguard Worker        "linear",
108*da0073e9SAndroid Build Coastguard Worker        "conv1d",
109*da0073e9SAndroid Build Coastguard Worker        "conv2d",
110*da0073e9SAndroid Build Coastguard Worker        "conv3d",
111*da0073e9SAndroid Build Coastguard Worker        "conv_transpose1d",
112*da0073e9SAndroid Build Coastguard Worker        "conv_transpose2d",
113*da0073e9SAndroid Build Coastguard Worker        "conv_transpose3d",
114*da0073e9SAndroid Build Coastguard Worker    ]
115*da0073e9SAndroid Build Coastguard Worker    if nonlinearity in linear_fns or nonlinearity == "sigmoid":
116*da0073e9SAndroid Build Coastguard Worker        return 1
117*da0073e9SAndroid Build Coastguard Worker    elif nonlinearity == "tanh":
118*da0073e9SAndroid Build Coastguard Worker        return 5.0 / 3
119*da0073e9SAndroid Build Coastguard Worker    elif nonlinearity == "relu":
120*da0073e9SAndroid Build Coastguard Worker        return math.sqrt(2.0)
121*da0073e9SAndroid Build Coastguard Worker    elif nonlinearity == "leaky_relu":
122*da0073e9SAndroid Build Coastguard Worker        if param is None:
123*da0073e9SAndroid Build Coastguard Worker            negative_slope = 0.01
124*da0073e9SAndroid Build Coastguard Worker        elif (
125*da0073e9SAndroid Build Coastguard Worker            not isinstance(param, bool)
126*da0073e9SAndroid Build Coastguard Worker            and isinstance(param, int)
127*da0073e9SAndroid Build Coastguard Worker            or isinstance(param, float)
128*da0073e9SAndroid Build Coastguard Worker        ):
129*da0073e9SAndroid Build Coastguard Worker            # True/False are instances of int, hence check above
130*da0073e9SAndroid Build Coastguard Worker            negative_slope = param
131*da0073e9SAndroid Build Coastguard Worker        else:
132*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"negative_slope {param} not a valid number")
133*da0073e9SAndroid Build Coastguard Worker        return math.sqrt(2.0 / (1 + negative_slope**2))
134*da0073e9SAndroid Build Coastguard Worker    elif nonlinearity == "selu":
135*da0073e9SAndroid Build Coastguard Worker        return (
136*da0073e9SAndroid Build Coastguard Worker            3.0 / 4
137*da0073e9SAndroid Build Coastguard Worker        )  # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)
138*da0073e9SAndroid Build Coastguard Worker    else:
139*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"Unsupported nonlinearity {nonlinearity}")
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Workerdef uniform_(
143*da0073e9SAndroid Build Coastguard Worker    tensor: Tensor,
144*da0073e9SAndroid Build Coastguard Worker    a: float = 0.0,
145*da0073e9SAndroid Build Coastguard Worker    b: float = 1.0,
146*da0073e9SAndroid Build Coastguard Worker    generator: _Optional[torch.Generator] = None,
147*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
148*da0073e9SAndroid Build Coastguard Worker    r"""Fill the input Tensor with values drawn from the uniform distribution.
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker    :math:`\mathcal{U}(a, b)`.
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker    Args:
153*da0073e9SAndroid Build Coastguard Worker        tensor: an n-dimensional `torch.Tensor`
154*da0073e9SAndroid Build Coastguard Worker        a: the lower bound of the uniform distribution
155*da0073e9SAndroid Build Coastguard Worker        b: the upper bound of the uniform distribution
156*da0073e9SAndroid Build Coastguard Worker        generator: the torch Generator to sample from (default: None)
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker    Examples:
159*da0073e9SAndroid Build Coastguard Worker        >>> w = torch.empty(3, 5)
160*da0073e9SAndroid Build Coastguard Worker        >>> nn.init.uniform_(w)
161*da0073e9SAndroid Build Coastguard Worker    """
162*da0073e9SAndroid Build Coastguard Worker    if torch.overrides.has_torch_function_variadic(tensor):
163*da0073e9SAndroid Build Coastguard Worker        return torch.overrides.handle_torch_function(
164*da0073e9SAndroid Build Coastguard Worker            uniform_, (tensor,), tensor=tensor, a=a, b=b, generator=generator
165*da0073e9SAndroid Build Coastguard Worker        )
166*da0073e9SAndroid Build Coastguard Worker    return _no_grad_uniform_(tensor, a, b, generator)
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Workerdef normal_(
170*da0073e9SAndroid Build Coastguard Worker    tensor: Tensor,
171*da0073e9SAndroid Build Coastguard Worker    mean: float = 0.0,
172*da0073e9SAndroid Build Coastguard Worker    std: float = 1.0,
173*da0073e9SAndroid Build Coastguard Worker    generator: _Optional[torch.Generator] = None,
174*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
175*da0073e9SAndroid Build Coastguard Worker    r"""Fill the input Tensor with values drawn from the normal distribution.
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker    :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker    Args:
180*da0073e9SAndroid Build Coastguard Worker        tensor: an n-dimensional `torch.Tensor`
181*da0073e9SAndroid Build Coastguard Worker        mean: the mean of the normal distribution
182*da0073e9SAndroid Build Coastguard Worker        std: the standard deviation of the normal distribution
183*da0073e9SAndroid Build Coastguard Worker        generator: the torch Generator to sample from (default: None)
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker    Examples:
186*da0073e9SAndroid Build Coastguard Worker        >>> w = torch.empty(3, 5)
187*da0073e9SAndroid Build Coastguard Worker        >>> nn.init.normal_(w)
188*da0073e9SAndroid Build Coastguard Worker    """
189*da0073e9SAndroid Build Coastguard Worker    if torch.overrides.has_torch_function_variadic(tensor):
190*da0073e9SAndroid Build Coastguard Worker        return torch.overrides.handle_torch_function(
191*da0073e9SAndroid Build Coastguard Worker            normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator
192*da0073e9SAndroid Build Coastguard Worker        )
193*da0073e9SAndroid Build Coastguard Worker    return _no_grad_normal_(tensor, mean, std, generator)
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Workerdef trunc_normal_(
197*da0073e9SAndroid Build Coastguard Worker    tensor: Tensor,
198*da0073e9SAndroid Build Coastguard Worker    mean: float = 0.0,
199*da0073e9SAndroid Build Coastguard Worker    std: float = 1.0,
200*da0073e9SAndroid Build Coastguard Worker    a: float = -2.0,
201*da0073e9SAndroid Build Coastguard Worker    b: float = 2.0,
202*da0073e9SAndroid Build Coastguard Worker    generator: _Optional[torch.Generator] = None,
203*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
204*da0073e9SAndroid Build Coastguard Worker    r"""Fill the input Tensor with values drawn from a truncated normal distribution.
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker    The values are effectively drawn from the
207*da0073e9SAndroid Build Coastguard Worker    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
208*da0073e9SAndroid Build Coastguard Worker    with values outside :math:`[a, b]` redrawn until they are within
209*da0073e9SAndroid Build Coastguard Worker    the bounds. The method used for generating the random values works
210*da0073e9SAndroid Build Coastguard Worker    best when :math:`a \leq \text{mean} \leq b`.
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker    Args:
213*da0073e9SAndroid Build Coastguard Worker        tensor: an n-dimensional `torch.Tensor`
214*da0073e9SAndroid Build Coastguard Worker        mean: the mean of the normal distribution
215*da0073e9SAndroid Build Coastguard Worker        std: the standard deviation of the normal distribution
216*da0073e9SAndroid Build Coastguard Worker        a: the minimum cutoff value
217*da0073e9SAndroid Build Coastguard Worker        b: the maximum cutoff value
218*da0073e9SAndroid Build Coastguard Worker        generator: the torch Generator to sample from (default: None)
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker    Examples:
221*da0073e9SAndroid Build Coastguard Worker        >>> w = torch.empty(3, 5)
222*da0073e9SAndroid Build Coastguard Worker        >>> nn.init.trunc_normal_(w)
223*da0073e9SAndroid Build Coastguard Worker    """
224*da0073e9SAndroid Build Coastguard Worker    return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Workerdef constant_(tensor: Tensor, val: float) -> Tensor:
228*da0073e9SAndroid Build Coastguard Worker    r"""Fill the input Tensor with the value :math:`\text{val}`.
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker    Args:
231*da0073e9SAndroid Build Coastguard Worker        tensor: an n-dimensional `torch.Tensor`
232*da0073e9SAndroid Build Coastguard Worker        val: the value to fill the tensor with
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker    Examples:
235*da0073e9SAndroid Build Coastguard Worker        >>> w = torch.empty(3, 5)
236*da0073e9SAndroid Build Coastguard Worker        >>> nn.init.constant_(w, 0.3)
237*da0073e9SAndroid Build Coastguard Worker    """
238*da0073e9SAndroid Build Coastguard Worker    if torch.overrides.has_torch_function_variadic(tensor):
239*da0073e9SAndroid Build Coastguard Worker        return torch.overrides.handle_torch_function(
240*da0073e9SAndroid Build Coastguard Worker            constant_, (tensor,), tensor=tensor, val=val
241*da0073e9SAndroid Build Coastguard Worker        )
242*da0073e9SAndroid Build Coastguard Worker    return _no_grad_fill_(tensor, val)
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Workerdef ones_(tensor: Tensor) -> Tensor:
246*da0073e9SAndroid Build Coastguard Worker    r"""Fill the input Tensor with the scalar value `1`.
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker    Args:
249*da0073e9SAndroid Build Coastguard Worker        tensor: an n-dimensional `torch.Tensor`
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker    Examples:
252*da0073e9SAndroid Build Coastguard Worker        >>> w = torch.empty(3, 5)
253*da0073e9SAndroid Build Coastguard Worker        >>> nn.init.ones_(w)
254*da0073e9SAndroid Build Coastguard Worker    """
255*da0073e9SAndroid Build Coastguard Worker    return _no_grad_fill_(tensor, 1.0)
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Workerdef zeros_(tensor: Tensor) -> Tensor:
259*da0073e9SAndroid Build Coastguard Worker    r"""Fill the input Tensor with the scalar value `0`.
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker    Args:
262*da0073e9SAndroid Build Coastguard Worker        tensor: an n-dimensional `torch.Tensor`
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker    Examples:
265*da0073e9SAndroid Build Coastguard Worker        >>> w = torch.empty(3, 5)
266*da0073e9SAndroid Build Coastguard Worker        >>> nn.init.zeros_(w)
267*da0073e9SAndroid Build Coastguard Worker    """
268*da0073e9SAndroid Build Coastguard Worker    return _no_grad_zero_(tensor)
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Workerdef eye_(tensor):
272*da0073e9SAndroid Build Coastguard Worker    r"""Fill the 2-dimensional input `Tensor` with the identity matrix.
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker    Preserves the identity of the inputs in `Linear` layers, where as
275*da0073e9SAndroid Build Coastguard Worker    many inputs are preserved as possible.
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker    Args:
278*da0073e9SAndroid Build Coastguard Worker        tensor: a 2-dimensional `torch.Tensor`
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker    Examples:
281*da0073e9SAndroid Build Coastguard Worker        >>> w = torch.empty(3, 5)
282*da0073e9SAndroid Build Coastguard Worker        >>> nn.init.eye_(w)
283*da0073e9SAndroid Build Coastguard Worker    """
284*da0073e9SAndroid Build Coastguard Worker    if tensor.ndimension() != 2:
285*da0073e9SAndroid Build Coastguard Worker        raise ValueError("Only tensors with 2 dimensions are supported")
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker    with torch.no_grad():
288*da0073e9SAndroid Build Coastguard Worker        torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad)
289*da0073e9SAndroid Build Coastguard Worker    return tensor
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Workerdef dirac_(tensor, groups=1):
293*da0073e9SAndroid Build Coastguard Worker    r"""Fill the {3, 4, 5}-dimensional input `Tensor` with the Dirac delta function.
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker    Preserves the identity of the inputs in `Convolutional`
296*da0073e9SAndroid Build Coastguard Worker    layers, where as many input channels are preserved as possible. In case
297*da0073e9SAndroid Build Coastguard Worker    of groups>1, each group of channels preserves identity
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker    Args:
300*da0073e9SAndroid Build Coastguard Worker        tensor: a {3, 4, 5}-dimensional `torch.Tensor`
301*da0073e9SAndroid Build Coastguard Worker        groups (int, optional): number of groups in the conv layer (default: 1)
302*da0073e9SAndroid Build Coastguard Worker    Examples:
303*da0073e9SAndroid Build Coastguard Worker        >>> w = torch.empty(3, 16, 5, 5)
304*da0073e9SAndroid Build Coastguard Worker        >>> nn.init.dirac_(w)
305*da0073e9SAndroid Build Coastguard Worker        >>> w = torch.empty(3, 24, 5, 5)
306*da0073e9SAndroid Build Coastguard Worker        >>> nn.init.dirac_(w, 3)
307*da0073e9SAndroid Build Coastguard Worker    """
308*da0073e9SAndroid Build Coastguard Worker    dimensions = tensor.ndimension()
309*da0073e9SAndroid Build Coastguard Worker    if dimensions not in [3, 4, 5]:
310*da0073e9SAndroid Build Coastguard Worker        raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported")
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker    sizes = tensor.size()
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker    if sizes[0] % groups != 0:
315*da0073e9SAndroid Build Coastguard Worker        raise ValueError("dim 0 must be divisible by groups")
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker    out_chans_per_grp = sizes[0] // groups
318*da0073e9SAndroid Build Coastguard Worker    min_dim = min(out_chans_per_grp, sizes[1])
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker    with torch.no_grad():
321*da0073e9SAndroid Build Coastguard Worker        tensor.zero_()
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker        for g in range(groups):
324*da0073e9SAndroid Build Coastguard Worker            for d in range(min_dim):
325*da0073e9SAndroid Build Coastguard Worker                if dimensions == 3:  # Temporal convolution
326*da0073e9SAndroid Build Coastguard Worker                    tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1
327*da0073e9SAndroid Build Coastguard Worker                elif dimensions == 4:  # Spatial convolution
328*da0073e9SAndroid Build Coastguard Worker                    tensor[
329*da0073e9SAndroid Build Coastguard Worker                        g * out_chans_per_grp + d,
330*da0073e9SAndroid Build Coastguard Worker                        d,
331*da0073e9SAndroid Build Coastguard Worker                        tensor.size(2) // 2,
332*da0073e9SAndroid Build Coastguard Worker                        tensor.size(3) // 2,
333*da0073e9SAndroid Build Coastguard Worker                    ] = 1
334*da0073e9SAndroid Build Coastguard Worker                else:  # Volumetric convolution
335*da0073e9SAndroid Build Coastguard Worker                    tensor[
336*da0073e9SAndroid Build Coastguard Worker                        g * out_chans_per_grp + d,
337*da0073e9SAndroid Build Coastguard Worker                        d,
338*da0073e9SAndroid Build Coastguard Worker                        tensor.size(2) // 2,
339*da0073e9SAndroid Build Coastguard Worker                        tensor.size(3) // 2,
340*da0073e9SAndroid Build Coastguard Worker                        tensor.size(4) // 2,
341*da0073e9SAndroid Build Coastguard Worker                    ] = 1
342*da0073e9SAndroid Build Coastguard Worker    return tensor
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Workerdef _calculate_fan_in_and_fan_out(tensor):
346*da0073e9SAndroid Build Coastguard Worker    dimensions = tensor.dim()
347*da0073e9SAndroid Build Coastguard Worker    if dimensions < 2:
348*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
349*da0073e9SAndroid Build Coastguard Worker            "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
350*da0073e9SAndroid Build Coastguard Worker        )
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker    num_input_fmaps = tensor.size(1)
353*da0073e9SAndroid Build Coastguard Worker    num_output_fmaps = tensor.size(0)
354*da0073e9SAndroid Build Coastguard Worker    receptive_field_size = 1
355*da0073e9SAndroid Build Coastguard Worker    if tensor.dim() > 2:
356*da0073e9SAndroid Build Coastguard Worker        # math.prod is not always available, accumulate the product manually
357*da0073e9SAndroid Build Coastguard Worker        # we could use functools.reduce but that is not supported by TorchScript
358*da0073e9SAndroid Build Coastguard Worker        for s in tensor.shape[2:]:
359*da0073e9SAndroid Build Coastguard Worker            receptive_field_size *= s
360*da0073e9SAndroid Build Coastguard Worker    fan_in = num_input_fmaps * receptive_field_size
361*da0073e9SAndroid Build Coastguard Worker    fan_out = num_output_fmaps * receptive_field_size
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker    return fan_in, fan_out
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Workerdef xavier_uniform_(
367*da0073e9SAndroid Build Coastguard Worker    tensor: Tensor,
368*da0073e9SAndroid Build Coastguard Worker    gain: float = 1.0,
369*da0073e9SAndroid Build Coastguard Worker    generator: _Optional[torch.Generator] = None,
370*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
371*da0073e9SAndroid Build Coastguard Worker    r"""Fill the input `Tensor` with values using a Xavier uniform distribution.
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker    The method is described in `Understanding the difficulty of training
374*da0073e9SAndroid Build Coastguard Worker    deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010).
375*da0073e9SAndroid Build Coastguard Worker    The resulting tensor will have values sampled from
376*da0073e9SAndroid Build Coastguard Worker    :math:`\mathcal{U}(-a, a)` where
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker    .. math::
379*da0073e9SAndroid Build Coastguard Worker        a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker    Also known as Glorot initialization.
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker    Args:
384*da0073e9SAndroid Build Coastguard Worker        tensor: an n-dimensional `torch.Tensor`
385*da0073e9SAndroid Build Coastguard Worker        gain: an optional scaling factor
386*da0073e9SAndroid Build Coastguard Worker        generator: the torch Generator to sample from (default: None)
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker    Examples:
389*da0073e9SAndroid Build Coastguard Worker        >>> w = torch.empty(3, 5)
390*da0073e9SAndroid Build Coastguard Worker        >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Worker    Note:
393*da0073e9SAndroid Build Coastguard Worker        Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
394*da0073e9SAndroid Build Coastguard Worker        that the weight matrix is used in a transposed manner,
395*da0073e9SAndroid Build Coastguard Worker        (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``).
396*da0073e9SAndroid Build Coastguard Worker        This is important for correct initialization.
397*da0073e9SAndroid Build Coastguard Worker        If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``,
398*da0073e9SAndroid Build Coastguard Worker        pass in a transposed weight matrix, i.e. ``nn.init.xavier_uniform_(w.T, ...)``.
399*da0073e9SAndroid Build Coastguard Worker    """
400*da0073e9SAndroid Build Coastguard Worker    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
401*da0073e9SAndroid Build Coastguard Worker    std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
402*da0073e9SAndroid Build Coastguard Worker    a = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker    return _no_grad_uniform_(tensor, -a, a, generator)
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Workerdef xavier_normal_(
408*da0073e9SAndroid Build Coastguard Worker    tensor: Tensor,
409*da0073e9SAndroid Build Coastguard Worker    gain: float = 1.0,
410*da0073e9SAndroid Build Coastguard Worker    generator: _Optional[torch.Generator] = None,
411*da0073e9SAndroid Build Coastguard Worker) -> Tensor:
412*da0073e9SAndroid Build Coastguard Worker    r"""Fill the input `Tensor` with values using a Xavier normal distribution.
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker    The method is described in `Understanding the difficulty of training deep feedforward
415*da0073e9SAndroid Build Coastguard Worker    neural networks` - Glorot, X. & Bengio, Y. (2010). The resulting tensor
416*da0073e9SAndroid Build Coastguard Worker    will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Worker    .. math::
419*da0073e9SAndroid Build Coastguard Worker        \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker    Also known as Glorot initialization.
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Worker    Args:
424*da0073e9SAndroid Build Coastguard Worker        tensor: an n-dimensional `torch.Tensor`
425*da0073e9SAndroid Build Coastguard Worker        gain: an optional scaling factor
426*da0073e9SAndroid Build Coastguard Worker        generator: the torch Generator to sample from (default: None)
427*da0073e9SAndroid Build Coastguard Worker
428*da0073e9SAndroid Build Coastguard Worker    Examples:
429*da0073e9SAndroid Build Coastguard Worker        >>> w = torch.empty(3, 5)
430*da0073e9SAndroid Build Coastguard Worker        >>> nn.init.xavier_normal_(w)
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker    Note:
433*da0073e9SAndroid Build Coastguard Worker        Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
434*da0073e9SAndroid Build Coastguard Worker        that the weight matrix is used in a transposed manner,
435*da0073e9SAndroid Build Coastguard Worker        (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``).
436*da0073e9SAndroid Build Coastguard Worker        This is important for correct initialization.
437*da0073e9SAndroid Build Coastguard Worker        If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``,
438*da0073e9SAndroid Build Coastguard Worker        pass in a transposed weight matrix, i.e. ``nn.init.xavier_normal_(w.T, ...)``.
439*da0073e9SAndroid Build Coastguard Worker    """
440*da0073e9SAndroid Build Coastguard Worker    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
441*da0073e9SAndroid Build Coastguard Worker    std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker    return _no_grad_normal_(tensor, 0.0, std, generator)
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Workerdef _calculate_correct_fan(tensor, mode):
447*da0073e9SAndroid Build Coastguard Worker    mode = mode.lower()
448*da0073e9SAndroid Build Coastguard Worker    valid_modes = ["fan_in", "fan_out"]
449*da0073e9SAndroid Build Coastguard Worker    if mode not in valid_modes:
450*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}")
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
453*da0073e9SAndroid Build Coastguard Worker    return fan_in if mode == "fan_in" else fan_out
454*da0073e9SAndroid Build Coastguard Worker
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Workerdef kaiming_uniform_(
457*da0073e9SAndroid Build Coastguard Worker    tensor: Tensor,
458*da0073e9SAndroid Build Coastguard Worker    a: float = 0,
459*da0073e9SAndroid Build Coastguard Worker    mode: str = "fan_in",
460*da0073e9SAndroid Build Coastguard Worker    nonlinearity: str = "leaky_relu",
461*da0073e9SAndroid Build Coastguard Worker    generator: _Optional[torch.Generator] = None,
462*da0073e9SAndroid Build Coastguard Worker):
463*da0073e9SAndroid Build Coastguard Worker    r"""Fill the input `Tensor` with values using a Kaiming uniform distribution.
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker    The method is described in `Delving deep into rectifiers: Surpassing
466*da0073e9SAndroid Build Coastguard Worker    human-level performance on ImageNet classification` - He, K. et al. (2015).
467*da0073e9SAndroid Build Coastguard Worker    The resulting tensor will have values sampled from
468*da0073e9SAndroid Build Coastguard Worker    :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker    .. math::
471*da0073e9SAndroid Build Coastguard Worker        \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
472*da0073e9SAndroid Build Coastguard Worker
473*da0073e9SAndroid Build Coastguard Worker    Also known as He initialization.
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker    Args:
476*da0073e9SAndroid Build Coastguard Worker        tensor: an n-dimensional `torch.Tensor`
477*da0073e9SAndroid Build Coastguard Worker        a: the negative slope of the rectifier used after this layer (only
478*da0073e9SAndroid Build Coastguard Worker            used with ``'leaky_relu'``)
479*da0073e9SAndroid Build Coastguard Worker        mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
480*da0073e9SAndroid Build Coastguard Worker            preserves the magnitude of the variance of the weights in the
481*da0073e9SAndroid Build Coastguard Worker            forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
482*da0073e9SAndroid Build Coastguard Worker            backwards pass.
483*da0073e9SAndroid Build Coastguard Worker        nonlinearity: the non-linear function (`nn.functional` name),
484*da0073e9SAndroid Build Coastguard Worker            recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
485*da0073e9SAndroid Build Coastguard Worker        generator: the torch Generator to sample from (default: None)
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker    Examples:
488*da0073e9SAndroid Build Coastguard Worker        >>> w = torch.empty(3, 5)
489*da0073e9SAndroid Build Coastguard Worker        >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker    Note:
492*da0073e9SAndroid Build Coastguard Worker        Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
493*da0073e9SAndroid Build Coastguard Worker        that the weight matrix is used in a transposed manner,
494*da0073e9SAndroid Build Coastguard Worker        (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``).
495*da0073e9SAndroid Build Coastguard Worker        This is important for correct initialization.
496*da0073e9SAndroid Build Coastguard Worker        If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``,
497*da0073e9SAndroid Build Coastguard Worker        pass in a transposed weight matrix, i.e. ``nn.init.kaiming_uniform_(w.T, ...)``.
498*da0073e9SAndroid Build Coastguard Worker    """
499*da0073e9SAndroid Build Coastguard Worker    if torch.overrides.has_torch_function_variadic(tensor):
500*da0073e9SAndroid Build Coastguard Worker        return torch.overrides.handle_torch_function(
501*da0073e9SAndroid Build Coastguard Worker            kaiming_uniform_,
502*da0073e9SAndroid Build Coastguard Worker            (tensor,),
503*da0073e9SAndroid Build Coastguard Worker            tensor=tensor,
504*da0073e9SAndroid Build Coastguard Worker            a=a,
505*da0073e9SAndroid Build Coastguard Worker            mode=mode,
506*da0073e9SAndroid Build Coastguard Worker            nonlinearity=nonlinearity,
507*da0073e9SAndroid Build Coastguard Worker            generator=generator,
508*da0073e9SAndroid Build Coastguard Worker        )
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Worker    if 0 in tensor.shape:
511*da0073e9SAndroid Build Coastguard Worker        warnings.warn("Initializing zero-element tensors is a no-op")
512*da0073e9SAndroid Build Coastguard Worker        return tensor
513*da0073e9SAndroid Build Coastguard Worker    fan = _calculate_correct_fan(tensor, mode)
514*da0073e9SAndroid Build Coastguard Worker    gain = calculate_gain(nonlinearity, a)
515*da0073e9SAndroid Build Coastguard Worker    std = gain / math.sqrt(fan)
516*da0073e9SAndroid Build Coastguard Worker    bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
517*da0073e9SAndroid Build Coastguard Worker    with torch.no_grad():
518*da0073e9SAndroid Build Coastguard Worker        return tensor.uniform_(-bound, bound, generator=generator)
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Workerdef kaiming_normal_(
522*da0073e9SAndroid Build Coastguard Worker    tensor: Tensor,
523*da0073e9SAndroid Build Coastguard Worker    a: float = 0,
524*da0073e9SAndroid Build Coastguard Worker    mode: str = "fan_in",
525*da0073e9SAndroid Build Coastguard Worker    nonlinearity: str = "leaky_relu",
526*da0073e9SAndroid Build Coastguard Worker    generator: _Optional[torch.Generator] = None,
527*da0073e9SAndroid Build Coastguard Worker):
528*da0073e9SAndroid Build Coastguard Worker    r"""Fill the input `Tensor` with values using a Kaiming normal distribution.
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Worker    The method is described in `Delving deep into rectifiers: Surpassing
531*da0073e9SAndroid Build Coastguard Worker    human-level performance on ImageNet classification` - He, K. et al. (2015).
532*da0073e9SAndroid Build Coastguard Worker    The resulting tensor will have values sampled from
533*da0073e9SAndroid Build Coastguard Worker    :math:`\mathcal{N}(0, \text{std}^2)` where
534*da0073e9SAndroid Build Coastguard Worker
535*da0073e9SAndroid Build Coastguard Worker    .. math::
536*da0073e9SAndroid Build Coastguard Worker        \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker    Also known as He initialization.
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker    Args:
541*da0073e9SAndroid Build Coastguard Worker        tensor: an n-dimensional `torch.Tensor`
542*da0073e9SAndroid Build Coastguard Worker        a: the negative slope of the rectifier used after this layer (only
543*da0073e9SAndroid Build Coastguard Worker            used with ``'leaky_relu'``)
544*da0073e9SAndroid Build Coastguard Worker        mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
545*da0073e9SAndroid Build Coastguard Worker            preserves the magnitude of the variance of the weights in the
546*da0073e9SAndroid Build Coastguard Worker            forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
547*da0073e9SAndroid Build Coastguard Worker            backwards pass.
548*da0073e9SAndroid Build Coastguard Worker        nonlinearity: the non-linear function (`nn.functional` name),
549*da0073e9SAndroid Build Coastguard Worker            recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
550*da0073e9SAndroid Build Coastguard Worker        generator: the torch Generator to sample from (default: None)
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker    Examples:
553*da0073e9SAndroid Build Coastguard Worker        >>> w = torch.empty(3, 5)
554*da0073e9SAndroid Build Coastguard Worker        >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker    Note:
557*da0073e9SAndroid Build Coastguard Worker        Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
558*da0073e9SAndroid Build Coastguard Worker        that the weight matrix is used in a transposed manner,
559*da0073e9SAndroid Build Coastguard Worker        (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``).
560*da0073e9SAndroid Build Coastguard Worker        This is important for correct initialization.
561*da0073e9SAndroid Build Coastguard Worker        If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``,
562*da0073e9SAndroid Build Coastguard Worker        pass in a transposed weight matrix, i.e. ``nn.init.kaiming_normal_(w.T, ...)``.
563*da0073e9SAndroid Build Coastguard Worker    """
564*da0073e9SAndroid Build Coastguard Worker    if 0 in tensor.shape:
565*da0073e9SAndroid Build Coastguard Worker        warnings.warn("Initializing zero-element tensors is a no-op")
566*da0073e9SAndroid Build Coastguard Worker        return tensor
567*da0073e9SAndroid Build Coastguard Worker    fan = _calculate_correct_fan(tensor, mode)
568*da0073e9SAndroid Build Coastguard Worker    gain = calculate_gain(nonlinearity, a)
569*da0073e9SAndroid Build Coastguard Worker    std = gain / math.sqrt(fan)
570*da0073e9SAndroid Build Coastguard Worker    with torch.no_grad():
571*da0073e9SAndroid Build Coastguard Worker        return tensor.normal_(0, std, generator=generator)
572*da0073e9SAndroid Build Coastguard Worker
573*da0073e9SAndroid Build Coastguard Worker
574*da0073e9SAndroid Build Coastguard Workerdef orthogonal_(
575*da0073e9SAndroid Build Coastguard Worker    tensor,
576*da0073e9SAndroid Build Coastguard Worker    gain=1,
577*da0073e9SAndroid Build Coastguard Worker    generator: _Optional[torch.Generator] = None,
578*da0073e9SAndroid Build Coastguard Worker):
579*da0073e9SAndroid Build Coastguard Worker    r"""Fill the input `Tensor` with a (semi) orthogonal matrix.
580*da0073e9SAndroid Build Coastguard Worker
581*da0073e9SAndroid Build Coastguard Worker    Described in `Exact solutions to the nonlinear dynamics of learning in deep
582*da0073e9SAndroid Build Coastguard Worker    linear neural networks` - Saxe, A. et al. (2013). The input tensor must have
583*da0073e9SAndroid Build Coastguard Worker    at least 2 dimensions, and for tensors with more than 2 dimensions the
584*da0073e9SAndroid Build Coastguard Worker    trailing dimensions are flattened.
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Worker    Args:
587*da0073e9SAndroid Build Coastguard Worker        tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`
588*da0073e9SAndroid Build Coastguard Worker        gain: optional scaling factor
589*da0073e9SAndroid Build Coastguard Worker        generator: the torch Generator to sample from (default: None)
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Worker    Examples:
592*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
593*da0073e9SAndroid Build Coastguard Worker        >>> w = torch.empty(3, 5)
594*da0073e9SAndroid Build Coastguard Worker        >>> nn.init.orthogonal_(w)
595*da0073e9SAndroid Build Coastguard Worker    """
596*da0073e9SAndroid Build Coastguard Worker    if tensor.ndimension() < 2:
597*da0073e9SAndroid Build Coastguard Worker        raise ValueError("Only tensors with 2 or more dimensions are supported")
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker    if tensor.numel() == 0:
600*da0073e9SAndroid Build Coastguard Worker        # no-op
601*da0073e9SAndroid Build Coastguard Worker        return tensor
602*da0073e9SAndroid Build Coastguard Worker    rows = tensor.size(0)
603*da0073e9SAndroid Build Coastguard Worker    cols = tensor.numel() // rows
604*da0073e9SAndroid Build Coastguard Worker    flattened = tensor.new_empty((rows, cols)).normal_(0, 1, generator=generator)
605*da0073e9SAndroid Build Coastguard Worker
606*da0073e9SAndroid Build Coastguard Worker    if rows < cols:
607*da0073e9SAndroid Build Coastguard Worker        flattened.t_()
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker    # Compute the qr factorization
610*da0073e9SAndroid Build Coastguard Worker    q, r = torch.linalg.qr(flattened)
611*da0073e9SAndroid Build Coastguard Worker    # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
612*da0073e9SAndroid Build Coastguard Worker    d = torch.diag(r, 0)
613*da0073e9SAndroid Build Coastguard Worker    ph = d.sign()
614*da0073e9SAndroid Build Coastguard Worker    q *= ph
615*da0073e9SAndroid Build Coastguard Worker
616*da0073e9SAndroid Build Coastguard Worker    if rows < cols:
617*da0073e9SAndroid Build Coastguard Worker        q.t_()
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker    with torch.no_grad():
620*da0073e9SAndroid Build Coastguard Worker        tensor.view_as(q).copy_(q)
621*da0073e9SAndroid Build Coastguard Worker        tensor.mul_(gain)
622*da0073e9SAndroid Build Coastguard Worker    return tensor
623*da0073e9SAndroid Build Coastguard Worker
624*da0073e9SAndroid Build Coastguard Worker
625*da0073e9SAndroid Build Coastguard Workerdef sparse_(
626*da0073e9SAndroid Build Coastguard Worker    tensor,
627*da0073e9SAndroid Build Coastguard Worker    sparsity,
628*da0073e9SAndroid Build Coastguard Worker    std=0.01,
629*da0073e9SAndroid Build Coastguard Worker    generator: _Optional[torch.Generator] = None,
630*da0073e9SAndroid Build Coastguard Worker):
631*da0073e9SAndroid Build Coastguard Worker    r"""Fill the 2D input `Tensor` as a sparse matrix.
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Worker    The non-zero elements will be drawn from the normal distribution
634*da0073e9SAndroid Build Coastguard Worker    :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via
635*da0073e9SAndroid Build Coastguard Worker    Hessian-free optimization` - Martens, J. (2010).
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker    Args:
638*da0073e9SAndroid Build Coastguard Worker        tensor: an n-dimensional `torch.Tensor`
639*da0073e9SAndroid Build Coastguard Worker        sparsity: The fraction of elements in each column to be set to zero
640*da0073e9SAndroid Build Coastguard Worker        std: the standard deviation of the normal distribution used to generate
641*da0073e9SAndroid Build Coastguard Worker            the non-zero values
642*da0073e9SAndroid Build Coastguard Worker        generator: the torch Generator to sample from (default: None)
643*da0073e9SAndroid Build Coastguard Worker
644*da0073e9SAndroid Build Coastguard Worker    Examples:
645*da0073e9SAndroid Build Coastguard Worker        >>> w = torch.empty(3, 5)
646*da0073e9SAndroid Build Coastguard Worker        >>> nn.init.sparse_(w, sparsity=0.1)
647*da0073e9SAndroid Build Coastguard Worker    """
648*da0073e9SAndroid Build Coastguard Worker    if tensor.ndimension() != 2:
649*da0073e9SAndroid Build Coastguard Worker        raise ValueError("Only tensors with 2 dimensions are supported")
650*da0073e9SAndroid Build Coastguard Worker
651*da0073e9SAndroid Build Coastguard Worker    rows, cols = tensor.shape
652*da0073e9SAndroid Build Coastguard Worker    num_zeros = int(math.ceil(sparsity * rows))
653*da0073e9SAndroid Build Coastguard Worker
654*da0073e9SAndroid Build Coastguard Worker    with torch.no_grad():
655*da0073e9SAndroid Build Coastguard Worker        tensor.normal_(0, std, generator=generator)
656*da0073e9SAndroid Build Coastguard Worker        for col_idx in range(cols):
657*da0073e9SAndroid Build Coastguard Worker            row_indices = torch.randperm(rows)
658*da0073e9SAndroid Build Coastguard Worker            zero_indices = row_indices[:num_zeros]
659*da0073e9SAndroid Build Coastguard Worker            tensor[zero_indices, col_idx] = 0
660*da0073e9SAndroid Build Coastguard Worker    return tensor
661*da0073e9SAndroid Build Coastguard Worker
662*da0073e9SAndroid Build Coastguard Worker
663*da0073e9SAndroid Build Coastguard Worker# for backward compatibility
664*da0073e9SAndroid Build Coastguard Workerdef _make_deprecate(meth):
665*da0073e9SAndroid Build Coastguard Worker    new_name = meth.__name__
666*da0073e9SAndroid Build Coastguard Worker    old_name = new_name[:-1]
667*da0073e9SAndroid Build Coastguard Worker
668*da0073e9SAndroid Build Coastguard Worker    def deprecated_init(*args, **kwargs):
669*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
670*da0073e9SAndroid Build Coastguard Worker            f"`nn.init.{old_name}` is now deprecated in favor of `nn.init.{new_name}`.",
671*da0073e9SAndroid Build Coastguard Worker            FutureWarning,
672*da0073e9SAndroid Build Coastguard Worker            stacklevel=2,
673*da0073e9SAndroid Build Coastguard Worker        )
674*da0073e9SAndroid Build Coastguard Worker        return meth(*args, **kwargs)
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker    deprecated_init.__doc__ = rf"""
677*da0073e9SAndroid Build Coastguard Worker    {old_name}(...)
678*da0073e9SAndroid Build Coastguard Worker
679*da0073e9SAndroid Build Coastguard Worker    .. warning::
680*da0073e9SAndroid Build Coastguard Worker        This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`.
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker    See :func:`~torch.nn.init.{new_name}` for details."""
683*da0073e9SAndroid Build Coastguard Worker    deprecated_init.__name__ = old_name
684*da0073e9SAndroid Build Coastguard Worker    return deprecated_init
685*da0073e9SAndroid Build Coastguard Worker
686*da0073e9SAndroid Build Coastguard Worker
687*da0073e9SAndroid Build Coastguard Workeruniform = _make_deprecate(uniform_)
688*da0073e9SAndroid Build Coastguard Workernormal = _make_deprecate(normal_)
689*da0073e9SAndroid Build Coastguard Workerconstant = _make_deprecate(constant_)
690*da0073e9SAndroid Build Coastguard Workereye = _make_deprecate(eye_)
691*da0073e9SAndroid Build Coastguard Workerdirac = _make_deprecate(dirac_)
692*da0073e9SAndroid Build Coastguard Workerxavier_uniform = _make_deprecate(xavier_uniform_)
693*da0073e9SAndroid Build Coastguard Workerxavier_normal = _make_deprecate(xavier_normal_)
694*da0073e9SAndroid Build Coastguard Workerkaiming_uniform = _make_deprecate(kaiming_uniform_)
695*da0073e9SAndroid Build Coastguard Workerkaiming_normal = _make_deprecate(kaiming_normal_)
696*da0073e9SAndroid Build Coastguard Workerorthogonal = _make_deprecate(orthogonal_)
697*da0073e9SAndroid Build Coastguard Workersparse = _make_deprecate(sparse_)
698