1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Worker"""This file contains utilities for initializing neural network parameters.""" 3*da0073e9SAndroid Build Coastguard Workerimport math 4*da0073e9SAndroid Build Coastguard Workerimport warnings 5*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional as _Optional 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker# These no_grad_* functions are necessary as wrappers around the parts of these 12*da0073e9SAndroid Build Coastguard Worker# functions that use `with torch.no_grad()`. The JIT doesn't support context 13*da0073e9SAndroid Build Coastguard Worker# managers, so these need to be implemented as builtins. Using these wrappers 14*da0073e9SAndroid Build Coastguard Worker# lets us keep those builtins small and re-usable. 15*da0073e9SAndroid Build Coastguard Workerdef _no_grad_uniform_(tensor, a, b, generator=None): 16*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 17*da0073e9SAndroid Build Coastguard Worker return tensor.uniform_(a, b, generator=generator) 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workerdef _no_grad_normal_(tensor, mean, std, generator=None): 21*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 22*da0073e9SAndroid Build Coastguard Worker return tensor.normal_(mean, std, generator=generator) 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Workerdef _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None): 26*da0073e9SAndroid Build Coastguard Worker # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 27*da0073e9SAndroid Build Coastguard Worker def norm_cdf(x): 28*da0073e9SAndroid Build Coastguard Worker # Computes standard normal cumulative distribution function 29*da0073e9SAndroid Build Coastguard Worker return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker if (mean < a - 2 * std) or (mean > b + 2 * std): 32*da0073e9SAndroid Build Coastguard Worker warnings.warn( 33*da0073e9SAndroid Build Coastguard Worker "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 34*da0073e9SAndroid Build Coastguard Worker "The distribution of values may be incorrect.", 35*da0073e9SAndroid Build Coastguard Worker stacklevel=2, 36*da0073e9SAndroid Build Coastguard Worker ) 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 39*da0073e9SAndroid Build Coastguard Worker # Values are generated by using a truncated uniform distribution and 40*da0073e9SAndroid Build Coastguard Worker # then using the inverse CDF for the normal distribution. 41*da0073e9SAndroid Build Coastguard Worker # Get upper and lower cdf values 42*da0073e9SAndroid Build Coastguard Worker l = norm_cdf((a - mean) / std) 43*da0073e9SAndroid Build Coastguard Worker u = norm_cdf((b - mean) / std) 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker # Uniformly fill tensor with values from [l, u], then translate to 46*da0073e9SAndroid Build Coastguard Worker # [2l-1, 2u-1]. 47*da0073e9SAndroid Build Coastguard Worker tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker # Use inverse cdf transform for normal distribution to get truncated 50*da0073e9SAndroid Build Coastguard Worker # standard normal 51*da0073e9SAndroid Build Coastguard Worker tensor.erfinv_() 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker # Transform to proper mean, std 54*da0073e9SAndroid Build Coastguard Worker tensor.mul_(std * math.sqrt(2.0)) 55*da0073e9SAndroid Build Coastguard Worker tensor.add_(mean) 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker # Clamp to ensure it's in the proper range 58*da0073e9SAndroid Build Coastguard Worker tensor.clamp_(min=a, max=b) 59*da0073e9SAndroid Build Coastguard Worker return tensor 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Workerdef _no_grad_fill_(tensor, val): 63*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 64*da0073e9SAndroid Build Coastguard Worker return tensor.fill_(val) 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Workerdef _no_grad_zero_(tensor): 68*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 69*da0073e9SAndroid Build Coastguard Worker return tensor.zero_() 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Workerdef calculate_gain(nonlinearity, param=None): 73*da0073e9SAndroid Build Coastguard Worker r"""Return the recommended gain value for the given nonlinearity function. 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker The values are as follows: 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker ================= ==================================================== 78*da0073e9SAndroid Build Coastguard Worker nonlinearity gain 79*da0073e9SAndroid Build Coastguard Worker ================= ==================================================== 80*da0073e9SAndroid Build Coastguard Worker Linear / Identity :math:`1` 81*da0073e9SAndroid Build Coastguard Worker Conv{1,2,3}D :math:`1` 82*da0073e9SAndroid Build Coastguard Worker Sigmoid :math:`1` 83*da0073e9SAndroid Build Coastguard Worker Tanh :math:`\frac{5}{3}` 84*da0073e9SAndroid Build Coastguard Worker ReLU :math:`\sqrt{2}` 85*da0073e9SAndroid Build Coastguard Worker Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` 86*da0073e9SAndroid Build Coastguard Worker SELU :math:`\frac{3}{4}` 87*da0073e9SAndroid Build Coastguard Worker ================= ==================================================== 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker .. warning:: 90*da0073e9SAndroid Build Coastguard Worker In order to implement `Self-Normalizing Neural Networks`_ , 91*da0073e9SAndroid Build Coastguard Worker you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``. 92*da0073e9SAndroid Build Coastguard Worker This gives the initial weights a variance of ``1 / N``, 93*da0073e9SAndroid Build Coastguard Worker which is necessary to induce a stable fixed point in the forward pass. 94*da0073e9SAndroid Build Coastguard Worker In contrast, the default gain for ``SELU`` sacrifices the normalization 95*da0073e9SAndroid Build Coastguard Worker effect for more stable gradient flow in rectangular layers. 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker Args: 98*da0073e9SAndroid Build Coastguard Worker nonlinearity: the non-linear function (`nn.functional` name) 99*da0073e9SAndroid Build Coastguard Worker param: optional parameter for the non-linear function 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker Examples: 102*da0073e9SAndroid Build Coastguard Worker >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html 105*da0073e9SAndroid Build Coastguard Worker """ 106*da0073e9SAndroid Build Coastguard Worker linear_fns = [ 107*da0073e9SAndroid Build Coastguard Worker "linear", 108*da0073e9SAndroid Build Coastguard Worker "conv1d", 109*da0073e9SAndroid Build Coastguard Worker "conv2d", 110*da0073e9SAndroid Build Coastguard Worker "conv3d", 111*da0073e9SAndroid Build Coastguard Worker "conv_transpose1d", 112*da0073e9SAndroid Build Coastguard Worker "conv_transpose2d", 113*da0073e9SAndroid Build Coastguard Worker "conv_transpose3d", 114*da0073e9SAndroid Build Coastguard Worker ] 115*da0073e9SAndroid Build Coastguard Worker if nonlinearity in linear_fns or nonlinearity == "sigmoid": 116*da0073e9SAndroid Build Coastguard Worker return 1 117*da0073e9SAndroid Build Coastguard Worker elif nonlinearity == "tanh": 118*da0073e9SAndroid Build Coastguard Worker return 5.0 / 3 119*da0073e9SAndroid Build Coastguard Worker elif nonlinearity == "relu": 120*da0073e9SAndroid Build Coastguard Worker return math.sqrt(2.0) 121*da0073e9SAndroid Build Coastguard Worker elif nonlinearity == "leaky_relu": 122*da0073e9SAndroid Build Coastguard Worker if param is None: 123*da0073e9SAndroid Build Coastguard Worker negative_slope = 0.01 124*da0073e9SAndroid Build Coastguard Worker elif ( 125*da0073e9SAndroid Build Coastguard Worker not isinstance(param, bool) 126*da0073e9SAndroid Build Coastguard Worker and isinstance(param, int) 127*da0073e9SAndroid Build Coastguard Worker or isinstance(param, float) 128*da0073e9SAndroid Build Coastguard Worker ): 129*da0073e9SAndroid Build Coastguard Worker # True/False are instances of int, hence check above 130*da0073e9SAndroid Build Coastguard Worker negative_slope = param 131*da0073e9SAndroid Build Coastguard Worker else: 132*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"negative_slope {param} not a valid number") 133*da0073e9SAndroid Build Coastguard Worker return math.sqrt(2.0 / (1 + negative_slope**2)) 134*da0073e9SAndroid Build Coastguard Worker elif nonlinearity == "selu": 135*da0073e9SAndroid Build Coastguard Worker return ( 136*da0073e9SAndroid Build Coastguard Worker 3.0 / 4 137*da0073e9SAndroid Build Coastguard Worker ) # Value found empirically (https://github.com/pytorch/pytorch/pull/50664) 138*da0073e9SAndroid Build Coastguard Worker else: 139*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Unsupported nonlinearity {nonlinearity}") 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Workerdef uniform_( 143*da0073e9SAndroid Build Coastguard Worker tensor: Tensor, 144*da0073e9SAndroid Build Coastguard Worker a: float = 0.0, 145*da0073e9SAndroid Build Coastguard Worker b: float = 1.0, 146*da0073e9SAndroid Build Coastguard Worker generator: _Optional[torch.Generator] = None, 147*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 148*da0073e9SAndroid Build Coastguard Worker r"""Fill the input Tensor with values drawn from the uniform distribution. 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker :math:`\mathcal{U}(a, b)`. 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker Args: 153*da0073e9SAndroid Build Coastguard Worker tensor: an n-dimensional `torch.Tensor` 154*da0073e9SAndroid Build Coastguard Worker a: the lower bound of the uniform distribution 155*da0073e9SAndroid Build Coastguard Worker b: the upper bound of the uniform distribution 156*da0073e9SAndroid Build Coastguard Worker generator: the torch Generator to sample from (default: None) 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker Examples: 159*da0073e9SAndroid Build Coastguard Worker >>> w = torch.empty(3, 5) 160*da0073e9SAndroid Build Coastguard Worker >>> nn.init.uniform_(w) 161*da0073e9SAndroid Build Coastguard Worker """ 162*da0073e9SAndroid Build Coastguard Worker if torch.overrides.has_torch_function_variadic(tensor): 163*da0073e9SAndroid Build Coastguard Worker return torch.overrides.handle_torch_function( 164*da0073e9SAndroid Build Coastguard Worker uniform_, (tensor,), tensor=tensor, a=a, b=b, generator=generator 165*da0073e9SAndroid Build Coastguard Worker ) 166*da0073e9SAndroid Build Coastguard Worker return _no_grad_uniform_(tensor, a, b, generator) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Workerdef normal_( 170*da0073e9SAndroid Build Coastguard Worker tensor: Tensor, 171*da0073e9SAndroid Build Coastguard Worker mean: float = 0.0, 172*da0073e9SAndroid Build Coastguard Worker std: float = 1.0, 173*da0073e9SAndroid Build Coastguard Worker generator: _Optional[torch.Generator] = None, 174*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 175*da0073e9SAndroid Build Coastguard Worker r"""Fill the input Tensor with values drawn from the normal distribution. 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker Args: 180*da0073e9SAndroid Build Coastguard Worker tensor: an n-dimensional `torch.Tensor` 181*da0073e9SAndroid Build Coastguard Worker mean: the mean of the normal distribution 182*da0073e9SAndroid Build Coastguard Worker std: the standard deviation of the normal distribution 183*da0073e9SAndroid Build Coastguard Worker generator: the torch Generator to sample from (default: None) 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker Examples: 186*da0073e9SAndroid Build Coastguard Worker >>> w = torch.empty(3, 5) 187*da0073e9SAndroid Build Coastguard Worker >>> nn.init.normal_(w) 188*da0073e9SAndroid Build Coastguard Worker """ 189*da0073e9SAndroid Build Coastguard Worker if torch.overrides.has_torch_function_variadic(tensor): 190*da0073e9SAndroid Build Coastguard Worker return torch.overrides.handle_torch_function( 191*da0073e9SAndroid Build Coastguard Worker normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator 192*da0073e9SAndroid Build Coastguard Worker ) 193*da0073e9SAndroid Build Coastguard Worker return _no_grad_normal_(tensor, mean, std, generator) 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Workerdef trunc_normal_( 197*da0073e9SAndroid Build Coastguard Worker tensor: Tensor, 198*da0073e9SAndroid Build Coastguard Worker mean: float = 0.0, 199*da0073e9SAndroid Build Coastguard Worker std: float = 1.0, 200*da0073e9SAndroid Build Coastguard Worker a: float = -2.0, 201*da0073e9SAndroid Build Coastguard Worker b: float = 2.0, 202*da0073e9SAndroid Build Coastguard Worker generator: _Optional[torch.Generator] = None, 203*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 204*da0073e9SAndroid Build Coastguard Worker r"""Fill the input Tensor with values drawn from a truncated normal distribution. 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker The values are effectively drawn from the 207*da0073e9SAndroid Build Coastguard Worker normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 208*da0073e9SAndroid Build Coastguard Worker with values outside :math:`[a, b]` redrawn until they are within 209*da0073e9SAndroid Build Coastguard Worker the bounds. The method used for generating the random values works 210*da0073e9SAndroid Build Coastguard Worker best when :math:`a \leq \text{mean} \leq b`. 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker Args: 213*da0073e9SAndroid Build Coastguard Worker tensor: an n-dimensional `torch.Tensor` 214*da0073e9SAndroid Build Coastguard Worker mean: the mean of the normal distribution 215*da0073e9SAndroid Build Coastguard Worker std: the standard deviation of the normal distribution 216*da0073e9SAndroid Build Coastguard Worker a: the minimum cutoff value 217*da0073e9SAndroid Build Coastguard Worker b: the maximum cutoff value 218*da0073e9SAndroid Build Coastguard Worker generator: the torch Generator to sample from (default: None) 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker Examples: 221*da0073e9SAndroid Build Coastguard Worker >>> w = torch.empty(3, 5) 222*da0073e9SAndroid Build Coastguard Worker >>> nn.init.trunc_normal_(w) 223*da0073e9SAndroid Build Coastguard Worker """ 224*da0073e9SAndroid Build Coastguard Worker return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator) 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Workerdef constant_(tensor: Tensor, val: float) -> Tensor: 228*da0073e9SAndroid Build Coastguard Worker r"""Fill the input Tensor with the value :math:`\text{val}`. 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker Args: 231*da0073e9SAndroid Build Coastguard Worker tensor: an n-dimensional `torch.Tensor` 232*da0073e9SAndroid Build Coastguard Worker val: the value to fill the tensor with 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker Examples: 235*da0073e9SAndroid Build Coastguard Worker >>> w = torch.empty(3, 5) 236*da0073e9SAndroid Build Coastguard Worker >>> nn.init.constant_(w, 0.3) 237*da0073e9SAndroid Build Coastguard Worker """ 238*da0073e9SAndroid Build Coastguard Worker if torch.overrides.has_torch_function_variadic(tensor): 239*da0073e9SAndroid Build Coastguard Worker return torch.overrides.handle_torch_function( 240*da0073e9SAndroid Build Coastguard Worker constant_, (tensor,), tensor=tensor, val=val 241*da0073e9SAndroid Build Coastguard Worker ) 242*da0073e9SAndroid Build Coastguard Worker return _no_grad_fill_(tensor, val) 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Workerdef ones_(tensor: Tensor) -> Tensor: 246*da0073e9SAndroid Build Coastguard Worker r"""Fill the input Tensor with the scalar value `1`. 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker Args: 249*da0073e9SAndroid Build Coastguard Worker tensor: an n-dimensional `torch.Tensor` 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker Examples: 252*da0073e9SAndroid Build Coastguard Worker >>> w = torch.empty(3, 5) 253*da0073e9SAndroid Build Coastguard Worker >>> nn.init.ones_(w) 254*da0073e9SAndroid Build Coastguard Worker """ 255*da0073e9SAndroid Build Coastguard Worker return _no_grad_fill_(tensor, 1.0) 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Workerdef zeros_(tensor: Tensor) -> Tensor: 259*da0073e9SAndroid Build Coastguard Worker r"""Fill the input Tensor with the scalar value `0`. 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker Args: 262*da0073e9SAndroid Build Coastguard Worker tensor: an n-dimensional `torch.Tensor` 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker Examples: 265*da0073e9SAndroid Build Coastguard Worker >>> w = torch.empty(3, 5) 266*da0073e9SAndroid Build Coastguard Worker >>> nn.init.zeros_(w) 267*da0073e9SAndroid Build Coastguard Worker """ 268*da0073e9SAndroid Build Coastguard Worker return _no_grad_zero_(tensor) 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Workerdef eye_(tensor): 272*da0073e9SAndroid Build Coastguard Worker r"""Fill the 2-dimensional input `Tensor` with the identity matrix. 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker Preserves the identity of the inputs in `Linear` layers, where as 275*da0073e9SAndroid Build Coastguard Worker many inputs are preserved as possible. 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker Args: 278*da0073e9SAndroid Build Coastguard Worker tensor: a 2-dimensional `torch.Tensor` 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker Examples: 281*da0073e9SAndroid Build Coastguard Worker >>> w = torch.empty(3, 5) 282*da0073e9SAndroid Build Coastguard Worker >>> nn.init.eye_(w) 283*da0073e9SAndroid Build Coastguard Worker """ 284*da0073e9SAndroid Build Coastguard Worker if tensor.ndimension() != 2: 285*da0073e9SAndroid Build Coastguard Worker raise ValueError("Only tensors with 2 dimensions are supported") 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 288*da0073e9SAndroid Build Coastguard Worker torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad) 289*da0073e9SAndroid Build Coastguard Worker return tensor 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Workerdef dirac_(tensor, groups=1): 293*da0073e9SAndroid Build Coastguard Worker r"""Fill the {3, 4, 5}-dimensional input `Tensor` with the Dirac delta function. 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker Preserves the identity of the inputs in `Convolutional` 296*da0073e9SAndroid Build Coastguard Worker layers, where as many input channels are preserved as possible. In case 297*da0073e9SAndroid Build Coastguard Worker of groups>1, each group of channels preserves identity 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker Args: 300*da0073e9SAndroid Build Coastguard Worker tensor: a {3, 4, 5}-dimensional `torch.Tensor` 301*da0073e9SAndroid Build Coastguard Worker groups (int, optional): number of groups in the conv layer (default: 1) 302*da0073e9SAndroid Build Coastguard Worker Examples: 303*da0073e9SAndroid Build Coastguard Worker >>> w = torch.empty(3, 16, 5, 5) 304*da0073e9SAndroid Build Coastguard Worker >>> nn.init.dirac_(w) 305*da0073e9SAndroid Build Coastguard Worker >>> w = torch.empty(3, 24, 5, 5) 306*da0073e9SAndroid Build Coastguard Worker >>> nn.init.dirac_(w, 3) 307*da0073e9SAndroid Build Coastguard Worker """ 308*da0073e9SAndroid Build Coastguard Worker dimensions = tensor.ndimension() 309*da0073e9SAndroid Build Coastguard Worker if dimensions not in [3, 4, 5]: 310*da0073e9SAndroid Build Coastguard Worker raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported") 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker sizes = tensor.size() 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker if sizes[0] % groups != 0: 315*da0073e9SAndroid Build Coastguard Worker raise ValueError("dim 0 must be divisible by groups") 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker out_chans_per_grp = sizes[0] // groups 318*da0073e9SAndroid Build Coastguard Worker min_dim = min(out_chans_per_grp, sizes[1]) 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 321*da0073e9SAndroid Build Coastguard Worker tensor.zero_() 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker for g in range(groups): 324*da0073e9SAndroid Build Coastguard Worker for d in range(min_dim): 325*da0073e9SAndroid Build Coastguard Worker if dimensions == 3: # Temporal convolution 326*da0073e9SAndroid Build Coastguard Worker tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1 327*da0073e9SAndroid Build Coastguard Worker elif dimensions == 4: # Spatial convolution 328*da0073e9SAndroid Build Coastguard Worker tensor[ 329*da0073e9SAndroid Build Coastguard Worker g * out_chans_per_grp + d, 330*da0073e9SAndroid Build Coastguard Worker d, 331*da0073e9SAndroid Build Coastguard Worker tensor.size(2) // 2, 332*da0073e9SAndroid Build Coastguard Worker tensor.size(3) // 2, 333*da0073e9SAndroid Build Coastguard Worker ] = 1 334*da0073e9SAndroid Build Coastguard Worker else: # Volumetric convolution 335*da0073e9SAndroid Build Coastguard Worker tensor[ 336*da0073e9SAndroid Build Coastguard Worker g * out_chans_per_grp + d, 337*da0073e9SAndroid Build Coastguard Worker d, 338*da0073e9SAndroid Build Coastguard Worker tensor.size(2) // 2, 339*da0073e9SAndroid Build Coastguard Worker tensor.size(3) // 2, 340*da0073e9SAndroid Build Coastguard Worker tensor.size(4) // 2, 341*da0073e9SAndroid Build Coastguard Worker ] = 1 342*da0073e9SAndroid Build Coastguard Worker return tensor 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Workerdef _calculate_fan_in_and_fan_out(tensor): 346*da0073e9SAndroid Build Coastguard Worker dimensions = tensor.dim() 347*da0073e9SAndroid Build Coastguard Worker if dimensions < 2: 348*da0073e9SAndroid Build Coastguard Worker raise ValueError( 349*da0073e9SAndroid Build Coastguard Worker "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" 350*da0073e9SAndroid Build Coastguard Worker ) 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker num_input_fmaps = tensor.size(1) 353*da0073e9SAndroid Build Coastguard Worker num_output_fmaps = tensor.size(0) 354*da0073e9SAndroid Build Coastguard Worker receptive_field_size = 1 355*da0073e9SAndroid Build Coastguard Worker if tensor.dim() > 2: 356*da0073e9SAndroid Build Coastguard Worker # math.prod is not always available, accumulate the product manually 357*da0073e9SAndroid Build Coastguard Worker # we could use functools.reduce but that is not supported by TorchScript 358*da0073e9SAndroid Build Coastguard Worker for s in tensor.shape[2:]: 359*da0073e9SAndroid Build Coastguard Worker receptive_field_size *= s 360*da0073e9SAndroid Build Coastguard Worker fan_in = num_input_fmaps * receptive_field_size 361*da0073e9SAndroid Build Coastguard Worker fan_out = num_output_fmaps * receptive_field_size 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker return fan_in, fan_out 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Workerdef xavier_uniform_( 367*da0073e9SAndroid Build Coastguard Worker tensor: Tensor, 368*da0073e9SAndroid Build Coastguard Worker gain: float = 1.0, 369*da0073e9SAndroid Build Coastguard Worker generator: _Optional[torch.Generator] = None, 370*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 371*da0073e9SAndroid Build Coastguard Worker r"""Fill the input `Tensor` with values using a Xavier uniform distribution. 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker The method is described in `Understanding the difficulty of training 374*da0073e9SAndroid Build Coastguard Worker deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010). 375*da0073e9SAndroid Build Coastguard Worker The resulting tensor will have values sampled from 376*da0073e9SAndroid Build Coastguard Worker :math:`\mathcal{U}(-a, a)` where 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker .. math:: 379*da0073e9SAndroid Build Coastguard Worker a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker Also known as Glorot initialization. 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker Args: 384*da0073e9SAndroid Build Coastguard Worker tensor: an n-dimensional `torch.Tensor` 385*da0073e9SAndroid Build Coastguard Worker gain: an optional scaling factor 386*da0073e9SAndroid Build Coastguard Worker generator: the torch Generator to sample from (default: None) 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker Examples: 389*da0073e9SAndroid Build Coastguard Worker >>> w = torch.empty(3, 5) 390*da0073e9SAndroid Build Coastguard Worker >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')) 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Worker Note: 393*da0073e9SAndroid Build Coastguard Worker Be aware that ``fan_in`` and ``fan_out`` are calculated assuming 394*da0073e9SAndroid Build Coastguard Worker that the weight matrix is used in a transposed manner, 395*da0073e9SAndroid Build Coastguard Worker (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). 396*da0073e9SAndroid Build Coastguard Worker This is important for correct initialization. 397*da0073e9SAndroid Build Coastguard Worker If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, 398*da0073e9SAndroid Build Coastguard Worker pass in a transposed weight matrix, i.e. ``nn.init.xavier_uniform_(w.T, ...)``. 399*da0073e9SAndroid Build Coastguard Worker """ 400*da0073e9SAndroid Build Coastguard Worker fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 401*da0073e9SAndroid Build Coastguard Worker std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) 402*da0073e9SAndroid Build Coastguard Worker a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker return _no_grad_uniform_(tensor, -a, a, generator) 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Workerdef xavier_normal_( 408*da0073e9SAndroid Build Coastguard Worker tensor: Tensor, 409*da0073e9SAndroid Build Coastguard Worker gain: float = 1.0, 410*da0073e9SAndroid Build Coastguard Worker generator: _Optional[torch.Generator] = None, 411*da0073e9SAndroid Build Coastguard Worker) -> Tensor: 412*da0073e9SAndroid Build Coastguard Worker r"""Fill the input `Tensor` with values using a Xavier normal distribution. 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker The method is described in `Understanding the difficulty of training deep feedforward 415*da0073e9SAndroid Build Coastguard Worker neural networks` - Glorot, X. & Bengio, Y. (2010). The resulting tensor 416*da0073e9SAndroid Build Coastguard Worker will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Worker .. math:: 419*da0073e9SAndroid Build Coastguard Worker \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}} 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker Also known as Glorot initialization. 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker Args: 424*da0073e9SAndroid Build Coastguard Worker tensor: an n-dimensional `torch.Tensor` 425*da0073e9SAndroid Build Coastguard Worker gain: an optional scaling factor 426*da0073e9SAndroid Build Coastguard Worker generator: the torch Generator to sample from (default: None) 427*da0073e9SAndroid Build Coastguard Worker 428*da0073e9SAndroid Build Coastguard Worker Examples: 429*da0073e9SAndroid Build Coastguard Worker >>> w = torch.empty(3, 5) 430*da0073e9SAndroid Build Coastguard Worker >>> nn.init.xavier_normal_(w) 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker Note: 433*da0073e9SAndroid Build Coastguard Worker Be aware that ``fan_in`` and ``fan_out`` are calculated assuming 434*da0073e9SAndroid Build Coastguard Worker that the weight matrix is used in a transposed manner, 435*da0073e9SAndroid Build Coastguard Worker (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). 436*da0073e9SAndroid Build Coastguard Worker This is important for correct initialization. 437*da0073e9SAndroid Build Coastguard Worker If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, 438*da0073e9SAndroid Build Coastguard Worker pass in a transposed weight matrix, i.e. ``nn.init.xavier_normal_(w.T, ...)``. 439*da0073e9SAndroid Build Coastguard Worker """ 440*da0073e9SAndroid Build Coastguard Worker fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 441*da0073e9SAndroid Build Coastguard Worker std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker return _no_grad_normal_(tensor, 0.0, std, generator) 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Workerdef _calculate_correct_fan(tensor, mode): 447*da0073e9SAndroid Build Coastguard Worker mode = mode.lower() 448*da0073e9SAndroid Build Coastguard Worker valid_modes = ["fan_in", "fan_out"] 449*da0073e9SAndroid Build Coastguard Worker if mode not in valid_modes: 450*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}") 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 453*da0073e9SAndroid Build Coastguard Worker return fan_in if mode == "fan_in" else fan_out 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Workerdef kaiming_uniform_( 457*da0073e9SAndroid Build Coastguard Worker tensor: Tensor, 458*da0073e9SAndroid Build Coastguard Worker a: float = 0, 459*da0073e9SAndroid Build Coastguard Worker mode: str = "fan_in", 460*da0073e9SAndroid Build Coastguard Worker nonlinearity: str = "leaky_relu", 461*da0073e9SAndroid Build Coastguard Worker generator: _Optional[torch.Generator] = None, 462*da0073e9SAndroid Build Coastguard Worker): 463*da0073e9SAndroid Build Coastguard Worker r"""Fill the input `Tensor` with values using a Kaiming uniform distribution. 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker The method is described in `Delving deep into rectifiers: Surpassing 466*da0073e9SAndroid Build Coastguard Worker human-level performance on ImageNet classification` - He, K. et al. (2015). 467*da0073e9SAndroid Build Coastguard Worker The resulting tensor will have values sampled from 468*da0073e9SAndroid Build Coastguard Worker :math:`\mathcal{U}(-\text{bound}, \text{bound})` where 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker .. math:: 471*da0073e9SAndroid Build Coastguard Worker \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} 472*da0073e9SAndroid Build Coastguard Worker 473*da0073e9SAndroid Build Coastguard Worker Also known as He initialization. 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker Args: 476*da0073e9SAndroid Build Coastguard Worker tensor: an n-dimensional `torch.Tensor` 477*da0073e9SAndroid Build Coastguard Worker a: the negative slope of the rectifier used after this layer (only 478*da0073e9SAndroid Build Coastguard Worker used with ``'leaky_relu'``) 479*da0073e9SAndroid Build Coastguard Worker mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` 480*da0073e9SAndroid Build Coastguard Worker preserves the magnitude of the variance of the weights in the 481*da0073e9SAndroid Build Coastguard Worker forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the 482*da0073e9SAndroid Build Coastguard Worker backwards pass. 483*da0073e9SAndroid Build Coastguard Worker nonlinearity: the non-linear function (`nn.functional` name), 484*da0073e9SAndroid Build Coastguard Worker recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). 485*da0073e9SAndroid Build Coastguard Worker generator: the torch Generator to sample from (default: None) 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker Examples: 488*da0073e9SAndroid Build Coastguard Worker >>> w = torch.empty(3, 5) 489*da0073e9SAndroid Build Coastguard Worker >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker Note: 492*da0073e9SAndroid Build Coastguard Worker Be aware that ``fan_in`` and ``fan_out`` are calculated assuming 493*da0073e9SAndroid Build Coastguard Worker that the weight matrix is used in a transposed manner, 494*da0073e9SAndroid Build Coastguard Worker (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). 495*da0073e9SAndroid Build Coastguard Worker This is important for correct initialization. 496*da0073e9SAndroid Build Coastguard Worker If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, 497*da0073e9SAndroid Build Coastguard Worker pass in a transposed weight matrix, i.e. ``nn.init.kaiming_uniform_(w.T, ...)``. 498*da0073e9SAndroid Build Coastguard Worker """ 499*da0073e9SAndroid Build Coastguard Worker if torch.overrides.has_torch_function_variadic(tensor): 500*da0073e9SAndroid Build Coastguard Worker return torch.overrides.handle_torch_function( 501*da0073e9SAndroid Build Coastguard Worker kaiming_uniform_, 502*da0073e9SAndroid Build Coastguard Worker (tensor,), 503*da0073e9SAndroid Build Coastguard Worker tensor=tensor, 504*da0073e9SAndroid Build Coastguard Worker a=a, 505*da0073e9SAndroid Build Coastguard Worker mode=mode, 506*da0073e9SAndroid Build Coastguard Worker nonlinearity=nonlinearity, 507*da0073e9SAndroid Build Coastguard Worker generator=generator, 508*da0073e9SAndroid Build Coastguard Worker ) 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Worker if 0 in tensor.shape: 511*da0073e9SAndroid Build Coastguard Worker warnings.warn("Initializing zero-element tensors is a no-op") 512*da0073e9SAndroid Build Coastguard Worker return tensor 513*da0073e9SAndroid Build Coastguard Worker fan = _calculate_correct_fan(tensor, mode) 514*da0073e9SAndroid Build Coastguard Worker gain = calculate_gain(nonlinearity, a) 515*da0073e9SAndroid Build Coastguard Worker std = gain / math.sqrt(fan) 516*da0073e9SAndroid Build Coastguard Worker bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation 517*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 518*da0073e9SAndroid Build Coastguard Worker return tensor.uniform_(-bound, bound, generator=generator) 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Workerdef kaiming_normal_( 522*da0073e9SAndroid Build Coastguard Worker tensor: Tensor, 523*da0073e9SAndroid Build Coastguard Worker a: float = 0, 524*da0073e9SAndroid Build Coastguard Worker mode: str = "fan_in", 525*da0073e9SAndroid Build Coastguard Worker nonlinearity: str = "leaky_relu", 526*da0073e9SAndroid Build Coastguard Worker generator: _Optional[torch.Generator] = None, 527*da0073e9SAndroid Build Coastguard Worker): 528*da0073e9SAndroid Build Coastguard Worker r"""Fill the input `Tensor` with values using a Kaiming normal distribution. 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker The method is described in `Delving deep into rectifiers: Surpassing 531*da0073e9SAndroid Build Coastguard Worker human-level performance on ImageNet classification` - He, K. et al. (2015). 532*da0073e9SAndroid Build Coastguard Worker The resulting tensor will have values sampled from 533*da0073e9SAndroid Build Coastguard Worker :math:`\mathcal{N}(0, \text{std}^2)` where 534*da0073e9SAndroid Build Coastguard Worker 535*da0073e9SAndroid Build Coastguard Worker .. math:: 536*da0073e9SAndroid Build Coastguard Worker \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker Also known as He initialization. 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker Args: 541*da0073e9SAndroid Build Coastguard Worker tensor: an n-dimensional `torch.Tensor` 542*da0073e9SAndroid Build Coastguard Worker a: the negative slope of the rectifier used after this layer (only 543*da0073e9SAndroid Build Coastguard Worker used with ``'leaky_relu'``) 544*da0073e9SAndroid Build Coastguard Worker mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` 545*da0073e9SAndroid Build Coastguard Worker preserves the magnitude of the variance of the weights in the 546*da0073e9SAndroid Build Coastguard Worker forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the 547*da0073e9SAndroid Build Coastguard Worker backwards pass. 548*da0073e9SAndroid Build Coastguard Worker nonlinearity: the non-linear function (`nn.functional` name), 549*da0073e9SAndroid Build Coastguard Worker recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). 550*da0073e9SAndroid Build Coastguard Worker generator: the torch Generator to sample from (default: None) 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker Examples: 553*da0073e9SAndroid Build Coastguard Worker >>> w = torch.empty(3, 5) 554*da0073e9SAndroid Build Coastguard Worker >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker Note: 557*da0073e9SAndroid Build Coastguard Worker Be aware that ``fan_in`` and ``fan_out`` are calculated assuming 558*da0073e9SAndroid Build Coastguard Worker that the weight matrix is used in a transposed manner, 559*da0073e9SAndroid Build Coastguard Worker (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). 560*da0073e9SAndroid Build Coastguard Worker This is important for correct initialization. 561*da0073e9SAndroid Build Coastguard Worker If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, 562*da0073e9SAndroid Build Coastguard Worker pass in a transposed weight matrix, i.e. ``nn.init.kaiming_normal_(w.T, ...)``. 563*da0073e9SAndroid Build Coastguard Worker """ 564*da0073e9SAndroid Build Coastguard Worker if 0 in tensor.shape: 565*da0073e9SAndroid Build Coastguard Worker warnings.warn("Initializing zero-element tensors is a no-op") 566*da0073e9SAndroid Build Coastguard Worker return tensor 567*da0073e9SAndroid Build Coastguard Worker fan = _calculate_correct_fan(tensor, mode) 568*da0073e9SAndroid Build Coastguard Worker gain = calculate_gain(nonlinearity, a) 569*da0073e9SAndroid Build Coastguard Worker std = gain / math.sqrt(fan) 570*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 571*da0073e9SAndroid Build Coastguard Worker return tensor.normal_(0, std, generator=generator) 572*da0073e9SAndroid Build Coastguard Worker 573*da0073e9SAndroid Build Coastguard Worker 574*da0073e9SAndroid Build Coastguard Workerdef orthogonal_( 575*da0073e9SAndroid Build Coastguard Worker tensor, 576*da0073e9SAndroid Build Coastguard Worker gain=1, 577*da0073e9SAndroid Build Coastguard Worker generator: _Optional[torch.Generator] = None, 578*da0073e9SAndroid Build Coastguard Worker): 579*da0073e9SAndroid Build Coastguard Worker r"""Fill the input `Tensor` with a (semi) orthogonal matrix. 580*da0073e9SAndroid Build Coastguard Worker 581*da0073e9SAndroid Build Coastguard Worker Described in `Exact solutions to the nonlinear dynamics of learning in deep 582*da0073e9SAndroid Build Coastguard Worker linear neural networks` - Saxe, A. et al. (2013). The input tensor must have 583*da0073e9SAndroid Build Coastguard Worker at least 2 dimensions, and for tensors with more than 2 dimensions the 584*da0073e9SAndroid Build Coastguard Worker trailing dimensions are flattened. 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Worker Args: 587*da0073e9SAndroid Build Coastguard Worker tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2` 588*da0073e9SAndroid Build Coastguard Worker gain: optional scaling factor 589*da0073e9SAndroid Build Coastguard Worker generator: the torch Generator to sample from (default: None) 590*da0073e9SAndroid Build Coastguard Worker 591*da0073e9SAndroid Build Coastguard Worker Examples: 592*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) 593*da0073e9SAndroid Build Coastguard Worker >>> w = torch.empty(3, 5) 594*da0073e9SAndroid Build Coastguard Worker >>> nn.init.orthogonal_(w) 595*da0073e9SAndroid Build Coastguard Worker """ 596*da0073e9SAndroid Build Coastguard Worker if tensor.ndimension() < 2: 597*da0073e9SAndroid Build Coastguard Worker raise ValueError("Only tensors with 2 or more dimensions are supported") 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker if tensor.numel() == 0: 600*da0073e9SAndroid Build Coastguard Worker # no-op 601*da0073e9SAndroid Build Coastguard Worker return tensor 602*da0073e9SAndroid Build Coastguard Worker rows = tensor.size(0) 603*da0073e9SAndroid Build Coastguard Worker cols = tensor.numel() // rows 604*da0073e9SAndroid Build Coastguard Worker flattened = tensor.new_empty((rows, cols)).normal_(0, 1, generator=generator) 605*da0073e9SAndroid Build Coastguard Worker 606*da0073e9SAndroid Build Coastguard Worker if rows < cols: 607*da0073e9SAndroid Build Coastguard Worker flattened.t_() 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker # Compute the qr factorization 610*da0073e9SAndroid Build Coastguard Worker q, r = torch.linalg.qr(flattened) 611*da0073e9SAndroid Build Coastguard Worker # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf 612*da0073e9SAndroid Build Coastguard Worker d = torch.diag(r, 0) 613*da0073e9SAndroid Build Coastguard Worker ph = d.sign() 614*da0073e9SAndroid Build Coastguard Worker q *= ph 615*da0073e9SAndroid Build Coastguard Worker 616*da0073e9SAndroid Build Coastguard Worker if rows < cols: 617*da0073e9SAndroid Build Coastguard Worker q.t_() 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 620*da0073e9SAndroid Build Coastguard Worker tensor.view_as(q).copy_(q) 621*da0073e9SAndroid Build Coastguard Worker tensor.mul_(gain) 622*da0073e9SAndroid Build Coastguard Worker return tensor 623*da0073e9SAndroid Build Coastguard Worker 624*da0073e9SAndroid Build Coastguard Worker 625*da0073e9SAndroid Build Coastguard Workerdef sparse_( 626*da0073e9SAndroid Build Coastguard Worker tensor, 627*da0073e9SAndroid Build Coastguard Worker sparsity, 628*da0073e9SAndroid Build Coastguard Worker std=0.01, 629*da0073e9SAndroid Build Coastguard Worker generator: _Optional[torch.Generator] = None, 630*da0073e9SAndroid Build Coastguard Worker): 631*da0073e9SAndroid Build Coastguard Worker r"""Fill the 2D input `Tensor` as a sparse matrix. 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Worker The non-zero elements will be drawn from the normal distribution 634*da0073e9SAndroid Build Coastguard Worker :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via 635*da0073e9SAndroid Build Coastguard Worker Hessian-free optimization` - Martens, J. (2010). 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker Args: 638*da0073e9SAndroid Build Coastguard Worker tensor: an n-dimensional `torch.Tensor` 639*da0073e9SAndroid Build Coastguard Worker sparsity: The fraction of elements in each column to be set to zero 640*da0073e9SAndroid Build Coastguard Worker std: the standard deviation of the normal distribution used to generate 641*da0073e9SAndroid Build Coastguard Worker the non-zero values 642*da0073e9SAndroid Build Coastguard Worker generator: the torch Generator to sample from (default: None) 643*da0073e9SAndroid Build Coastguard Worker 644*da0073e9SAndroid Build Coastguard Worker Examples: 645*da0073e9SAndroid Build Coastguard Worker >>> w = torch.empty(3, 5) 646*da0073e9SAndroid Build Coastguard Worker >>> nn.init.sparse_(w, sparsity=0.1) 647*da0073e9SAndroid Build Coastguard Worker """ 648*da0073e9SAndroid Build Coastguard Worker if tensor.ndimension() != 2: 649*da0073e9SAndroid Build Coastguard Worker raise ValueError("Only tensors with 2 dimensions are supported") 650*da0073e9SAndroid Build Coastguard Worker 651*da0073e9SAndroid Build Coastguard Worker rows, cols = tensor.shape 652*da0073e9SAndroid Build Coastguard Worker num_zeros = int(math.ceil(sparsity * rows)) 653*da0073e9SAndroid Build Coastguard Worker 654*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 655*da0073e9SAndroid Build Coastguard Worker tensor.normal_(0, std, generator=generator) 656*da0073e9SAndroid Build Coastguard Worker for col_idx in range(cols): 657*da0073e9SAndroid Build Coastguard Worker row_indices = torch.randperm(rows) 658*da0073e9SAndroid Build Coastguard Worker zero_indices = row_indices[:num_zeros] 659*da0073e9SAndroid Build Coastguard Worker tensor[zero_indices, col_idx] = 0 660*da0073e9SAndroid Build Coastguard Worker return tensor 661*da0073e9SAndroid Build Coastguard Worker 662*da0073e9SAndroid Build Coastguard Worker 663*da0073e9SAndroid Build Coastguard Worker# for backward compatibility 664*da0073e9SAndroid Build Coastguard Workerdef _make_deprecate(meth): 665*da0073e9SAndroid Build Coastguard Worker new_name = meth.__name__ 666*da0073e9SAndroid Build Coastguard Worker old_name = new_name[:-1] 667*da0073e9SAndroid Build Coastguard Worker 668*da0073e9SAndroid Build Coastguard Worker def deprecated_init(*args, **kwargs): 669*da0073e9SAndroid Build Coastguard Worker warnings.warn( 670*da0073e9SAndroid Build Coastguard Worker f"`nn.init.{old_name}` is now deprecated in favor of `nn.init.{new_name}`.", 671*da0073e9SAndroid Build Coastguard Worker FutureWarning, 672*da0073e9SAndroid Build Coastguard Worker stacklevel=2, 673*da0073e9SAndroid Build Coastguard Worker ) 674*da0073e9SAndroid Build Coastguard Worker return meth(*args, **kwargs) 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker deprecated_init.__doc__ = rf""" 677*da0073e9SAndroid Build Coastguard Worker {old_name}(...) 678*da0073e9SAndroid Build Coastguard Worker 679*da0073e9SAndroid Build Coastguard Worker .. warning:: 680*da0073e9SAndroid Build Coastguard Worker This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`. 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker See :func:`~torch.nn.init.{new_name}` for details.""" 683*da0073e9SAndroid Build Coastguard Worker deprecated_init.__name__ = old_name 684*da0073e9SAndroid Build Coastguard Worker return deprecated_init 685*da0073e9SAndroid Build Coastguard Worker 686*da0073e9SAndroid Build Coastguard Worker 687*da0073e9SAndroid Build Coastguard Workeruniform = _make_deprecate(uniform_) 688*da0073e9SAndroid Build Coastguard Workernormal = _make_deprecate(normal_) 689*da0073e9SAndroid Build Coastguard Workerconstant = _make_deprecate(constant_) 690*da0073e9SAndroid Build Coastguard Workereye = _make_deprecate(eye_) 691*da0073e9SAndroid Build Coastguard Workerdirac = _make_deprecate(dirac_) 692*da0073e9SAndroid Build Coastguard Workerxavier_uniform = _make_deprecate(xavier_uniform_) 693*da0073e9SAndroid Build Coastguard Workerxavier_normal = _make_deprecate(xavier_normal_) 694*da0073e9SAndroid Build Coastguard Workerkaiming_uniform = _make_deprecate(kaiming_uniform_) 695*da0073e9SAndroid Build Coastguard Workerkaiming_normal = _make_deprecate(kaiming_normal_) 696*da0073e9SAndroid Build Coastguard Workerorthogonal = _make_deprecate(orthogonal_) 697*da0073e9SAndroid Build Coastguard Workersparse = _make_deprecate(sparse_) 698