1.. _modules: 2 3Modules 4======= 5 6PyTorch uses modules to represent neural networks. Modules are: 7 8* **Building blocks of stateful computation.** 9 PyTorch provides a robust library of modules and makes it simple to define new custom modules, allowing for 10 easy construction of elaborate, multi-layer neural networks. 11* **Tightly integrated with PyTorch's** 12 `autograd <https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html>`_ 13 **system.** Modules make it simple to specify learnable parameters for PyTorch's Optimizers to update. 14* **Easy to work with and transform.** Modules are straightforward to save and restore, transfer between 15 CPU / GPU / TPU devices, prune, quantize, and more. 16 17This note describes modules, and is intended for all PyTorch users. Since modules are so fundamental to PyTorch, 18many topics in this note are elaborated on in other notes or tutorials, and links to many of those documents 19are provided here as well. 20 21.. contents:: :local: 22 23A Simple Custom Module 24---------------------- 25 26To get started, let's look at a simpler, custom version of PyTorch's :class:`~torch.nn.Linear` module. 27This module applies an affine transformation to its input. 28 29.. code-block:: python 30 31 import torch 32 from torch import nn 33 34 class MyLinear(nn.Module): 35 def __init__(self, in_features, out_features): 36 super().__init__() 37 self.weight = nn.Parameter(torch.randn(in_features, out_features)) 38 self.bias = nn.Parameter(torch.randn(out_features)) 39 40 def forward(self, input): 41 return (input @ self.weight) + self.bias 42 43This simple module has the following fundamental characteristics of modules: 44 45* **It inherits from the base Module class.** 46 All modules should subclass :class:`~torch.nn.Module` for composability with other modules. 47* **It defines some "state" that is used in computation.** 48 Here, the state consists of randomly-initialized ``weight`` and ``bias`` tensors that define the affine 49 transformation. Because each of these is defined as a :class:`~torch.nn.parameter.Parameter`, they are 50 *registered* for the module and will automatically be tracked and returned from calls 51 to :func:`~torch.nn.Module.parameters`. Parameters can be 52 considered the "learnable" aspects of the module's computation (more on this later). Note that modules 53 are not required to have state, and can also be stateless. 54* **It defines a forward() function that performs the computation.** For this affine transformation module, the input 55 is matrix-multiplied with the ``weight`` parameter (using the ``@`` short-hand notation) and added to the ``bias`` 56 parameter to produce the output. More generally, the ``forward()`` implementation for a module can perform arbitrary 57 computation involving any number of inputs and outputs. 58 59This simple module demonstrates how modules package state and computation together. Instances of this module can be 60constructed and called: 61 62.. code-block:: python 63 64 m = MyLinear(4, 3) 65 sample_input = torch.randn(4) 66 m(sample_input) 67 : tensor([-0.3037, -1.0413, -4.2057], grad_fn=<AddBackward0>) 68 69Note that the module itself is callable, and that calling it invokes its ``forward()`` function. 70This name is in reference to the concepts of "forward pass" and "backward pass", which apply to each module. 71The "forward pass" is responsible for applying the computation represented by the module 72to the given input(s) (as shown in the above snippet). The "backward pass" computes gradients of 73module outputs with respect to its inputs, which can be used for "training" parameters through gradient 74descent methods. PyTorch's autograd system automatically takes care of this backward pass computation, so it 75is not required to manually implement a ``backward()`` function for each module. The process of training 76module parameters through successive forward / backward passes is covered in detail in 77:ref:`Neural Network Training with Modules`. 78 79The full set of parameters registered by the module can be iterated through via a call to 80:func:`~torch.nn.Module.parameters` or :func:`~torch.nn.Module.named_parameters`, 81where the latter includes each parameter's name: 82 83.. code-block:: python 84 85 for parameter in m.named_parameters(): 86 print(parameter) 87 : ('weight', Parameter containing: 88 tensor([[ 1.0597, 1.1796, 0.8247], 89 [-0.5080, -1.2635, -1.1045], 90 [ 0.0593, 0.2469, -1.4299], 91 [-0.4926, -0.5457, 0.4793]], requires_grad=True)) 92 ('bias', Parameter containing: 93 tensor([ 0.3634, 0.2015, -0.8525], requires_grad=True)) 94 95In general, the parameters registered by a module are aspects of the module's computation that should be 96"learned". A later section of this note shows how to update these parameters using one of PyTorch's Optimizers. 97Before we get to that, however, let's first examine how modules can be composed with one another. 98 99Modules as Building Blocks 100-------------------------- 101 102Modules can contain other modules, making them useful building blocks for developing more elaborate functionality. 103The simplest way to do this is using the :class:`~torch.nn.Sequential` module. It allows us to chain together 104multiple modules: 105 106.. code-block:: python 107 108 net = nn.Sequential( 109 MyLinear(4, 3), 110 nn.ReLU(), 111 MyLinear(3, 1) 112 ) 113 114 sample_input = torch.randn(4) 115 net(sample_input) 116 : tensor([-0.6749], grad_fn=<AddBackward0>) 117 118Note that :class:`~torch.nn.Sequential` automatically feeds the output of the first ``MyLinear`` module as input 119into the :class:`~torch.nn.ReLU`, and the output of that as input into the second ``MyLinear`` module. As 120shown, it is limited to in-order chaining of modules with a single input and output. 121 122In general, it is recommended to define a custom module for anything beyond the simplest use cases, as this gives 123full flexibility on how submodules are used for a module's computation. 124 125For example, here's a simple neural network implemented as a custom module: 126 127.. code-block:: python 128 129 import torch.nn.functional as F 130 131 class Net(nn.Module): 132 def __init__(self): 133 super().__init__() 134 self.l0 = MyLinear(4, 3) 135 self.l1 = MyLinear(3, 1) 136 def forward(self, x): 137 x = self.l0(x) 138 x = F.relu(x) 139 x = self.l1(x) 140 return x 141 142This module is composed of two "children" or "submodules" (\ ``l0`` and ``l1``\ ) that define the layers of 143the neural network and are utilized for computation within the module's ``forward()`` method. Immediate 144children of a module can be iterated through via a call to :func:`~torch.nn.Module.children` or 145:func:`~torch.nn.Module.named_children`: 146 147.. code-block:: python 148 149 net = Net() 150 for child in net.named_children(): 151 print(child) 152 : ('l0', MyLinear()) 153 ('l1', MyLinear()) 154 155To go deeper than just the immediate children, :func:`~torch.nn.Module.modules` and 156:func:`~torch.nn.Module.named_modules` *recursively* iterate through a module and its child modules: 157 158.. code-block:: python 159 160 class BigNet(nn.Module): 161 def __init__(self): 162 super().__init__() 163 self.l1 = MyLinear(5, 4) 164 self.net = Net() 165 def forward(self, x): 166 return self.net(self.l1(x)) 167 168 big_net = BigNet() 169 for module in big_net.named_modules(): 170 print(module) 171 : ('', BigNet( 172 (l1): MyLinear() 173 (net): Net( 174 (l0): MyLinear() 175 (l1): MyLinear() 176 ) 177 )) 178 ('l1', MyLinear()) 179 ('net', Net( 180 (l0): MyLinear() 181 (l1): MyLinear() 182 )) 183 ('net.l0', MyLinear()) 184 ('net.l1', MyLinear()) 185 186Sometimes, it's necessary for a module to dynamically define submodules. 187The :class:`~torch.nn.ModuleList` and :class:`~torch.nn.ModuleDict` modules are useful here; they 188register submodules from a list or dict: 189 190.. code-block:: python 191 192 class DynamicNet(nn.Module): 193 def __init__(self, num_layers): 194 super().__init__() 195 self.linears = nn.ModuleList( 196 [MyLinear(4, 4) for _ in range(num_layers)]) 197 self.activations = nn.ModuleDict({ 198 'relu': nn.ReLU(), 199 'lrelu': nn.LeakyReLU() 200 }) 201 self.final = MyLinear(4, 1) 202 def forward(self, x, act): 203 for linear in self.linears: 204 x = linear(x) 205 x = self.activations[act](x) 206 x = self.final(x) 207 return x 208 209 dynamic_net = DynamicNet(3) 210 sample_input = torch.randn(4) 211 output = dynamic_net(sample_input, 'relu') 212 213For any given module, its parameters consist of its direct parameters as well as the parameters of all submodules. 214This means that calls to :func:`~torch.nn.Module.parameters` and :func:`~torch.nn.Module.named_parameters` will 215recursively include child parameters, allowing for convenient optimization of all parameters within the network: 216 217.. code-block:: python 218 219 for parameter in dynamic_net.named_parameters(): 220 print(parameter) 221 : ('linears.0.weight', Parameter containing: 222 tensor([[-1.2051, 0.7601, 1.1065, 0.1963], 223 [ 3.0592, 0.4354, 1.6598, 0.9828], 224 [-0.4446, 0.4628, 0.8774, 1.6848], 225 [-0.1222, 1.5458, 1.1729, 1.4647]], requires_grad=True)) 226 ('linears.0.bias', Parameter containing: 227 tensor([ 1.5310, 1.0609, -2.0940, 1.1266], requires_grad=True)) 228 ('linears.1.weight', Parameter containing: 229 tensor([[ 2.1113, -0.0623, -1.0806, 0.3508], 230 [-0.0550, 1.5317, 1.1064, -0.5562], 231 [-0.4028, -0.6942, 1.5793, -1.0140], 232 [-0.0329, 0.1160, -1.7183, -1.0434]], requires_grad=True)) 233 ('linears.1.bias', Parameter containing: 234 tensor([ 0.0361, -0.9768, -0.3889, 1.1613], requires_grad=True)) 235 ('linears.2.weight', Parameter containing: 236 tensor([[-2.6340, -0.3887, -0.9979, 0.0767], 237 [-0.3526, 0.8756, -1.5847, -0.6016], 238 [-0.3269, -0.1608, 0.2897, -2.0829], 239 [ 2.6338, 0.9239, 0.6943, -1.5034]], requires_grad=True)) 240 ('linears.2.bias', Parameter containing: 241 tensor([ 1.0268, 0.4489, -0.9403, 0.1571], requires_grad=True)) 242 ('final.weight', Parameter containing: 243 tensor([[ 0.2509], [-0.5052], [ 0.3088], [-1.4951]], requires_grad=True)) 244 ('final.bias', Parameter containing: 245 tensor([0.3381], requires_grad=True)) 246 247It's also easy to move all parameters to a different device or change their precision using 248:func:`~torch.nn.Module.to`: 249 250.. code-block:: python 251 252 # Move all parameters to a CUDA device 253 dynamic_net.to(device='cuda') 254 255 # Change precision of all parameters 256 dynamic_net.to(dtype=torch.float64) 257 258 dynamic_net(torch.randn(5, device='cuda', dtype=torch.float64)) 259 : tensor([6.5166], device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>) 260 261More generally, an arbitrary function can be applied to a module and its submodules recursively by 262using the :func:`~torch.nn.Module.apply` function. For example, to apply custom initialization to parameters 263of a module and its submodules: 264 265.. code-block:: python 266 267 # Define a function to initialize Linear weights. 268 # Note that no_grad() is used here to avoid tracking this computation in the autograd graph. 269 @torch.no_grad() 270 def init_weights(m): 271 if isinstance(m, nn.Linear): 272 nn.init.xavier_normal_(m.weight) 273 m.bias.fill_(0.0) 274 275 # Apply the function recursively on the module and its submodules. 276 dynamic_net.apply(init_weights) 277 278These examples show how elaborate neural networks can be formed through module composition and conveniently 279manipulated. To allow for quick and easy construction of neural networks with minimal boilerplate, PyTorch 280provides a large library of performant modules within the :mod:`torch.nn` namespace that perform common neural 281network operations like pooling, convolutions, loss functions, etc. 282 283In the next section, we give a full example of training a neural network. 284 285For more information, check out: 286 287* Library of PyTorch-provided modules: `torch.nn <https://pytorch.org/docs/stable/nn.html>`_ 288* Defining neural net modules: https://pytorch.org/tutorials/beginner/examples_nn/polynomial_module.html 289 290.. _Neural Network Training with Modules: 291 292Neural Network Training with Modules 293------------------------------------ 294 295Once a network is built, it has to be trained, and its parameters can be easily optimized with one of PyTorch’s 296Optimizers from :mod:`torch.optim`: 297 298.. code-block:: python 299 300 # Create the network (from previous section) and optimizer 301 net = Net() 302 optimizer = torch.optim.SGD(net.parameters(), lr=1e-4, weight_decay=1e-2, momentum=0.9) 303 304 # Run a sample training loop that "teaches" the network 305 # to output the constant zero function 306 for _ in range(10000): 307 input = torch.randn(4) 308 output = net(input) 309 loss = torch.abs(output) 310 net.zero_grad() 311 loss.backward() 312 optimizer.step() 313 314 # After training, switch the module to eval mode to do inference, compute performance metrics, etc. 315 # (see discussion below for a description of training and evaluation modes) 316 ... 317 net.eval() 318 ... 319 320In this simplified example, the network learns to simply output zero, as any non-zero output is "penalized" according 321to its absolute value by employing :func:`torch.abs` as a loss function. While this is not a very interesting task, the 322key parts of training are present: 323 324* A network is created. 325* An optimizer (in this case, a stochastic gradient descent optimizer) is created, and the network’s 326 parameters are associated with it. 327* A training loop... 328 * acquires an input, 329 * runs the network, 330 * computes a loss, 331 * zeros the network’s parameters’ gradients, 332 * calls loss.backward() to update the parameters’ gradients, 333 * calls optimizer.step() to apply the gradients to the parameters. 334 335After the above snippet has been run, note that the network's parameters have changed. In particular, examining the 336value of ``l1``\ 's ``weight`` parameter shows that its values are now much closer to 0 (as may be expected): 337 338.. code-block:: python 339 340 print(net.l1.weight) 341 : Parameter containing: 342 tensor([[-0.0013], 343 [ 0.0030], 344 [-0.0008]], requires_grad=True) 345 346Note that the above process is done entirely while the network module is in "training mode". Modules default to 347training mode and can be switched between training and evaluation modes using :func:`~torch.nn.Module.train` and 348:func:`~torch.nn.Module.eval`. They can behave differently depending on which mode they are in. For example, the 349:class:`~torch.nn.BatchNorm` module maintains a running mean and variance during training that are not updated 350when the module is in evaluation mode. In general, modules should be in training mode during training 351and only switched to evaluation mode for inference or evaluation. Below is an example of a custom module 352that behaves differently between the two modes: 353 354.. code-block:: python 355 356 class ModalModule(nn.Module): 357 def __init__(self): 358 super().__init__() 359 360 def forward(self, x): 361 if self.training: 362 # Add a constant only in training mode. 363 return x + 1. 364 else: 365 return x 366 367 368 m = ModalModule() 369 x = torch.randn(4) 370 371 print('training mode output: {}'.format(m(x))) 372 : tensor([1.6614, 1.2669, 1.0617, 1.6213, 0.5481]) 373 374 m.eval() 375 print('evaluation mode output: {}'.format(m(x))) 376 : tensor([ 0.6614, 0.2669, 0.0617, 0.6213, -0.4519]) 377 378Training neural networks can often be tricky. For more information, check out: 379 380* Using Optimizers: https://pytorch.org/tutorials/beginner/examples_nn/two_layer_net_optim.html. 381* Neural network training: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html 382* Introduction to autograd: https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html 383 384Module State 385------------ 386 387In the previous section, we demonstrated training a module's "parameters", or learnable aspects of computation. 388Now, if we want to save the trained model to disk, we can do so by saving its ``state_dict`` (i.e. "state dictionary"): 389 390.. code-block:: python 391 392 # Save the module 393 torch.save(net.state_dict(), 'net.pt') 394 395 ... 396 397 # Load the module later on 398 new_net = Net() 399 new_net.load_state_dict(torch.load('net.pt')) 400 : <All keys matched successfully> 401 402A module's ``state_dict`` contains state that affects its computation. This includes, but is not limited to, the 403module's parameters. For some modules, it may be useful to have state beyond parameters that affects module 404computation but is not learnable. For such cases, PyTorch provides the concept of "buffers", both "persistent" 405and "non-persistent". Following is an overview of the various types of state a module can have: 406 407* **Parameters**\ : learnable aspects of computation; contained within the ``state_dict`` 408* **Buffers**\ : non-learnable aspects of computation 409 410 * **Persistent** buffers: contained within the ``state_dict`` (i.e. serialized when saving & loading) 411 * **Non-persistent** buffers: not contained within the ``state_dict`` (i.e. left out of serialization) 412 413As a motivating example for the use of buffers, consider a simple module that maintains a running mean. We want 414the current value of the running mean to be considered part of the module's ``state_dict`` so that it will be 415restored when loading a serialized form of the module, but we don't want it to be learnable. 416This snippet shows how to use :func:`~torch.nn.Module.register_buffer` to accomplish this: 417 418.. code-block:: python 419 420 class RunningMean(nn.Module): 421 def __init__(self, num_features, momentum=0.9): 422 super().__init__() 423 self.momentum = momentum 424 self.register_buffer('mean', torch.zeros(num_features)) 425 def forward(self, x): 426 self.mean = self.momentum * self.mean + (1.0 - self.momentum) * x 427 return self.mean 428 429Now, the current value of the running mean is considered part of the module's ``state_dict`` 430and will be properly restored when loading the module from disk: 431 432.. code-block:: python 433 434 m = RunningMean(4) 435 for _ in range(10): 436 input = torch.randn(4) 437 m(input) 438 439 print(m.state_dict()) 440 : OrderedDict([('mean', tensor([ 0.1041, -0.1113, -0.0647, 0.1515]))])) 441 442 # Serialized form will contain the 'mean' tensor 443 torch.save(m.state_dict(), 'mean.pt') 444 445 m_loaded = RunningMean(4) 446 m_loaded.load_state_dict(torch.load('mean.pt')) 447 assert(torch.all(m.mean == m_loaded.mean)) 448 449As mentioned previously, buffers can be left out of the module's ``state_dict`` by marking them as non-persistent: 450 451.. code-block:: python 452 453 self.register_buffer('unserialized_thing', torch.randn(5), persistent=False) 454 455Both persistent and non-persistent buffers are affected by model-wide device / dtype changes applied with 456:func:`~torch.nn.Module.to`: 457 458.. code-block:: python 459 460 # Moves all module parameters and buffers to the specified device / dtype 461 m.to(device='cuda', dtype=torch.float64) 462 463Buffers of a module can be iterated over using :func:`~torch.nn.Module.buffers` or 464:func:`~torch.nn.Module.named_buffers`. 465 466.. code-block:: python 467 468 for buffer in m.named_buffers(): 469 print(buffer) 470 471The following class demonstrates the various ways of registering parameters and buffers within a module: 472 473.. code-block:: python 474 475 class StatefulModule(nn.Module): 476 def __init__(self): 477 super().__init__() 478 # Setting a nn.Parameter as an attribute of the module automatically registers the tensor 479 # as a parameter of the module. 480 self.param1 = nn.Parameter(torch.randn(2)) 481 482 # Alternative string-based way to register a parameter. 483 self.register_parameter('param2', nn.Parameter(torch.randn(3))) 484 485 # Reserves the "param3" attribute as a parameter, preventing it from being set to anything 486 # except a parameter. "None" entries like this will not be present in the module's state_dict. 487 self.register_parameter('param3', None) 488 489 # Registers a list of parameters. 490 self.param_list = nn.ParameterList([nn.Parameter(torch.randn(2)) for i in range(3)]) 491 492 # Registers a dictionary of parameters. 493 self.param_dict = nn.ParameterDict({ 494 'foo': nn.Parameter(torch.randn(3)), 495 'bar': nn.Parameter(torch.randn(4)) 496 }) 497 498 # Registers a persistent buffer (one that appears in the module's state_dict). 499 self.register_buffer('buffer1', torch.randn(4), persistent=True) 500 501 # Registers a non-persistent buffer (one that does not appear in the module's state_dict). 502 self.register_buffer('buffer2', torch.randn(5), persistent=False) 503 504 # Reserves the "buffer3" attribute as a buffer, preventing it from being set to anything 505 # except a buffer. "None" entries like this will not be present in the module's state_dict. 506 self.register_buffer('buffer3', None) 507 508 # Adding a submodule registers its parameters as parameters of the module. 509 self.linear = nn.Linear(2, 3) 510 511 m = StatefulModule() 512 513 # Save and load state_dict. 514 torch.save(m.state_dict(), 'state.pt') 515 m_loaded = StatefulModule() 516 m_loaded.load_state_dict(torch.load('state.pt')) 517 518 # Note that non-persistent buffer "buffer2" and reserved attributes "param3" and "buffer3" do 519 # not appear in the state_dict. 520 print(m_loaded.state_dict()) 521 : OrderedDict([('param1', tensor([-0.0322, 0.9066])), 522 ('param2', tensor([-0.4472, 0.1409, 0.4852])), 523 ('buffer1', tensor([ 0.6949, -0.1944, 1.2911, -2.1044])), 524 ('param_list.0', tensor([ 0.4202, -0.1953])), 525 ('param_list.1', tensor([ 1.5299, -0.8747])), 526 ('param_list.2', tensor([-1.6289, 1.4898])), 527 ('param_dict.bar', tensor([-0.6434, 1.5187, 0.0346, -0.4077])), 528 ('param_dict.foo', tensor([-0.0845, -1.4324, 0.7022])), 529 ('linear.weight', tensor([[-0.3915, -0.6176], 530 [ 0.6062, -0.5992], 531 [ 0.4452, -0.2843]])), 532 ('linear.bias', tensor([-0.3710, -0.0795, -0.3947]))]) 533 534For more information, check out: 535 536* Saving and loading: https://pytorch.org/tutorials/beginner/saving_loading_models.html 537* Serialization semantics: https://pytorch.org/docs/main/notes/serialization.html 538* What is a state dict? https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html 539 540Module Initialization 541--------------------- 542 543By default, parameters and floating-point buffers for modules provided by :mod:`torch.nn` are initialized during 544module instantiation as 32-bit floating point values on the CPU using an initialization scheme determined to 545perform well historically for the module type. For certain use cases, it may be desired to initialize with a different 546dtype, device (e.g. GPU), or initialization technique. 547 548Examples: 549 550.. code-block:: python 551 552 # Initialize module directly onto GPU. 553 m = nn.Linear(5, 3, device='cuda') 554 555 # Initialize module with 16-bit floating point parameters. 556 m = nn.Linear(5, 3, dtype=torch.half) 557 558 # Skip default parameter initialization and perform custom (e.g. orthogonal) initialization. 559 m = torch.nn.utils.skip_init(nn.Linear, 5, 3) 560 nn.init.orthogonal_(m.weight) 561 562Note that the device and dtype options demonstrated above also apply to any floating-point buffers registered 563for the module: 564 565.. code-block:: python 566 567 m = nn.BatchNorm2d(3, dtype=torch.half) 568 print(m.running_mean) 569 : tensor([0., 0., 0.], dtype=torch.float16) 570 571While module writers can use any device or dtype to initialize parameters in their custom modules, good practice is 572to use ``dtype=torch.float`` and ``device='cpu'`` by default as well. Optionally, you can provide full flexibility 573in these areas for your custom module by conforming to the convention demonstrated above that all 574:mod:`torch.nn` modules follow: 575 576* Provide a ``device`` constructor kwarg that applies to any parameters / buffers registered by the module. 577* Provide a ``dtype`` constructor kwarg that applies to any parameters / floating-point buffers registered by 578 the module. 579* Only use initialization functions (i.e. functions from :mod:`torch.nn.init`) on parameters and buffers within the 580 module's constructor. Note that this is only required to use :func:`~torch.nn.utils.skip_init`; see 581 `this page <https://pytorch.org/tutorials/prototype/skip_param_init.html#updating-modules-to-support-skipping-initialization>`_ for an explanation. 582 583For more information, check out: 584 585* Skipping module parameter initialization: https://pytorch.org/tutorials/prototype/skip_param_init.html 586 587Module Hooks 588------------ 589 590In :ref:`Neural Network Training with Modules`, we demonstrated the training process for a module, which iteratively 591performs forward and backward passes, updating module parameters each iteration. For more control 592over this process, PyTorch provides "hooks" that can perform arbitrary computation during a forward or backward 593pass, even modifying how the pass is done if desired. Some useful examples for this functionality include 594debugging, visualizing activations, examining gradients in-depth, etc. Hooks can be added to modules 595you haven't written yourself, meaning this functionality can be applied to third-party or PyTorch-provided modules. 596 597PyTorch provides two types of hooks for modules: 598 599* **Forward hooks** are called during the forward pass. They can be installed for a given module with 600 :func:`~torch.nn.Module.register_forward_pre_hook` and :func:`~torch.nn.Module.register_forward_hook`. 601 These hooks will be called respectively just before the forward function is called and just after it is called. 602 Alternatively, these hooks can be installed globally for all modules with the analogous 603 :func:`~torch.nn.modules.module.register_module_forward_pre_hook` and 604 :func:`~torch.nn.modules.module.register_module_forward_hook` functions. 605* **Backward hooks** are called during the backward pass. They can be installed with 606 :func:`~torch.nn.Module.register_full_backward_pre_hook` and :func:`~torch.nn.Module.register_full_backward_hook`. 607 These hooks will be called when the backward for this Module has been computed. 608 :func:`~torch.nn.Module.register_full_backward_pre_hook` will allow the user to access the gradients for outputs 609 while :func:`~torch.nn.Module.register_full_backward_hook` will allow the user to access the gradients 610 both the inputs and outputs. Alternatively, they can be installed globally for all modules with 611 :func:`~torch.nn.modules.module.register_module_full_backward_hook` and 612 :func:`~torch.nn.modules.module.register_module_full_backward_pre_hook`. 613 614All hooks allow the user to return an updated value that will be used throughout the remaining computation. 615Thus, these hooks can be used to either execute arbitrary code along the regular module forward/backward or 616modify some inputs/outputs without having to change the module's ``forward()`` function. 617 618Below is an example demonstrating usage of forward and backward hooks: 619 620.. code-block:: python 621 622 torch.manual_seed(1) 623 624 def forward_pre_hook(m, inputs): 625 # Allows for examination and modification of the input before the forward pass. 626 # Note that inputs are always wrapped in a tuple. 627 input = inputs[0] 628 return input + 1. 629 630 def forward_hook(m, inputs, output): 631 # Allows for examination of inputs / outputs and modification of the outputs 632 # after the forward pass. Note that inputs are always wrapped in a tuple while outputs 633 # are passed as-is. 634 635 # Residual computation a la ResNet. 636 return output + inputs[0] 637 638 def backward_hook(m, grad_inputs, grad_outputs): 639 # Allows for examination of grad_inputs / grad_outputs and modification of 640 # grad_inputs used in the rest of the backwards pass. Note that grad_inputs and 641 # grad_outputs are always wrapped in tuples. 642 new_grad_inputs = [torch.ones_like(gi) * 42. for gi in grad_inputs] 643 return new_grad_inputs 644 645 # Create sample module & input. 646 m = nn.Linear(3, 3) 647 x = torch.randn(2, 3, requires_grad=True) 648 649 # ==== Demonstrate forward hooks. ==== 650 # Run input through module before and after adding hooks. 651 print('output with no forward hooks: {}'.format(m(x))) 652 : output with no forward hooks: tensor([[-0.5059, -0.8158, 0.2390], 653 [-0.0043, 0.4724, -0.1714]], grad_fn=<AddmmBackward>) 654 655 # Note that the modified input results in a different output. 656 forward_pre_hook_handle = m.register_forward_pre_hook(forward_pre_hook) 657 print('output with forward pre hook: {}'.format(m(x))) 658 : output with forward pre hook: tensor([[-0.5752, -0.7421, 0.4942], 659 [-0.0736, 0.5461, 0.0838]], grad_fn=<AddmmBackward>) 660 661 # Note the modified output. 662 forward_hook_handle = m.register_forward_hook(forward_hook) 663 print('output with both forward hooks: {}'.format(m(x))) 664 : output with both forward hooks: tensor([[-1.0980, 0.6396, 0.4666], 665 [ 0.3634, 0.6538, 1.0256]], grad_fn=<AddBackward0>) 666 667 # Remove hooks; note that the output here matches the output before adding hooks. 668 forward_pre_hook_handle.remove() 669 forward_hook_handle.remove() 670 print('output after removing forward hooks: {}'.format(m(x))) 671 : output after removing forward hooks: tensor([[-0.5059, -0.8158, 0.2390], 672 [-0.0043, 0.4724, -0.1714]], grad_fn=<AddmmBackward>) 673 674 # ==== Demonstrate backward hooks. ==== 675 m(x).sum().backward() 676 print('x.grad with no backwards hook: {}'.format(x.grad)) 677 : x.grad with no backwards hook: tensor([[ 0.4497, -0.5046, 0.3146], 678 [ 0.4497, -0.5046, 0.3146]]) 679 680 # Clear gradients before running backward pass again. 681 m.zero_grad() 682 x.grad.zero_() 683 684 m.register_full_backward_hook(backward_hook) 685 m(x).sum().backward() 686 print('x.grad with backwards hook: {}'.format(x.grad)) 687 : x.grad with backwards hook: tensor([[42., 42., 42.], 688 [42., 42., 42.]]) 689 690Advanced Features 691----------------- 692 693PyTorch also provides several more advanced features that are designed to work with modules. All these functionalities 694are available for custom-written modules, with the small caveat that certain features may require modules to conform 695to particular constraints in order to be supported. In-depth discussion of these features and the corresponding 696requirements can be found in the links below. 697 698Distributed Training 699******************** 700 701Various methods for distributed training exist within PyTorch, both for scaling up training using multiple GPUs 702as well as training across multiple machines. Check out the 703`distributed training overview page <https://pytorch.org/tutorials/beginner/dist_overview.html>`_ for 704detailed information on how to utilize these. 705 706Profiling Performance 707********************* 708 709The `PyTorch Profiler <https://pytorch.org/tutorials/beginner/profiler.html>`_ can be useful for identifying 710performance bottlenecks within your models. It measures and outputs performance characteristics for 711both memory usage and time spent. 712 713Improving Performance with Quantization 714*************************************** 715 716Applying quantization techniques to modules can improve performance and memory usage by utilizing lower 717bitwidths than floating-point precision. Check out the various PyTorch-provided mechanisms for quantization 718`here <https://pytorch.org/docs/stable/quantization.html>`_. 719 720Improving Memory Usage with Pruning 721*********************************** 722 723Large deep learning models are often over-parametrized, resulting in high memory usage. To combat this, PyTorch 724provides mechanisms for model pruning, which can help reduce memory usage while maintaining task accuracy. The 725`Pruning tutorial <https://pytorch.org/tutorials/intermediate/pruning_tutorial.html>`_ describes how to utilize 726the pruning techniques PyTorch provides or define custom pruning techniques as necessary. 727 728Parametrizations 729**************** 730 731For certain applications, it can be beneficial to constrain the parameter space during model training. For example, 732enforcing orthogonality of the learned parameters can improve convergence for RNNs. PyTorch provides a mechanism for 733applying `parametrizations <https://pytorch.org/tutorials/intermediate/parametrizations.html>`_ such as this, and 734further allows for custom constraints to be defined. 735 736Transforming Modules with FX 737**************************** 738 739The `FX <https://pytorch.org/docs/stable/fx.html>`_ component of PyTorch provides a flexible way to transform 740modules by operating directly on module computation graphs. This can be used to programmatically generate or 741manipulate modules for a broad array of use cases. To explore FX, check out these examples of using FX for 742`convolution + batch norm fusion <https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html>`_ and 743`CPU performance analysis <https://pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html>`_. 744