1# mypy: allow-untyped-defs 2"""This file contains utilities for initializing neural network parameters.""" 3import math 4import warnings 5from typing import Optional as _Optional 6 7import torch 8from torch import Tensor 9 10 11# These no_grad_* functions are necessary as wrappers around the parts of these 12# functions that use `with torch.no_grad()`. The JIT doesn't support context 13# managers, so these need to be implemented as builtins. Using these wrappers 14# lets us keep those builtins small and re-usable. 15def _no_grad_uniform_(tensor, a, b, generator=None): 16 with torch.no_grad(): 17 return tensor.uniform_(a, b, generator=generator) 18 19 20def _no_grad_normal_(tensor, mean, std, generator=None): 21 with torch.no_grad(): 22 return tensor.normal_(mean, std, generator=generator) 23 24 25def _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None): 26 # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 27 def norm_cdf(x): 28 # Computes standard normal cumulative distribution function 29 return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 30 31 if (mean < a - 2 * std) or (mean > b + 2 * std): 32 warnings.warn( 33 "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 34 "The distribution of values may be incorrect.", 35 stacklevel=2, 36 ) 37 38 with torch.no_grad(): 39 # Values are generated by using a truncated uniform distribution and 40 # then using the inverse CDF for the normal distribution. 41 # Get upper and lower cdf values 42 l = norm_cdf((a - mean) / std) 43 u = norm_cdf((b - mean) / std) 44 45 # Uniformly fill tensor with values from [l, u], then translate to 46 # [2l-1, 2u-1]. 47 tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator) 48 49 # Use inverse cdf transform for normal distribution to get truncated 50 # standard normal 51 tensor.erfinv_() 52 53 # Transform to proper mean, std 54 tensor.mul_(std * math.sqrt(2.0)) 55 tensor.add_(mean) 56 57 # Clamp to ensure it's in the proper range 58 tensor.clamp_(min=a, max=b) 59 return tensor 60 61 62def _no_grad_fill_(tensor, val): 63 with torch.no_grad(): 64 return tensor.fill_(val) 65 66 67def _no_grad_zero_(tensor): 68 with torch.no_grad(): 69 return tensor.zero_() 70 71 72def calculate_gain(nonlinearity, param=None): 73 r"""Return the recommended gain value for the given nonlinearity function. 74 75 The values are as follows: 76 77 ================= ==================================================== 78 nonlinearity gain 79 ================= ==================================================== 80 Linear / Identity :math:`1` 81 Conv{1,2,3}D :math:`1` 82 Sigmoid :math:`1` 83 Tanh :math:`\frac{5}{3}` 84 ReLU :math:`\sqrt{2}` 85 Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` 86 SELU :math:`\frac{3}{4}` 87 ================= ==================================================== 88 89 .. warning:: 90 In order to implement `Self-Normalizing Neural Networks`_ , 91 you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``. 92 This gives the initial weights a variance of ``1 / N``, 93 which is necessary to induce a stable fixed point in the forward pass. 94 In contrast, the default gain for ``SELU`` sacrifices the normalization 95 effect for more stable gradient flow in rectangular layers. 96 97 Args: 98 nonlinearity: the non-linear function (`nn.functional` name) 99 param: optional parameter for the non-linear function 100 101 Examples: 102 >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 103 104 .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html 105 """ 106 linear_fns = [ 107 "linear", 108 "conv1d", 109 "conv2d", 110 "conv3d", 111 "conv_transpose1d", 112 "conv_transpose2d", 113 "conv_transpose3d", 114 ] 115 if nonlinearity in linear_fns or nonlinearity == "sigmoid": 116 return 1 117 elif nonlinearity == "tanh": 118 return 5.0 / 3 119 elif nonlinearity == "relu": 120 return math.sqrt(2.0) 121 elif nonlinearity == "leaky_relu": 122 if param is None: 123 negative_slope = 0.01 124 elif ( 125 not isinstance(param, bool) 126 and isinstance(param, int) 127 or isinstance(param, float) 128 ): 129 # True/False are instances of int, hence check above 130 negative_slope = param 131 else: 132 raise ValueError(f"negative_slope {param} not a valid number") 133 return math.sqrt(2.0 / (1 + negative_slope**2)) 134 elif nonlinearity == "selu": 135 return ( 136 3.0 / 4 137 ) # Value found empirically (https://github.com/pytorch/pytorch/pull/50664) 138 else: 139 raise ValueError(f"Unsupported nonlinearity {nonlinearity}") 140 141 142def uniform_( 143 tensor: Tensor, 144 a: float = 0.0, 145 b: float = 1.0, 146 generator: _Optional[torch.Generator] = None, 147) -> Tensor: 148 r"""Fill the input Tensor with values drawn from the uniform distribution. 149 150 :math:`\mathcal{U}(a, b)`. 151 152 Args: 153 tensor: an n-dimensional `torch.Tensor` 154 a: the lower bound of the uniform distribution 155 b: the upper bound of the uniform distribution 156 generator: the torch Generator to sample from (default: None) 157 158 Examples: 159 >>> w = torch.empty(3, 5) 160 >>> nn.init.uniform_(w) 161 """ 162 if torch.overrides.has_torch_function_variadic(tensor): 163 return torch.overrides.handle_torch_function( 164 uniform_, (tensor,), tensor=tensor, a=a, b=b, generator=generator 165 ) 166 return _no_grad_uniform_(tensor, a, b, generator) 167 168 169def normal_( 170 tensor: Tensor, 171 mean: float = 0.0, 172 std: float = 1.0, 173 generator: _Optional[torch.Generator] = None, 174) -> Tensor: 175 r"""Fill the input Tensor with values drawn from the normal distribution. 176 177 :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. 178 179 Args: 180 tensor: an n-dimensional `torch.Tensor` 181 mean: the mean of the normal distribution 182 std: the standard deviation of the normal distribution 183 generator: the torch Generator to sample from (default: None) 184 185 Examples: 186 >>> w = torch.empty(3, 5) 187 >>> nn.init.normal_(w) 188 """ 189 if torch.overrides.has_torch_function_variadic(tensor): 190 return torch.overrides.handle_torch_function( 191 normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator 192 ) 193 return _no_grad_normal_(tensor, mean, std, generator) 194 195 196def trunc_normal_( 197 tensor: Tensor, 198 mean: float = 0.0, 199 std: float = 1.0, 200 a: float = -2.0, 201 b: float = 2.0, 202 generator: _Optional[torch.Generator] = None, 203) -> Tensor: 204 r"""Fill the input Tensor with values drawn from a truncated normal distribution. 205 206 The values are effectively drawn from the 207 normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 208 with values outside :math:`[a, b]` redrawn until they are within 209 the bounds. The method used for generating the random values works 210 best when :math:`a \leq \text{mean} \leq b`. 211 212 Args: 213 tensor: an n-dimensional `torch.Tensor` 214 mean: the mean of the normal distribution 215 std: the standard deviation of the normal distribution 216 a: the minimum cutoff value 217 b: the maximum cutoff value 218 generator: the torch Generator to sample from (default: None) 219 220 Examples: 221 >>> w = torch.empty(3, 5) 222 >>> nn.init.trunc_normal_(w) 223 """ 224 return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator) 225 226 227def constant_(tensor: Tensor, val: float) -> Tensor: 228 r"""Fill the input Tensor with the value :math:`\text{val}`. 229 230 Args: 231 tensor: an n-dimensional `torch.Tensor` 232 val: the value to fill the tensor with 233 234 Examples: 235 >>> w = torch.empty(3, 5) 236 >>> nn.init.constant_(w, 0.3) 237 """ 238 if torch.overrides.has_torch_function_variadic(tensor): 239 return torch.overrides.handle_torch_function( 240 constant_, (tensor,), tensor=tensor, val=val 241 ) 242 return _no_grad_fill_(tensor, val) 243 244 245def ones_(tensor: Tensor) -> Tensor: 246 r"""Fill the input Tensor with the scalar value `1`. 247 248 Args: 249 tensor: an n-dimensional `torch.Tensor` 250 251 Examples: 252 >>> w = torch.empty(3, 5) 253 >>> nn.init.ones_(w) 254 """ 255 return _no_grad_fill_(tensor, 1.0) 256 257 258def zeros_(tensor: Tensor) -> Tensor: 259 r"""Fill the input Tensor with the scalar value `0`. 260 261 Args: 262 tensor: an n-dimensional `torch.Tensor` 263 264 Examples: 265 >>> w = torch.empty(3, 5) 266 >>> nn.init.zeros_(w) 267 """ 268 return _no_grad_zero_(tensor) 269 270 271def eye_(tensor): 272 r"""Fill the 2-dimensional input `Tensor` with the identity matrix. 273 274 Preserves the identity of the inputs in `Linear` layers, where as 275 many inputs are preserved as possible. 276 277 Args: 278 tensor: a 2-dimensional `torch.Tensor` 279 280 Examples: 281 >>> w = torch.empty(3, 5) 282 >>> nn.init.eye_(w) 283 """ 284 if tensor.ndimension() != 2: 285 raise ValueError("Only tensors with 2 dimensions are supported") 286 287 with torch.no_grad(): 288 torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad) 289 return tensor 290 291 292def dirac_(tensor, groups=1): 293 r"""Fill the {3, 4, 5}-dimensional input `Tensor` with the Dirac delta function. 294 295 Preserves the identity of the inputs in `Convolutional` 296 layers, where as many input channels are preserved as possible. In case 297 of groups>1, each group of channels preserves identity 298 299 Args: 300 tensor: a {3, 4, 5}-dimensional `torch.Tensor` 301 groups (int, optional): number of groups in the conv layer (default: 1) 302 Examples: 303 >>> w = torch.empty(3, 16, 5, 5) 304 >>> nn.init.dirac_(w) 305 >>> w = torch.empty(3, 24, 5, 5) 306 >>> nn.init.dirac_(w, 3) 307 """ 308 dimensions = tensor.ndimension() 309 if dimensions not in [3, 4, 5]: 310 raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported") 311 312 sizes = tensor.size() 313 314 if sizes[0] % groups != 0: 315 raise ValueError("dim 0 must be divisible by groups") 316 317 out_chans_per_grp = sizes[0] // groups 318 min_dim = min(out_chans_per_grp, sizes[1]) 319 320 with torch.no_grad(): 321 tensor.zero_() 322 323 for g in range(groups): 324 for d in range(min_dim): 325 if dimensions == 3: # Temporal convolution 326 tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1 327 elif dimensions == 4: # Spatial convolution 328 tensor[ 329 g * out_chans_per_grp + d, 330 d, 331 tensor.size(2) // 2, 332 tensor.size(3) // 2, 333 ] = 1 334 else: # Volumetric convolution 335 tensor[ 336 g * out_chans_per_grp + d, 337 d, 338 tensor.size(2) // 2, 339 tensor.size(3) // 2, 340 tensor.size(4) // 2, 341 ] = 1 342 return tensor 343 344 345def _calculate_fan_in_and_fan_out(tensor): 346 dimensions = tensor.dim() 347 if dimensions < 2: 348 raise ValueError( 349 "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" 350 ) 351 352 num_input_fmaps = tensor.size(1) 353 num_output_fmaps = tensor.size(0) 354 receptive_field_size = 1 355 if tensor.dim() > 2: 356 # math.prod is not always available, accumulate the product manually 357 # we could use functools.reduce but that is not supported by TorchScript 358 for s in tensor.shape[2:]: 359 receptive_field_size *= s 360 fan_in = num_input_fmaps * receptive_field_size 361 fan_out = num_output_fmaps * receptive_field_size 362 363 return fan_in, fan_out 364 365 366def xavier_uniform_( 367 tensor: Tensor, 368 gain: float = 1.0, 369 generator: _Optional[torch.Generator] = None, 370) -> Tensor: 371 r"""Fill the input `Tensor` with values using a Xavier uniform distribution. 372 373 The method is described in `Understanding the difficulty of training 374 deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010). 375 The resulting tensor will have values sampled from 376 :math:`\mathcal{U}(-a, a)` where 377 378 .. math:: 379 a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} 380 381 Also known as Glorot initialization. 382 383 Args: 384 tensor: an n-dimensional `torch.Tensor` 385 gain: an optional scaling factor 386 generator: the torch Generator to sample from (default: None) 387 388 Examples: 389 >>> w = torch.empty(3, 5) 390 >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')) 391 392 Note: 393 Be aware that ``fan_in`` and ``fan_out`` are calculated assuming 394 that the weight matrix is used in a transposed manner, 395 (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). 396 This is important for correct initialization. 397 If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, 398 pass in a transposed weight matrix, i.e. ``nn.init.xavier_uniform_(w.T, ...)``. 399 """ 400 fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 401 std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) 402 a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation 403 404 return _no_grad_uniform_(tensor, -a, a, generator) 405 406 407def xavier_normal_( 408 tensor: Tensor, 409 gain: float = 1.0, 410 generator: _Optional[torch.Generator] = None, 411) -> Tensor: 412 r"""Fill the input `Tensor` with values using a Xavier normal distribution. 413 414 The method is described in `Understanding the difficulty of training deep feedforward 415 neural networks` - Glorot, X. & Bengio, Y. (2010). The resulting tensor 416 will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where 417 418 .. math:: 419 \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}} 420 421 Also known as Glorot initialization. 422 423 Args: 424 tensor: an n-dimensional `torch.Tensor` 425 gain: an optional scaling factor 426 generator: the torch Generator to sample from (default: None) 427 428 Examples: 429 >>> w = torch.empty(3, 5) 430 >>> nn.init.xavier_normal_(w) 431 432 Note: 433 Be aware that ``fan_in`` and ``fan_out`` are calculated assuming 434 that the weight matrix is used in a transposed manner, 435 (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). 436 This is important for correct initialization. 437 If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, 438 pass in a transposed weight matrix, i.e. ``nn.init.xavier_normal_(w.T, ...)``. 439 """ 440 fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 441 std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) 442 443 return _no_grad_normal_(tensor, 0.0, std, generator) 444 445 446def _calculate_correct_fan(tensor, mode): 447 mode = mode.lower() 448 valid_modes = ["fan_in", "fan_out"] 449 if mode not in valid_modes: 450 raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}") 451 452 fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) 453 return fan_in if mode == "fan_in" else fan_out 454 455 456def kaiming_uniform_( 457 tensor: Tensor, 458 a: float = 0, 459 mode: str = "fan_in", 460 nonlinearity: str = "leaky_relu", 461 generator: _Optional[torch.Generator] = None, 462): 463 r"""Fill the input `Tensor` with values using a Kaiming uniform distribution. 464 465 The method is described in `Delving deep into rectifiers: Surpassing 466 human-level performance on ImageNet classification` - He, K. et al. (2015). 467 The resulting tensor will have values sampled from 468 :math:`\mathcal{U}(-\text{bound}, \text{bound})` where 469 470 .. math:: 471 \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} 472 473 Also known as He initialization. 474 475 Args: 476 tensor: an n-dimensional `torch.Tensor` 477 a: the negative slope of the rectifier used after this layer (only 478 used with ``'leaky_relu'``) 479 mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` 480 preserves the magnitude of the variance of the weights in the 481 forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the 482 backwards pass. 483 nonlinearity: the non-linear function (`nn.functional` name), 484 recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). 485 generator: the torch Generator to sample from (default: None) 486 487 Examples: 488 >>> w = torch.empty(3, 5) 489 >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') 490 491 Note: 492 Be aware that ``fan_in`` and ``fan_out`` are calculated assuming 493 that the weight matrix is used in a transposed manner, 494 (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). 495 This is important for correct initialization. 496 If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, 497 pass in a transposed weight matrix, i.e. ``nn.init.kaiming_uniform_(w.T, ...)``. 498 """ 499 if torch.overrides.has_torch_function_variadic(tensor): 500 return torch.overrides.handle_torch_function( 501 kaiming_uniform_, 502 (tensor,), 503 tensor=tensor, 504 a=a, 505 mode=mode, 506 nonlinearity=nonlinearity, 507 generator=generator, 508 ) 509 510 if 0 in tensor.shape: 511 warnings.warn("Initializing zero-element tensors is a no-op") 512 return tensor 513 fan = _calculate_correct_fan(tensor, mode) 514 gain = calculate_gain(nonlinearity, a) 515 std = gain / math.sqrt(fan) 516 bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation 517 with torch.no_grad(): 518 return tensor.uniform_(-bound, bound, generator=generator) 519 520 521def kaiming_normal_( 522 tensor: Tensor, 523 a: float = 0, 524 mode: str = "fan_in", 525 nonlinearity: str = "leaky_relu", 526 generator: _Optional[torch.Generator] = None, 527): 528 r"""Fill the input `Tensor` with values using a Kaiming normal distribution. 529 530 The method is described in `Delving deep into rectifiers: Surpassing 531 human-level performance on ImageNet classification` - He, K. et al. (2015). 532 The resulting tensor will have values sampled from 533 :math:`\mathcal{N}(0, \text{std}^2)` where 534 535 .. math:: 536 \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} 537 538 Also known as He initialization. 539 540 Args: 541 tensor: an n-dimensional `torch.Tensor` 542 a: the negative slope of the rectifier used after this layer (only 543 used with ``'leaky_relu'``) 544 mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` 545 preserves the magnitude of the variance of the weights in the 546 forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the 547 backwards pass. 548 nonlinearity: the non-linear function (`nn.functional` name), 549 recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). 550 generator: the torch Generator to sample from (default: None) 551 552 Examples: 553 >>> w = torch.empty(3, 5) 554 >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') 555 556 Note: 557 Be aware that ``fan_in`` and ``fan_out`` are calculated assuming 558 that the weight matrix is used in a transposed manner, 559 (i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). 560 This is important for correct initialization. 561 If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, 562 pass in a transposed weight matrix, i.e. ``nn.init.kaiming_normal_(w.T, ...)``. 563 """ 564 if 0 in tensor.shape: 565 warnings.warn("Initializing zero-element tensors is a no-op") 566 return tensor 567 fan = _calculate_correct_fan(tensor, mode) 568 gain = calculate_gain(nonlinearity, a) 569 std = gain / math.sqrt(fan) 570 with torch.no_grad(): 571 return tensor.normal_(0, std, generator=generator) 572 573 574def orthogonal_( 575 tensor, 576 gain=1, 577 generator: _Optional[torch.Generator] = None, 578): 579 r"""Fill the input `Tensor` with a (semi) orthogonal matrix. 580 581 Described in `Exact solutions to the nonlinear dynamics of learning in deep 582 linear neural networks` - Saxe, A. et al. (2013). The input tensor must have 583 at least 2 dimensions, and for tensors with more than 2 dimensions the 584 trailing dimensions are flattened. 585 586 Args: 587 tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2` 588 gain: optional scaling factor 589 generator: the torch Generator to sample from (default: None) 590 591 Examples: 592 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) 593 >>> w = torch.empty(3, 5) 594 >>> nn.init.orthogonal_(w) 595 """ 596 if tensor.ndimension() < 2: 597 raise ValueError("Only tensors with 2 or more dimensions are supported") 598 599 if tensor.numel() == 0: 600 # no-op 601 return tensor 602 rows = tensor.size(0) 603 cols = tensor.numel() // rows 604 flattened = tensor.new_empty((rows, cols)).normal_(0, 1, generator=generator) 605 606 if rows < cols: 607 flattened.t_() 608 609 # Compute the qr factorization 610 q, r = torch.linalg.qr(flattened) 611 # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf 612 d = torch.diag(r, 0) 613 ph = d.sign() 614 q *= ph 615 616 if rows < cols: 617 q.t_() 618 619 with torch.no_grad(): 620 tensor.view_as(q).copy_(q) 621 tensor.mul_(gain) 622 return tensor 623 624 625def sparse_( 626 tensor, 627 sparsity, 628 std=0.01, 629 generator: _Optional[torch.Generator] = None, 630): 631 r"""Fill the 2D input `Tensor` as a sparse matrix. 632 633 The non-zero elements will be drawn from the normal distribution 634 :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via 635 Hessian-free optimization` - Martens, J. (2010). 636 637 Args: 638 tensor: an n-dimensional `torch.Tensor` 639 sparsity: The fraction of elements in each column to be set to zero 640 std: the standard deviation of the normal distribution used to generate 641 the non-zero values 642 generator: the torch Generator to sample from (default: None) 643 644 Examples: 645 >>> w = torch.empty(3, 5) 646 >>> nn.init.sparse_(w, sparsity=0.1) 647 """ 648 if tensor.ndimension() != 2: 649 raise ValueError("Only tensors with 2 dimensions are supported") 650 651 rows, cols = tensor.shape 652 num_zeros = int(math.ceil(sparsity * rows)) 653 654 with torch.no_grad(): 655 tensor.normal_(0, std, generator=generator) 656 for col_idx in range(cols): 657 row_indices = torch.randperm(rows) 658 zero_indices = row_indices[:num_zeros] 659 tensor[zero_indices, col_idx] = 0 660 return tensor 661 662 663# for backward compatibility 664def _make_deprecate(meth): 665 new_name = meth.__name__ 666 old_name = new_name[:-1] 667 668 def deprecated_init(*args, **kwargs): 669 warnings.warn( 670 f"`nn.init.{old_name}` is now deprecated in favor of `nn.init.{new_name}`.", 671 FutureWarning, 672 stacklevel=2, 673 ) 674 return meth(*args, **kwargs) 675 676 deprecated_init.__doc__ = rf""" 677 {old_name}(...) 678 679 .. warning:: 680 This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`. 681 682 See :func:`~torch.nn.init.{new_name}` for details.""" 683 deprecated_init.__name__ = old_name 684 return deprecated_init 685 686 687uniform = _make_deprecate(uniform_) 688normal = _make_deprecate(normal_) 689constant = _make_deprecate(constant_) 690eye = _make_deprecate(eye_) 691dirac = _make_deprecate(dirac_) 692xavier_uniform = _make_deprecate(xavier_uniform_) 693xavier_normal = _make_deprecate(xavier_normal_) 694kaiming_uniform = _make_deprecate(kaiming_uniform_) 695kaiming_normal = _make_deprecate(kaiming_normal_) 696orthogonal = _make_deprecate(orthogonal_) 697sparse = _make_deprecate(sparse_) 698