xref: /aosp_15_r20/external/pytorch/docs/source/notes/modules.rst (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1.. _modules:
2
3Modules
4=======
5
6PyTorch uses modules to represent neural networks. Modules are:
7
8* **Building blocks of stateful computation.**
9  PyTorch provides a robust library of modules and makes it simple to define new custom modules, allowing for
10  easy construction of elaborate, multi-layer neural networks.
11* **Tightly integrated with PyTorch's**
12  `autograd <https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html>`_
13  **system.** Modules make it simple to specify learnable parameters for PyTorch's Optimizers to update.
14* **Easy to work with and transform.** Modules are straightforward to save and restore, transfer between
15  CPU / GPU / TPU devices, prune, quantize, and more.
16
17This note describes modules, and is intended for all PyTorch users. Since modules are so fundamental to PyTorch,
18many topics in this note are elaborated on in other notes or tutorials, and links to many of those documents
19are provided here as well.
20
21.. contents:: :local:
22
23A Simple Custom Module
24----------------------
25
26To get started, let's look at a simpler, custom version of PyTorch's :class:`~torch.nn.Linear` module.
27This module applies an affine transformation to its input.
28
29.. code-block:: python
30
31   import torch
32   from torch import nn
33
34   class MyLinear(nn.Module):
35     def __init__(self, in_features, out_features):
36       super().__init__()
37       self.weight = nn.Parameter(torch.randn(in_features, out_features))
38       self.bias = nn.Parameter(torch.randn(out_features))
39
40     def forward(self, input):
41       return (input @ self.weight) + self.bias
42
43This simple module has the following fundamental characteristics of modules:
44
45* **It inherits from the base Module class.**
46  All modules should subclass :class:`~torch.nn.Module` for composability with other modules.
47* **It defines some "state" that is used in computation.**
48  Here, the state consists of randomly-initialized ``weight`` and ``bias`` tensors that define the affine
49  transformation. Because each of these is defined as a :class:`~torch.nn.parameter.Parameter`, they are
50  *registered* for the module and will automatically be tracked and returned from calls
51  to :func:`~torch.nn.Module.parameters`. Parameters can be
52  considered the "learnable" aspects of the module's computation (more on this later). Note that modules
53  are not required to have state, and can also be stateless.
54* **It defines a forward() function that performs the computation.** For this affine transformation module, the input
55  is matrix-multiplied with the ``weight`` parameter (using the ``@`` short-hand notation) and added to the ``bias``
56  parameter to produce the output. More generally, the ``forward()`` implementation for a module can perform arbitrary
57  computation involving any number of inputs and outputs.
58
59This simple module demonstrates how modules package state and computation together. Instances of this module can be
60constructed and called:
61
62.. code-block:: python
63
64   m = MyLinear(4, 3)
65   sample_input = torch.randn(4)
66   m(sample_input)
67   : tensor([-0.3037, -1.0413, -4.2057], grad_fn=<AddBackward0>)
68
69Note that the module itself is callable, and that calling it invokes its ``forward()`` function.
70This name is in reference to the concepts of "forward pass" and "backward pass", which apply to each module.
71The "forward pass" is responsible for applying the computation represented by the module
72to the given input(s) (as shown in the above snippet). The "backward pass" computes gradients of
73module outputs with respect to its inputs, which can be used for "training" parameters through gradient
74descent methods. PyTorch's autograd system automatically takes care of this backward pass computation, so it
75is not required to manually implement a ``backward()`` function for each module. The process of training
76module parameters through successive forward / backward passes is covered in detail in
77:ref:`Neural Network Training with Modules`.
78
79The full set of parameters registered by the module can be iterated through via a call to
80:func:`~torch.nn.Module.parameters` or :func:`~torch.nn.Module.named_parameters`,
81where the latter includes each parameter's name:
82
83.. code-block:: python
84
85   for parameter in m.named_parameters():
86     print(parameter)
87   : ('weight', Parameter containing:
88   tensor([[ 1.0597,  1.1796,  0.8247],
89           [-0.5080, -1.2635, -1.1045],
90           [ 0.0593,  0.2469, -1.4299],
91           [-0.4926, -0.5457,  0.4793]], requires_grad=True))
92   ('bias', Parameter containing:
93   tensor([ 0.3634,  0.2015, -0.8525], requires_grad=True))
94
95In general, the parameters registered by a module are aspects of the module's computation that should be
96"learned". A later section of this note shows how to update these parameters using one of PyTorch's Optimizers.
97Before we get to that, however, let's first examine how modules can be composed with one another.
98
99Modules as Building Blocks
100--------------------------
101
102Modules can contain other modules, making them useful building blocks for developing more elaborate functionality.
103The simplest way to do this is using the :class:`~torch.nn.Sequential` module. It allows us to chain together
104multiple modules:
105
106.. code-block:: python
107
108   net = nn.Sequential(
109     MyLinear(4, 3),
110     nn.ReLU(),
111     MyLinear(3, 1)
112   )
113
114   sample_input = torch.randn(4)
115   net(sample_input)
116   : tensor([-0.6749], grad_fn=<AddBackward0>)
117
118Note that :class:`~torch.nn.Sequential` automatically feeds the output of the first ``MyLinear`` module as input
119into the :class:`~torch.nn.ReLU`, and the output of that as input into the second ``MyLinear`` module. As
120shown, it is limited to in-order chaining of modules with a single input and output.
121
122In general, it is recommended to define a custom module for anything beyond the simplest use cases, as this gives
123full flexibility on how submodules are used for a module's computation.
124
125For example, here's a simple neural network implemented as a custom module:
126
127.. code-block:: python
128
129   import torch.nn.functional as F
130
131   class Net(nn.Module):
132     def __init__(self):
133       super().__init__()
134       self.l0 = MyLinear(4, 3)
135       self.l1 = MyLinear(3, 1)
136     def forward(self, x):
137       x = self.l0(x)
138       x = F.relu(x)
139       x = self.l1(x)
140       return x
141
142This module is composed of two "children" or "submodules" (\ ``l0`` and ``l1``\ ) that define the layers of
143the neural network and are utilized for computation within the module's ``forward()`` method. Immediate
144children of a module can be iterated through via a call to :func:`~torch.nn.Module.children` or
145:func:`~torch.nn.Module.named_children`:
146
147.. code-block:: python
148
149   net = Net()
150   for child in net.named_children():
151     print(child)
152   : ('l0', MyLinear())
153   ('l1', MyLinear())
154
155To go deeper than just the immediate children, :func:`~torch.nn.Module.modules` and
156:func:`~torch.nn.Module.named_modules` *recursively* iterate through a module and its child modules:
157
158.. code-block:: python
159
160   class BigNet(nn.Module):
161     def __init__(self):
162       super().__init__()
163       self.l1 = MyLinear(5, 4)
164       self.net = Net()
165     def forward(self, x):
166       return self.net(self.l1(x))
167
168   big_net = BigNet()
169   for module in big_net.named_modules():
170     print(module)
171   : ('', BigNet(
172     (l1): MyLinear()
173     (net): Net(
174       (l0): MyLinear()
175       (l1): MyLinear()
176     )
177   ))
178   ('l1', MyLinear())
179   ('net', Net(
180     (l0): MyLinear()
181     (l1): MyLinear()
182   ))
183   ('net.l0', MyLinear())
184   ('net.l1', MyLinear())
185
186Sometimes, it's necessary for a module to dynamically define submodules.
187The :class:`~torch.nn.ModuleList` and :class:`~torch.nn.ModuleDict` modules are useful here; they
188register submodules from a list or dict:
189
190.. code-block:: python
191
192   class DynamicNet(nn.Module):
193     def __init__(self, num_layers):
194       super().__init__()
195       self.linears = nn.ModuleList(
196         [MyLinear(4, 4) for _ in range(num_layers)])
197       self.activations = nn.ModuleDict({
198         'relu': nn.ReLU(),
199         'lrelu': nn.LeakyReLU()
200       })
201       self.final = MyLinear(4, 1)
202     def forward(self, x, act):
203       for linear in self.linears:
204         x = linear(x)
205       x = self.activations[act](x)
206       x = self.final(x)
207       return x
208
209   dynamic_net = DynamicNet(3)
210   sample_input = torch.randn(4)
211   output = dynamic_net(sample_input, 'relu')
212
213For any given module, its parameters consist of its direct parameters as well as the parameters of all submodules.
214This means that calls to :func:`~torch.nn.Module.parameters` and :func:`~torch.nn.Module.named_parameters` will
215recursively include child parameters, allowing for convenient optimization of all parameters within the network:
216
217.. code-block:: python
218
219   for parameter in dynamic_net.named_parameters():
220     print(parameter)
221   : ('linears.0.weight', Parameter containing:
222   tensor([[-1.2051,  0.7601,  1.1065,  0.1963],
223           [ 3.0592,  0.4354,  1.6598,  0.9828],
224           [-0.4446,  0.4628,  0.8774,  1.6848],
225           [-0.1222,  1.5458,  1.1729,  1.4647]], requires_grad=True))
226   ('linears.0.bias', Parameter containing:
227   tensor([ 1.5310,  1.0609, -2.0940,  1.1266], requires_grad=True))
228   ('linears.1.weight', Parameter containing:
229   tensor([[ 2.1113, -0.0623, -1.0806,  0.3508],
230           [-0.0550,  1.5317,  1.1064, -0.5562],
231           [-0.4028, -0.6942,  1.5793, -1.0140],
232           [-0.0329,  0.1160, -1.7183, -1.0434]], requires_grad=True))
233   ('linears.1.bias', Parameter containing:
234   tensor([ 0.0361, -0.9768, -0.3889,  1.1613], requires_grad=True))
235   ('linears.2.weight', Parameter containing:
236   tensor([[-2.6340, -0.3887, -0.9979,  0.0767],
237           [-0.3526,  0.8756, -1.5847, -0.6016],
238           [-0.3269, -0.1608,  0.2897, -2.0829],
239           [ 2.6338,  0.9239,  0.6943, -1.5034]], requires_grad=True))
240   ('linears.2.bias', Parameter containing:
241   tensor([ 1.0268,  0.4489, -0.9403,  0.1571], requires_grad=True))
242   ('final.weight', Parameter containing:
243   tensor([[ 0.2509], [-0.5052], [ 0.3088], [-1.4951]], requires_grad=True))
244   ('final.bias', Parameter containing:
245   tensor([0.3381], requires_grad=True))
246
247It's also easy to move all parameters to a different device or change their precision using
248:func:`~torch.nn.Module.to`:
249
250.. code-block:: python
251
252   # Move all parameters to a CUDA device
253   dynamic_net.to(device='cuda')
254
255   # Change precision of all parameters
256   dynamic_net.to(dtype=torch.float64)
257
258   dynamic_net(torch.randn(5, device='cuda', dtype=torch.float64))
259   : tensor([6.5166], device='cuda:0', dtype=torch.float64, grad_fn=<AddBackward0>)
260
261More generally, an arbitrary function can be applied to a module and its submodules recursively by
262using the :func:`~torch.nn.Module.apply` function. For example, to apply custom initialization to parameters
263of a module and its submodules:
264
265.. code-block:: python
266
267   # Define a function to initialize Linear weights.
268   # Note that no_grad() is used here to avoid tracking this computation in the autograd graph.
269   @torch.no_grad()
270   def init_weights(m):
271     if isinstance(m, nn.Linear):
272       nn.init.xavier_normal_(m.weight)
273       m.bias.fill_(0.0)
274
275   # Apply the function recursively on the module and its submodules.
276   dynamic_net.apply(init_weights)
277
278These examples show how elaborate neural networks can be formed through module composition and conveniently
279manipulated. To allow for quick and easy construction of neural networks with minimal boilerplate, PyTorch
280provides a large library of performant modules within the :mod:`torch.nn` namespace that perform common neural
281network operations like pooling, convolutions, loss functions, etc.
282
283In the next section, we give a full example of training a neural network.
284
285For more information, check out:
286
287* Library of PyTorch-provided modules: `torch.nn <https://pytorch.org/docs/stable/nn.html>`_
288* Defining neural net modules: https://pytorch.org/tutorials/beginner/examples_nn/polynomial_module.html
289
290.. _Neural Network Training with Modules:
291
292Neural Network Training with Modules
293------------------------------------
294
295Once a network is built, it has to be trained, and its parameters can be easily optimized with one of PyTorch’s
296Optimizers from :mod:`torch.optim`:
297
298.. code-block:: python
299
300   # Create the network (from previous section) and optimizer
301   net = Net()
302   optimizer = torch.optim.SGD(net.parameters(), lr=1e-4, weight_decay=1e-2, momentum=0.9)
303
304   # Run a sample training loop that "teaches" the network
305   # to output the constant zero function
306   for _ in range(10000):
307     input = torch.randn(4)
308     output = net(input)
309     loss = torch.abs(output)
310     net.zero_grad()
311     loss.backward()
312     optimizer.step()
313
314   # After training, switch the module to eval mode to do inference, compute performance metrics, etc.
315   # (see discussion below for a description of training and evaluation modes)
316   ...
317   net.eval()
318   ...
319
320In this simplified example, the network learns to simply output zero, as any non-zero output is "penalized" according
321to its absolute value by employing :func:`torch.abs` as a loss function. While this is not a very interesting task, the
322key parts of training are present:
323
324* A network is created.
325* An optimizer (in this case, a stochastic gradient descent optimizer) is created, and the network’s
326  parameters are associated with it.
327* A training loop...
328    * acquires an input,
329    * runs the network,
330    * computes a loss,
331    * zeros the network’s parameters’ gradients,
332    * calls loss.backward() to update the parameters’ gradients,
333    * calls optimizer.step() to apply the gradients to the parameters.
334
335After the above snippet has been run, note that the network's parameters have changed. In particular, examining the
336value of ``l1``\ 's ``weight`` parameter shows that its values are now much closer to 0 (as may be expected):
337
338.. code-block:: python
339
340   print(net.l1.weight)
341   : Parameter containing:
342   tensor([[-0.0013],
343           [ 0.0030],
344           [-0.0008]], requires_grad=True)
345
346Note that the above process is done entirely while the network module is in "training mode". Modules default to
347training mode and can be switched between training and evaluation modes using :func:`~torch.nn.Module.train` and
348:func:`~torch.nn.Module.eval`. They can behave differently depending on which mode they are in. For example, the
349:class:`~torch.nn.BatchNorm` module maintains a running mean and variance during training that are not updated
350when the module is in evaluation mode. In general, modules should be in training mode during training
351and only switched to evaluation mode for inference or evaluation. Below is an example of a custom module
352that behaves differently between the two modes:
353
354.. code-block:: python
355
356   class ModalModule(nn.Module):
357     def __init__(self):
358       super().__init__()
359
360     def forward(self, x):
361       if self.training:
362         # Add a constant only in training mode.
363         return x + 1.
364       else:
365         return x
366
367
368   m = ModalModule()
369   x = torch.randn(4)
370
371   print('training mode output: {}'.format(m(x)))
372   : tensor([1.6614, 1.2669, 1.0617, 1.6213, 0.5481])
373
374   m.eval()
375   print('evaluation mode output: {}'.format(m(x)))
376   : tensor([ 0.6614,  0.2669,  0.0617,  0.6213, -0.4519])
377
378Training neural networks can often be tricky. For more information, check out:
379
380* Using Optimizers: https://pytorch.org/tutorials/beginner/examples_nn/two_layer_net_optim.html.
381* Neural network training: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html
382* Introduction to autograd: https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html
383
384Module State
385------------
386
387In the previous section, we demonstrated training a module's "parameters", or learnable aspects of computation.
388Now, if we want to save the trained model to disk, we can do so by saving its ``state_dict`` (i.e. "state dictionary"):
389
390.. code-block:: python
391
392   # Save the module
393   torch.save(net.state_dict(), 'net.pt')
394
395   ...
396
397   # Load the module later on
398   new_net = Net()
399   new_net.load_state_dict(torch.load('net.pt'))
400   : <All keys matched successfully>
401
402A module's ``state_dict`` contains state that affects its computation. This includes, but is not limited to, the
403module's parameters. For some modules, it may be useful to have state beyond parameters that affects module
404computation but is not learnable. For such cases, PyTorch provides the concept of "buffers", both "persistent"
405and "non-persistent". Following is an overview of the various types of state a module can have:
406
407* **Parameters**\ : learnable aspects of computation; contained within the ``state_dict``
408* **Buffers**\ : non-learnable aspects of computation
409
410  * **Persistent** buffers: contained within the ``state_dict`` (i.e. serialized when saving & loading)
411  * **Non-persistent** buffers: not contained within the ``state_dict`` (i.e. left out of serialization)
412
413As a motivating example for the use of buffers, consider a simple module that maintains a running mean. We want
414the current value of the running mean to be considered part of the module's ``state_dict`` so that it will be
415restored when loading a serialized form of the module, but we don't want it to be learnable.
416This snippet shows how to use :func:`~torch.nn.Module.register_buffer` to accomplish this:
417
418.. code-block:: python
419
420   class RunningMean(nn.Module):
421     def __init__(self, num_features, momentum=0.9):
422       super().__init__()
423       self.momentum = momentum
424       self.register_buffer('mean', torch.zeros(num_features))
425     def forward(self, x):
426       self.mean = self.momentum * self.mean + (1.0 - self.momentum) * x
427       return self.mean
428
429Now, the current value of the running mean is considered part of the module's ``state_dict``
430and will be properly restored when loading the module from disk:
431
432.. code-block:: python
433
434   m = RunningMean(4)
435   for _ in range(10):
436     input = torch.randn(4)
437     m(input)
438
439   print(m.state_dict())
440   : OrderedDict([('mean', tensor([ 0.1041, -0.1113, -0.0647,  0.1515]))]))
441
442   # Serialized form will contain the 'mean' tensor
443   torch.save(m.state_dict(), 'mean.pt')
444
445   m_loaded = RunningMean(4)
446   m_loaded.load_state_dict(torch.load('mean.pt'))
447   assert(torch.all(m.mean == m_loaded.mean))
448
449As mentioned previously, buffers can be left out of the module's ``state_dict`` by marking them as non-persistent:
450
451.. code-block:: python
452
453   self.register_buffer('unserialized_thing', torch.randn(5), persistent=False)
454
455Both persistent and non-persistent buffers are affected by model-wide device / dtype changes applied with
456:func:`~torch.nn.Module.to`:
457
458.. code-block:: python
459
460   # Moves all module parameters and buffers to the specified device / dtype
461   m.to(device='cuda', dtype=torch.float64)
462
463Buffers of a module can be iterated over using :func:`~torch.nn.Module.buffers` or
464:func:`~torch.nn.Module.named_buffers`.
465
466.. code-block:: python
467
468   for buffer in m.named_buffers():
469     print(buffer)
470
471The following class demonstrates the various ways of registering parameters and buffers within a module:
472
473.. code-block:: python
474
475   class StatefulModule(nn.Module):
476     def __init__(self):
477       super().__init__()
478       # Setting a nn.Parameter as an attribute of the module automatically registers the tensor
479       # as a parameter of the module.
480       self.param1 = nn.Parameter(torch.randn(2))
481
482       # Alternative string-based way to register a parameter.
483       self.register_parameter('param2', nn.Parameter(torch.randn(3)))
484
485       # Reserves the "param3" attribute as a parameter, preventing it from being set to anything
486       # except a parameter. "None" entries like this will not be present in the module's state_dict.
487       self.register_parameter('param3', None)
488
489       # Registers a list of parameters.
490       self.param_list = nn.ParameterList([nn.Parameter(torch.randn(2)) for i in range(3)])
491
492       # Registers a dictionary of parameters.
493       self.param_dict = nn.ParameterDict({
494         'foo': nn.Parameter(torch.randn(3)),
495         'bar': nn.Parameter(torch.randn(4))
496       })
497
498       # Registers a persistent buffer (one that appears in the module's state_dict).
499       self.register_buffer('buffer1', torch.randn(4), persistent=True)
500
501       # Registers a non-persistent buffer (one that does not appear in the module's state_dict).
502       self.register_buffer('buffer2', torch.randn(5), persistent=False)
503
504       # Reserves the "buffer3" attribute as a buffer, preventing it from being set to anything
505       # except a buffer. "None" entries like this will not be present in the module's state_dict.
506       self.register_buffer('buffer3', None)
507
508       # Adding a submodule registers its parameters as parameters of the module.
509       self.linear = nn.Linear(2, 3)
510
511   m = StatefulModule()
512
513   # Save and load state_dict.
514   torch.save(m.state_dict(), 'state.pt')
515   m_loaded = StatefulModule()
516   m_loaded.load_state_dict(torch.load('state.pt'))
517
518   # Note that non-persistent buffer "buffer2" and reserved attributes "param3" and "buffer3" do
519   # not appear in the state_dict.
520   print(m_loaded.state_dict())
521   : OrderedDict([('param1', tensor([-0.0322,  0.9066])),
522                  ('param2', tensor([-0.4472,  0.1409,  0.4852])),
523                  ('buffer1', tensor([ 0.6949, -0.1944,  1.2911, -2.1044])),
524                  ('param_list.0', tensor([ 0.4202, -0.1953])),
525                  ('param_list.1', tensor([ 1.5299, -0.8747])),
526                  ('param_list.2', tensor([-1.6289,  1.4898])),
527                  ('param_dict.bar', tensor([-0.6434,  1.5187,  0.0346, -0.4077])),
528                  ('param_dict.foo', tensor([-0.0845, -1.4324,  0.7022])),
529                  ('linear.weight', tensor([[-0.3915, -0.6176],
530                                            [ 0.6062, -0.5992],
531                                            [ 0.4452, -0.2843]])),
532                  ('linear.bias', tensor([-0.3710, -0.0795, -0.3947]))])
533
534For more information, check out:
535
536* Saving and loading: https://pytorch.org/tutorials/beginner/saving_loading_models.html
537* Serialization semantics: https://pytorch.org/docs/main/notes/serialization.html
538* What is a state dict? https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html
539
540Module Initialization
541---------------------
542
543By default, parameters and floating-point buffers for modules provided by :mod:`torch.nn` are initialized during
544module instantiation as 32-bit floating point values on the CPU using an initialization scheme determined to
545perform well historically for the module type. For certain use cases, it may be desired to initialize with a different
546dtype, device (e.g. GPU), or initialization technique.
547
548Examples:
549
550.. code-block:: python
551
552   # Initialize module directly onto GPU.
553   m = nn.Linear(5, 3, device='cuda')
554
555   # Initialize module with 16-bit floating point parameters.
556   m = nn.Linear(5, 3, dtype=torch.half)
557
558   # Skip default parameter initialization and perform custom (e.g. orthogonal) initialization.
559   m = torch.nn.utils.skip_init(nn.Linear, 5, 3)
560   nn.init.orthogonal_(m.weight)
561
562Note that the device and dtype options demonstrated above also apply to any floating-point buffers registered
563for the module:
564
565.. code-block:: python
566
567   m = nn.BatchNorm2d(3, dtype=torch.half)
568   print(m.running_mean)
569   : tensor([0., 0., 0.], dtype=torch.float16)
570
571While module writers can use any device or dtype to initialize parameters in their custom modules, good practice is
572to use ``dtype=torch.float`` and ``device='cpu'`` by default as well. Optionally, you can provide full flexibility
573in these areas for your custom module by conforming to the convention demonstrated above that all
574:mod:`torch.nn` modules follow:
575
576* Provide a ``device`` constructor kwarg that applies to any parameters / buffers registered by the module.
577* Provide a ``dtype`` constructor kwarg that applies to any parameters / floating-point buffers registered by
578  the module.
579* Only use initialization functions (i.e. functions from :mod:`torch.nn.init`) on parameters and buffers within the
580  module's constructor. Note that this is only required to use :func:`~torch.nn.utils.skip_init`; see
581  `this page <https://pytorch.org/tutorials/prototype/skip_param_init.html#updating-modules-to-support-skipping-initialization>`_ for an explanation.
582
583For more information, check out:
584
585* Skipping module parameter initialization: https://pytorch.org/tutorials/prototype/skip_param_init.html
586
587Module Hooks
588------------
589
590In :ref:`Neural Network Training with Modules`, we demonstrated the training process for a module, which iteratively
591performs forward and backward passes, updating module parameters each iteration. For more control
592over this process, PyTorch provides "hooks" that can perform arbitrary computation during a forward or backward
593pass, even modifying how the pass is done if desired. Some useful examples for this functionality include
594debugging, visualizing activations, examining gradients in-depth, etc. Hooks can be added to modules
595you haven't written yourself, meaning this functionality can be applied to third-party or PyTorch-provided modules.
596
597PyTorch provides two types of hooks for modules:
598
599* **Forward hooks** are called during the forward pass. They can be installed for a given module with
600  :func:`~torch.nn.Module.register_forward_pre_hook` and :func:`~torch.nn.Module.register_forward_hook`.
601  These hooks will be called respectively just before the forward function is called and just after it is called.
602  Alternatively, these hooks can be installed globally for all modules with the analogous
603  :func:`~torch.nn.modules.module.register_module_forward_pre_hook` and
604  :func:`~torch.nn.modules.module.register_module_forward_hook` functions.
605* **Backward hooks** are called during the backward pass. They can be installed with
606  :func:`~torch.nn.Module.register_full_backward_pre_hook` and :func:`~torch.nn.Module.register_full_backward_hook`.
607  These hooks will be called when the backward for this Module has been computed.
608  :func:`~torch.nn.Module.register_full_backward_pre_hook` will allow the user to access the gradients for outputs
609  while :func:`~torch.nn.Module.register_full_backward_hook` will allow the user to access the gradients
610  both the inputs and outputs. Alternatively, they can be installed globally for all modules with
611  :func:`~torch.nn.modules.module.register_module_full_backward_hook` and
612  :func:`~torch.nn.modules.module.register_module_full_backward_pre_hook`.
613
614All hooks allow the user to return an updated value that will be used throughout the remaining computation.
615Thus, these hooks can be used to either execute arbitrary code along the regular module forward/backward or
616modify some inputs/outputs without having to change the module's ``forward()`` function.
617
618Below is an example demonstrating usage of forward and backward hooks:
619
620.. code-block:: python
621
622   torch.manual_seed(1)
623
624   def forward_pre_hook(m, inputs):
625     # Allows for examination and modification of the input before the forward pass.
626     # Note that inputs are always wrapped in a tuple.
627     input = inputs[0]
628     return input + 1.
629
630   def forward_hook(m, inputs, output):
631     # Allows for examination of inputs / outputs and modification of the outputs
632     # after the forward pass. Note that inputs are always wrapped in a tuple while outputs
633     # are passed as-is.
634
635     # Residual computation a la ResNet.
636     return output + inputs[0]
637
638   def backward_hook(m, grad_inputs, grad_outputs):
639     # Allows for examination of grad_inputs / grad_outputs and modification of
640     # grad_inputs used in the rest of the backwards pass. Note that grad_inputs and
641     # grad_outputs are always wrapped in tuples.
642     new_grad_inputs = [torch.ones_like(gi) * 42. for gi in grad_inputs]
643     return new_grad_inputs
644
645   # Create sample module & input.
646   m = nn.Linear(3, 3)
647   x = torch.randn(2, 3, requires_grad=True)
648
649   # ==== Demonstrate forward hooks. ====
650   # Run input through module before and after adding hooks.
651   print('output with no forward hooks: {}'.format(m(x)))
652   : output with no forward hooks: tensor([[-0.5059, -0.8158,  0.2390],
653                                           [-0.0043,  0.4724, -0.1714]], grad_fn=<AddmmBackward>)
654
655   # Note that the modified input results in a different output.
656   forward_pre_hook_handle = m.register_forward_pre_hook(forward_pre_hook)
657   print('output with forward pre hook: {}'.format(m(x)))
658   : output with forward pre hook: tensor([[-0.5752, -0.7421,  0.4942],
659                                           [-0.0736,  0.5461,  0.0838]], grad_fn=<AddmmBackward>)
660
661   # Note the modified output.
662   forward_hook_handle = m.register_forward_hook(forward_hook)
663   print('output with both forward hooks: {}'.format(m(x)))
664   : output with both forward hooks: tensor([[-1.0980,  0.6396,  0.4666],
665                                             [ 0.3634,  0.6538,  1.0256]], grad_fn=<AddBackward0>)
666
667   # Remove hooks; note that the output here matches the output before adding hooks.
668   forward_pre_hook_handle.remove()
669   forward_hook_handle.remove()
670   print('output after removing forward hooks: {}'.format(m(x)))
671   : output after removing forward hooks: tensor([[-0.5059, -0.8158,  0.2390],
672                                                  [-0.0043,  0.4724, -0.1714]], grad_fn=<AddmmBackward>)
673
674   # ==== Demonstrate backward hooks. ====
675   m(x).sum().backward()
676   print('x.grad with no backwards hook: {}'.format(x.grad))
677   : x.grad with no backwards hook: tensor([[ 0.4497, -0.5046,  0.3146],
678                                            [ 0.4497, -0.5046,  0.3146]])
679
680   # Clear gradients before running backward pass again.
681   m.zero_grad()
682   x.grad.zero_()
683
684   m.register_full_backward_hook(backward_hook)
685   m(x).sum().backward()
686   print('x.grad with backwards hook: {}'.format(x.grad))
687   : x.grad with backwards hook: tensor([[42., 42., 42.],
688                                         [42., 42., 42.]])
689
690Advanced Features
691-----------------
692
693PyTorch also provides several more advanced features that are designed to work with modules. All these functionalities
694are available for custom-written modules, with the small caveat that certain features may require modules to conform
695to particular constraints in order to be supported. In-depth discussion of these features and the corresponding
696requirements can be found in the links below.
697
698Distributed Training
699********************
700
701Various methods for distributed training exist within PyTorch, both for scaling up training using multiple GPUs
702as well as training across multiple machines. Check out the
703`distributed training overview page <https://pytorch.org/tutorials/beginner/dist_overview.html>`_ for
704detailed information on how to utilize these.
705
706Profiling Performance
707*********************
708
709The `PyTorch Profiler <https://pytorch.org/tutorials/beginner/profiler.html>`_ can be useful for identifying
710performance bottlenecks within your models. It measures and outputs performance characteristics for
711both memory usage and time spent.
712
713Improving Performance with Quantization
714***************************************
715
716Applying quantization techniques to modules can improve performance and memory usage by utilizing lower
717bitwidths than floating-point precision. Check out the various PyTorch-provided mechanisms for quantization
718`here <https://pytorch.org/docs/stable/quantization.html>`_.
719
720Improving Memory Usage with Pruning
721***********************************
722
723Large deep learning models are often over-parametrized, resulting in high memory usage. To combat this, PyTorch
724provides mechanisms for model pruning, which can help reduce memory usage while maintaining task accuracy. The
725`Pruning tutorial <https://pytorch.org/tutorials/intermediate/pruning_tutorial.html>`_ describes how to utilize
726the pruning techniques PyTorch provides or define custom pruning techniques as necessary.
727
728Parametrizations
729****************
730
731For certain applications, it can be beneficial to constrain the parameter space during model training. For example,
732enforcing orthogonality of the learned parameters can improve convergence for RNNs. PyTorch provides a mechanism for
733applying `parametrizations <https://pytorch.org/tutorials/intermediate/parametrizations.html>`_ such as this, and
734further allows for custom constraints to be defined.
735
736Transforming Modules with FX
737****************************
738
739The `FX <https://pytorch.org/docs/stable/fx.html>`_ component of PyTorch provides a flexible way to transform
740modules by operating directly on module computation graphs. This can be used to programmatically generate or
741manipulate modules for a broad array of use cases. To explore FX, check out these examples of using FX for
742`convolution + batch norm fusion <https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html>`_ and
743`CPU performance analysis <https://pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html>`_.
744