xref: /aosp_15_r20/external/pytorch/torch/nn/modules/module.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import functools
4import inspect
5import itertools
6import warnings
7import weakref
8from collections import namedtuple, OrderedDict
9from typing import (
10    Any,
11    Callable,
12    Dict,
13    Iterator,
14    List,
15    Mapping,
16    Optional,
17    overload,
18    Set,
19    Tuple,
20    TypeVar,
21    Union,
22)
23from typing_extensions import Self
24
25import torch
26from torch import device, dtype, Tensor
27from torch._prims_common import DeviceLikeType
28from torch.nn.parameter import Buffer, Parameter
29from torch.utils._python_dispatch import is_traceable_wrapper_subclass
30from torch.utils.hooks import BackwardHook, RemovableHandle
31
32
33__all__ = [
34    "register_module_forward_pre_hook",
35    "register_module_forward_hook",
36    "register_module_full_backward_pre_hook",
37    "register_module_backward_hook",
38    "register_module_full_backward_hook",
39    "register_module_buffer_registration_hook",
40    "register_module_module_registration_hook",
41    "register_module_parameter_registration_hook",
42    "Module",
43]
44
45_grad_t = Union[Tuple[Tensor, ...], Tensor]
46# See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use
47# of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be
48# the type of the subclass, not the looser type of `Module`.
49T = TypeVar("T", bound="Module")
50
51
52class _IncompatibleKeys(
53    namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
54):
55    def __repr__(self):
56        if not self.missing_keys and not self.unexpected_keys:
57            return "<All keys matched successfully>"
58        return super().__repr__()
59
60    __str__ = __repr__
61
62
63def _addindent(s_, numSpaces):
64    s = s_.split("\n")
65    # don't do anything for single-line stuff
66    if len(s) == 1:
67        return s_
68    first = s.pop(0)
69    s = [(numSpaces * " ") + line for line in s]
70    s = "\n".join(s)
71    s = first + "\n" + s
72    return s
73
74
75r"""This tracks hooks common to all modules that are executed immediately before
76.registering the buffer/module/parameter"""
77_global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict()
78_global_module_registration_hooks: Dict[int, Callable] = OrderedDict()
79_global_parameter_registration_hooks: Dict[int, Callable] = OrderedDict()
80
81
82class _WrappedHook:
83    def __init__(self, hook: Callable, module: Optional["Module"] = None):
84        self.hook: Callable = hook
85        functools.update_wrapper(self, hook)
86
87        self.with_module: bool = False
88
89        if module is not None:
90            self.module: weakref.ReferenceType[Module] = weakref.ref(module)
91            self.with_module = True
92
93    def __call__(self, *args: Any, **kwargs: Any) -> Any:
94        if self.with_module:
95            module = self.module()
96            if module is None:
97                raise RuntimeError("You are trying to call the hook of a dead Module!")
98            return self.hook(module, *args, **kwargs)
99        return self.hook(*args, **kwargs)
100
101    def __getstate__(self) -> Dict:
102        result = {"hook": self.hook, "with_module": self.with_module}
103        if self.with_module:
104            result["module"] = self.module()
105
106        return result
107
108    def __setstate__(self, state: Dict):
109        self.hook = state["hook"]
110        self.with_module = state["with_module"]
111
112        if self.with_module:
113            if state["module"] is None:
114                raise RuntimeError(
115                    "You are trying to revive the hook of a dead Module!"
116                )
117            self.module = weakref.ref(state["module"])
118
119
120r"""This tracks hooks common to all modules that are executed before/after
121calling forward and backward. This is global state used for debugging/profiling
122purposes"""
123_global_backward_pre_hooks: Dict[int, Callable] = OrderedDict()
124_global_backward_hooks: Dict[int, Callable] = OrderedDict()
125_global_is_full_backward_hook: Optional[bool] = None
126_global_forward_pre_hooks: Dict[int, Callable] = OrderedDict()
127_global_forward_hooks: Dict[int, Callable] = OrderedDict()
128_global_forward_hooks_always_called: Dict[int, bool] = OrderedDict()
129
130_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
131
132
133def register_module_buffer_registration_hook(
134    hook: Callable[..., None],
135) -> RemovableHandle:
136    r"""Register a buffer registration hook common to all modules.
137
138    .. warning ::
139
140        This adds global state to the `nn.Module` module
141
142    The hook will be called every time :func:`register_buffer` is invoked.
143    It should have the following signature::
144
145        hook(module, name, buffer) -> None or new buffer
146
147    The hook can modify the input or return a single modified value in the hook.
148
149    Returns:
150        :class:`torch.utils.hooks.RemovableHandle`:
151            a handle that can be used to remove the added hook by calling
152            ``handle.remove()``
153    """
154    handle = RemovableHandle(_global_buffer_registration_hooks)
155    _global_buffer_registration_hooks[handle.id] = hook
156    return handle
157
158
159def register_module_module_registration_hook(
160    hook: Callable[..., None],
161) -> RemovableHandle:
162    r"""Register a module registration hook common to all modules.
163
164    .. warning ::
165
166        This adds global state to the `nn.Module` module
167
168    The hook will be called every time :func:`register_module` is invoked.
169    It should have the following signature::
170
171        hook(module, name, submodule) -> None or new submodule
172
173    The hook can modify the input or return a single modified value in the hook.
174
175    Returns:
176        :class:`torch.utils.hooks.RemovableHandle`:
177            a handle that can be used to remove the added hook by calling
178            ``handle.remove()``
179    """
180    handle = RemovableHandle(_global_module_registration_hooks)
181    _global_module_registration_hooks[handle.id] = hook
182    return handle
183
184
185def register_module_parameter_registration_hook(
186    hook: Callable[..., None],
187) -> RemovableHandle:
188    r"""Register a parameter registration hook common to all modules.
189
190    .. warning ::
191
192        This adds global state to the `nn.Module` module
193
194    The hook will be called every time :func:`register_parameter` is invoked.
195    It should have the following signature::
196
197        hook(module, name, param) -> None or new parameter
198
199    The hook can modify the input or return a single modified value in the hook.
200
201    Returns:
202        :class:`torch.utils.hooks.RemovableHandle`:
203            a handle that can be used to remove the added hook by calling
204            ``handle.remove()``
205    """
206    handle = RemovableHandle(_global_parameter_registration_hooks)
207    _global_parameter_registration_hooks[handle.id] = hook
208    return handle
209
210
211def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle:
212    r"""Register a forward pre-hook common to all modules.
213
214    .. warning ::
215
216        This adds global state to the `nn.module` module
217        and it is only intended for debugging/profiling purposes.
218
219    The hook will be called every time before :func:`forward` is invoked.
220    It should have the following signature::
221
222        hook(module, input) -> None or modified input
223
224    The input contains only the positional arguments given to the module.
225    Keyword arguments won't be passed to the hooks and only to the ``forward``.
226    The hook can modify the input. User can either return a tuple or a
227    single modified value in the hook. We will wrap the value into a tuple
228    if a single value is returned(unless that value is already a tuple).
229
230    This hook has precedence over the specific module hooks registered with
231    ``register_forward_pre_hook``.
232
233    Returns:
234        :class:`torch.utils.hooks.RemovableHandle`:
235            a handle that can be used to remove the added hook by calling
236            ``handle.remove()``
237    """
238    handle = RemovableHandle(_global_forward_pre_hooks)
239    _global_forward_pre_hooks[handle.id] = hook
240    return handle
241
242
243def register_module_forward_hook(
244    hook: Callable[..., None],
245    *,
246    always_call: bool = False,
247) -> RemovableHandle:
248    r"""Register a global forward hook for all the modules.
249
250    .. warning ::
251
252        This adds global state to the `nn.module` module
253        and it is only intended for debugging/profiling purposes.
254
255    The hook will be called every time after :func:`forward` has computed an output.
256    It should have the following signature::
257
258        hook(module, input, output) -> None or modified output
259
260    The input contains only the positional arguments given to the module.
261    Keyword arguments won't be passed to the hooks and only to the ``forward``.
262    The hook can modify the output. It can modify the input inplace but
263    it will not have effect on forward since this is called after
264    :func:`forward` is called.
265
266    Parameters:
267        hook (Callable): The user defined hook to be registered.
268        always_call (bool): If ``True`` the ``hook`` will be run regardless of
269            whether an exception is raised while calling the Module.
270            Default: ``False``
271    Returns:
272        :class:`torch.utils.hooks.RemovableHandle`:
273            a handle that can be used to remove the added hook by calling
274            ``handle.remove()``
275
276    This hook will be executed before specific module hooks registered with
277    ``register_forward_hook``.
278    """
279    handle = RemovableHandle(
280        _global_forward_hooks, extra_dict=_global_forward_hooks_always_called
281    )
282    _global_forward_hooks[handle.id] = hook
283    if always_call:
284        _global_forward_hooks_always_called[handle.id] = True
285    return handle
286
287
288def register_module_backward_hook(
289    hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]],
290) -> RemovableHandle:
291    r"""Register a backward hook common to all the modules.
292
293    This function is deprecated in favor of
294    :func:`torch.nn.modules.module.register_module_full_backward_hook`
295    and the behavior of this function will change in future versions.
296
297    Returns:
298        :class:`torch.utils.hooks.RemovableHandle`:
299            a handle that can be used to remove the added hook by calling
300            ``handle.remove()``
301
302    """
303    global _global_is_full_backward_hook
304    if _global_is_full_backward_hook is True:
305        raise RuntimeError(
306            "Cannot use both regular backward hooks and full backward hooks as a "
307            "global Module hook. Please use only one of them."
308        )
309
310    _global_is_full_backward_hook = False
311
312    handle = RemovableHandle(_global_backward_hooks)
313    _global_backward_hooks[handle.id] = hook
314    return handle
315
316
317def register_module_full_backward_pre_hook(
318    hook: Callable[["Module", _grad_t], Union[None, _grad_t]],
319) -> RemovableHandle:
320    r"""Register a backward pre-hook common to all the modules.
321
322    .. warning ::
323        This adds global state to the `nn.module` module
324        and it is only intended for debugging/profiling purposes.
325
326    Hooks registered using this function behave in the same way as those
327    registered by :meth:`torch.nn.Module.register_full_backward_pre_hook`.
328    Refer to its documentation for more details.
329
330    Hooks registered using this function will be called before hooks registered
331    using :meth:`torch.nn.Module.register_full_backward_pre_hook`.
332
333    Returns:
334        :class:`torch.utils.hooks.RemovableHandle`:
335            a handle that can be used to remove the added hook by calling
336            ``handle.remove()``
337
338    """
339    handle = RemovableHandle(_global_backward_pre_hooks)
340    _global_backward_pre_hooks[handle.id] = hook
341    return handle
342
343
344def register_module_full_backward_hook(
345    hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]],
346) -> RemovableHandle:
347    r"""Register a backward hook common to all the modules.
348
349    .. warning ::
350        This adds global state to the `nn.module` module
351        and it is only intended for debugging/profiling purposes.
352
353    Hooks registered using this function behave in the same way as those
354    registered by :meth:`torch.nn.Module.register_full_backward_hook`.
355    Refer to its documentation for more details.
356
357    Hooks registered using this function will be called before hooks registered
358    using :meth:`torch.nn.Module.register_full_backward_hook`.
359
360    Returns:
361        :class:`torch.utils.hooks.RemovableHandle`:
362            a handle that can be used to remove the added hook by calling
363            ``handle.remove()``
364
365    """
366    global _global_is_full_backward_hook
367    if _global_is_full_backward_hook is False:
368        raise RuntimeError(
369            "Cannot use both regular backward hooks and full backward hooks as a "
370            "global Module hook. Please use only one of them."
371        )
372
373    _global_is_full_backward_hook = True
374
375    handle = RemovableHandle(_global_backward_hooks)
376    _global_backward_hooks[handle.id] = hook
377    return handle
378
379
380# Trick mypy into not applying contravariance rules to inputs by defining
381# forward as a value, rather than a function.  See also
382# https://github.com/python/mypy/issues/8795
383def _forward_unimplemented(self, *input: Any) -> None:
384    r"""Define the computation performed at every call.
385
386    Should be overridden by all subclasses.
387
388    .. note::
389        Although the recipe for forward pass needs to be defined within
390        this function, one should call the :class:`Module` instance afterwards
391        instead of this since the former takes care of running the
392        registered hooks while the latter silently ignores them.
393    """
394    raise NotImplementedError(
395        f'Module [{type(self).__name__}] is missing the required "forward" function'
396    )
397
398
399class Module:
400    r"""Base class for all neural network modules.
401
402    Your models should also subclass this class.
403
404    Modules can also contain other Modules, allowing to nest them in
405    a tree structure. You can assign the submodules as regular attributes::
406
407        import torch.nn as nn
408        import torch.nn.functional as F
409
410        class Model(nn.Module):
411            def __init__(self) -> None:
412                super().__init__()
413                self.conv1 = nn.Conv2d(1, 20, 5)
414                self.conv2 = nn.Conv2d(20, 20, 5)
415
416            def forward(self, x):
417                x = F.relu(self.conv1(x))
418                return F.relu(self.conv2(x))
419
420    Submodules assigned in this way will be registered, and will have their
421    parameters converted too when you call :meth:`to`, etc.
422
423    .. note::
424        As per the example above, an ``__init__()`` call to the parent class
425        must be made before assignment on the child.
426
427    :ivar training: Boolean represents whether this module is in training or
428                    evaluation mode.
429    :vartype training: bool
430    """
431
432    dump_patches: bool = False
433
434    _version: int = 1
435    r"""This allows better BC support for :meth:`load_state_dict`. In
436    :meth:`state_dict`, the version number will be saved as in the attribute
437    `_metadata` of the returned state dict, and thus pickled. `_metadata` is a
438    dictionary with keys that follow the naming convention of state dict. See
439    ``_load_from_state_dict`` on how to use this information in loading.
440
441    If new parameters/buffers are added/removed from a module, this number shall
442    be bumped, and the module's `_load_from_state_dict` method can compare the
443    version number and do appropriate changes if the state dict is from before
444    the change."""
445
446    training: bool
447    _parameters: Dict[str, Optional[Parameter]]
448    _buffers: Dict[str, Optional[Tensor]]
449    _non_persistent_buffers_set: Set[str]
450    _backward_pre_hooks: Dict[int, Callable]
451    _backward_hooks: Dict[int, Callable]
452    _is_full_backward_hook: Optional[bool]
453    _forward_hooks: Dict[int, Callable]
454    # Marks whether the corresponding _forward_hooks accept kwargs or not.
455    # As JIT does not support Set[int], this dict is used as a set, where all
456    # hooks represented in this dict accept kwargs.
457    _forward_hooks_with_kwargs: Dict[int, bool]
458    # forward hooks that should always be called even if an exception is raised
459    _forward_hooks_always_called: Dict[int, bool]
460    _forward_pre_hooks: Dict[int, Callable]
461    # Marks whether the corresponding _forward_hooks accept kwargs or not.
462    # As JIT does not support Set[int], this dict is used as a set, where all
463    # hooks represented in this dict accept kwargs.
464    _forward_pre_hooks_with_kwargs: Dict[int, bool]
465    _state_dict_hooks: Dict[int, Callable]
466    _load_state_dict_pre_hooks: Dict[int, Callable]
467    _state_dict_pre_hooks: Dict[int, Callable]
468    _load_state_dict_post_hooks: Dict[int, Callable]
469    _modules: Dict[str, Optional["Module"]]
470    call_super_init: bool = False
471    _compiled_call_impl: Optional[Callable] = None
472
473    def __init__(self, *args, **kwargs) -> None:
474        """Initialize internal Module state, shared by both nn.Module and ScriptModule."""
475        torch._C._log_api_usage_once("python.nn_module")
476
477        # Backward compatibility: no args used to be allowed when call_super_init=False
478        if self.call_super_init is False and bool(kwargs):
479            raise TypeError(
480                f"{type(self).__name__}.__init__() got an unexpected keyword argument '{next(iter(kwargs))}'"
481                ""
482            )
483
484        if self.call_super_init is False and bool(args):
485            raise TypeError(
486                f"{type(self).__name__}.__init__() takes 1 positional argument but {len(args) + 1} were"
487                " given"
488            )
489
490        """
491        Calls super().__setattr__('a', a) instead of the typical self.a = a
492        to avoid Module.__setattr__ overhead. Module's __setattr__ has special
493        handling for parameters, submodules, and buffers but simply calls into
494        super().__setattr__ for all other attributes.
495        """
496        super().__setattr__("training", True)
497        super().__setattr__("_parameters", {})
498        super().__setattr__("_buffers", {})
499        super().__setattr__("_non_persistent_buffers_set", set())
500        super().__setattr__("_backward_pre_hooks", OrderedDict())
501        super().__setattr__("_backward_hooks", OrderedDict())
502        super().__setattr__("_is_full_backward_hook", None)
503        super().__setattr__("_forward_hooks", OrderedDict())
504        super().__setattr__("_forward_hooks_with_kwargs", OrderedDict())
505        super().__setattr__("_forward_hooks_always_called", OrderedDict())
506        super().__setattr__("_forward_pre_hooks", OrderedDict())
507        super().__setattr__("_forward_pre_hooks_with_kwargs", OrderedDict())
508        super().__setattr__("_state_dict_hooks", OrderedDict())
509        super().__setattr__("_state_dict_pre_hooks", OrderedDict())
510        super().__setattr__("_load_state_dict_pre_hooks", OrderedDict())
511        super().__setattr__("_load_state_dict_post_hooks", OrderedDict())
512        super().__setattr__("_modules", {})
513
514        if self.call_super_init:
515            super().__init__(*args, **kwargs)
516
517    forward: Callable[..., Any] = _forward_unimplemented
518
519    def register_buffer(
520        self, name: str, tensor: Optional[Tensor], persistent: bool = True
521    ) -> None:
522        r"""Add a buffer to the module.
523
524        This is typically used to register a buffer that should not to be
525        considered a model parameter. For example, BatchNorm's ``running_mean``
526        is not a parameter, but is part of the module's state. Buffers, by
527        default, are persistent and will be saved alongside parameters. This
528        behavior can be changed by setting :attr:`persistent` to ``False``. The
529        only difference between a persistent buffer and a non-persistent buffer
530        is that the latter will not be a part of this module's
531        :attr:`state_dict`.
532
533        Buffers can be accessed as attributes using given names.
534
535        Args:
536            name (str): name of the buffer. The buffer can be accessed
537                from this module using the given name
538            tensor (Tensor or None): buffer to be registered. If ``None``, then operations
539                that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,
540                the buffer is **not** included in the module's :attr:`state_dict`.
541            persistent (bool): whether the buffer is part of this module's
542                :attr:`state_dict`.
543
544        Example::
545
546            >>> # xdoctest: +SKIP("undefined vars")
547            >>> self.register_buffer('running_mean', torch.zeros(num_features))
548
549        """
550        if persistent is False and isinstance(self, torch.jit.ScriptModule):
551            raise RuntimeError("ScriptModule does not support non-persistent buffers")
552
553        if "_buffers" not in self.__dict__:
554            raise AttributeError("cannot assign buffer before Module.__init__() call")
555        elif not isinstance(name, str):
556            raise TypeError(
557                f"buffer name should be a string. Got {torch.typename(name)}"
558            )
559        elif "." in name:
560            raise KeyError('buffer name can\'t contain "."')
561        elif name == "":
562            raise KeyError('buffer name can\'t be empty string ""')
563        elif hasattr(self, name) and name not in self._buffers:
564            raise KeyError(f"attribute '{name}' already exists")
565        elif tensor is not None and not isinstance(tensor, torch.Tensor):
566            raise TypeError(
567                f"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' "
568                "(torch Tensor or None required)"
569            )
570        else:
571            for hook in _global_buffer_registration_hooks.values():
572                output = hook(self, name, tensor)
573                if output is not None:
574                    tensor = output
575            self._buffers[name] = tensor
576            if persistent:
577                self._non_persistent_buffers_set.discard(name)
578            else:
579                self._non_persistent_buffers_set.add(name)
580
581    def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
582        r"""Add a parameter to the module.
583
584        The parameter can be accessed as an attribute using given name.
585
586        Args:
587            name (str): name of the parameter. The parameter can be accessed
588                from this module using the given name
589            param (Parameter or None): parameter to be added to the module. If
590                ``None``, then operations that run on parameters, such as :attr:`cuda`,
591                are ignored. If ``None``, the parameter is **not** included in the
592                module's :attr:`state_dict`.
593        """
594        if "_parameters" not in self.__dict__:
595            raise AttributeError(
596                "cannot assign parameter before Module.__init__() call"
597            )
598
599        elif not isinstance(name, str):
600            raise TypeError(
601                f"parameter name should be a string. Got {torch.typename(name)}"
602            )
603        elif "." in name:
604            raise KeyError('parameter name can\'t contain "."')
605        elif name == "":
606            raise KeyError('parameter name can\'t be empty string ""')
607        elif hasattr(self, name) and name not in self._parameters:
608            raise KeyError(f"attribute '{name}' already exists")
609
610        if param is None:
611            self._parameters[name] = None
612        elif not isinstance(param, Parameter):
613            raise TypeError(
614                f"cannot assign '{torch.typename(param)}' object to parameter '{name}' "
615                "(torch.nn.Parameter or None required)"
616            )
617        elif param.grad_fn:
618            raise ValueError(
619                f"Cannot assign non-leaf Tensor to parameter '{name}'. Model "
620                f"parameters must be created explicitly. To express '{name}' "
621                "as a function of another Tensor, compute the value in "
622                "the forward() method."
623            )
624        else:
625            for hook in _global_parameter_registration_hooks.values():
626                output = hook(self, name, param)
627                if output is not None:
628                    param = output
629            self._parameters[name] = param
630
631    def add_module(self, name: str, module: Optional["Module"]) -> None:
632        r"""Add a child module to the current module.
633
634        The module can be accessed as an attribute using the given name.
635
636        Args:
637            name (str): name of the child module. The child module can be
638                accessed from this module using the given name
639            module (Module): child module to be added to the module.
640        """
641        if not isinstance(module, Module) and module is not None:
642            raise TypeError(f"{torch.typename(module)} is not a Module subclass")
643        elif not isinstance(name, str):
644            raise TypeError(
645                f"module name should be a string. Got {torch.typename(name)}"
646            )
647        elif hasattr(self, name) and name not in self._modules:
648            raise KeyError(f"attribute '{name}' already exists")
649        elif "." in name:
650            raise KeyError(f'module name can\'t contain ".", got: {name}')
651        elif name == "":
652            raise KeyError('module name can\'t be empty string ""')
653        for hook in _global_module_registration_hooks.values():
654            output = hook(self, name, module)
655            if output is not None:
656                module = output
657        self._modules[name] = module
658
659    def register_module(self, name: str, module: Optional["Module"]) -> None:
660        r"""Alias for :func:`add_module`."""
661        self.add_module(name, module)
662
663    def get_submodule(self, target: str) -> "Module":
664        """Return the submodule given by ``target`` if it exists, otherwise throw an error.
665
666        For example, let's say you have an ``nn.Module`` ``A`` that
667        looks like this:
668
669        .. code-block:: text
670
671            A(
672                (net_b): Module(
673                    (net_c): Module(
674                        (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
675                    )
676                    (linear): Linear(in_features=100, out_features=200, bias=True)
677                )
678            )
679
680        (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested
681        submodule ``net_b``, which itself has two submodules ``net_c``
682        and ``linear``. ``net_c`` then has a submodule ``conv``.)
683
684        To check whether or not we have the ``linear`` submodule, we
685        would call ``get_submodule("net_b.linear")``. To check whether
686        we have the ``conv`` submodule, we would call
687        ``get_submodule("net_b.net_c.conv")``.
688
689        The runtime of ``get_submodule`` is bounded by the degree
690        of module nesting in ``target``. A query against
691        ``named_modules`` achieves the same result, but it is O(N) in
692        the number of transitive modules. So, for a simple check to see
693        if some submodule exists, ``get_submodule`` should always be
694        used.
695
696        Args:
697            target: The fully-qualified string name of the submodule
698                to look for. (See above example for how to specify a
699                fully-qualified string.)
700
701        Returns:
702            torch.nn.Module: The submodule referenced by ``target``
703
704        Raises:
705            AttributeError: If the target string references an invalid
706                path or resolves to something that is not an
707                ``nn.Module``
708        """
709        if target == "":
710            return self
711
712        atoms: List[str] = target.split(".")
713        mod: torch.nn.Module = self
714
715        for item in atoms:
716            if not hasattr(mod, item):
717                raise AttributeError(
718                    mod._get_name() + " has no " "attribute `" + item + "`"
719                )
720
721            mod = getattr(mod, item)
722
723            if not isinstance(mod, torch.nn.Module):
724                raise AttributeError("`" + item + "` is not " "an nn.Module")
725
726        return mod
727
728    def set_submodule(self, target: str, module: "Module") -> None:
729        """
730        Set the submodule given by ``target`` if it exists, otherwise throw an error.
731
732        For example, let's say you have an ``nn.Module`` ``A`` that
733        looks like this:
734
735        .. code-block:: text
736
737            A(
738                (net_b): Module(
739                    (net_c): Module(
740                        (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
741                    )
742                    (linear): Linear(in_features=100, out_features=200, bias=True)
743                )
744            )
745
746        (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested
747        submodule ``net_b``, which itself has two submodules ``net_c``
748        and ``linear``. ``net_c`` then has a submodule ``conv``.)
749
750        To overide the ``Conv2d`` with a new submodule ``Linear``, you
751        would call
752        ``set_submodule("net_b.net_c.conv", nn.Linear(33, 16))``.
753
754        Args:
755            target: The fully-qualified string name of the submodule
756                to look for. (See above example for how to specify a
757                fully-qualified string.)
758            module: The module to set the submodule to.
759
760        Raises:
761            ValueError: If the target string is empty
762            AttributeError: If the target string references an invalid
763                path or resolves to something that is not an
764                ``nn.Module``
765        """
766        if target == "":
767            raise ValueError("Cannot set the submodule without a target name!")
768
769        atoms: List[str] = target.split(".")
770        name = atoms.pop(-1)
771        mod: torch.nn.Module = self
772
773        for item in atoms:
774            if not hasattr(mod, item):
775                raise AttributeError(
776                    mod._get_name() + " has no attribute `" + item + "`"
777                )
778
779            mod = getattr(mod, item)
780
781            # Use isinstance instead of type here to also handle subclass of nn.Module
782            if not isinstance(mod, torch.nn.Module):
783                raise AttributeError("`" + item + "` is not an nn.Module")
784
785        setattr(mod, name, module)
786
787    def get_parameter(self, target: str) -> "Parameter":
788        """Return the parameter given by ``target`` if it exists, otherwise throw an error.
789
790        See the docstring for ``get_submodule`` for a more detailed
791        explanation of this method's functionality as well as how to
792        correctly specify ``target``.
793
794        Args:
795            target: The fully-qualified string name of the Parameter
796                to look for. (See ``get_submodule`` for how to specify a
797                fully-qualified string.)
798
799        Returns:
800            torch.nn.Parameter: The Parameter referenced by ``target``
801
802        Raises:
803            AttributeError: If the target string references an invalid
804                path or resolves to something that is not an
805                ``nn.Parameter``
806        """
807        module_path, _, param_name = target.rpartition(".")
808
809        mod: torch.nn.Module = self.get_submodule(module_path)
810
811        if not hasattr(mod, param_name):
812            raise AttributeError(
813                mod._get_name() + " has no attribute `" + param_name + "`"
814            )
815
816        param: torch.nn.Parameter = getattr(mod, param_name)
817
818        if not isinstance(param, torch.nn.Parameter):
819            raise AttributeError("`" + param_name + "` is not an " "nn.Parameter")
820
821        return param
822
823    def get_buffer(self, target: str) -> "Tensor":
824        """Return the buffer given by ``target`` if it exists, otherwise throw an error.
825
826        See the docstring for ``get_submodule`` for a more detailed
827        explanation of this method's functionality as well as how to
828        correctly specify ``target``.
829
830        Args:
831            target: The fully-qualified string name of the buffer
832                to look for. (See ``get_submodule`` for how to specify a
833                fully-qualified string.)
834
835        Returns:
836            torch.Tensor: The buffer referenced by ``target``
837
838        Raises:
839            AttributeError: If the target string references an invalid
840                path or resolves to something that is not a
841                buffer
842        """
843        module_path, _, buffer_name = target.rpartition(".")
844
845        mod: torch.nn.Module = self.get_submodule(module_path)
846
847        if not hasattr(mod, buffer_name):
848            raise AttributeError(
849                mod._get_name() + " has no attribute `" + buffer_name + "`"
850            )
851
852        buffer: torch.Tensor = getattr(mod, buffer_name)
853
854        if buffer_name not in mod._buffers:
855            raise AttributeError("`" + buffer_name + "` is not a buffer")
856
857        return buffer
858
859    def get_extra_state(self) -> Any:
860        """Return any extra state to include in the module's state_dict.
861
862        Implement this and a corresponding :func:`set_extra_state` for your module
863        if you need to store extra state. This function is called when building the
864        module's `state_dict()`.
865
866        Note that extra state should be picklable to ensure working serialization
867        of the state_dict. We only provide provide backwards compatibility guarantees
868        for serializing Tensors; other objects may break backwards compatibility if
869        their serialized pickled form changes.
870
871        Returns:
872            object: Any extra state to store in the module's state_dict
873        """
874        raise RuntimeError(
875            "Reached a code path in Module.get_extra_state() that should never be called. "
876            "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
877            "to report this bug."
878        )
879
880    def set_extra_state(self, state: Any) -> None:
881        """Set extra state contained in the loaded `state_dict`.
882
883        This function is called from :func:`load_state_dict` to handle any extra state
884        found within the `state_dict`. Implement this function and a corresponding
885        :func:`get_extra_state` for your module if you need to store extra state within its
886        `state_dict`.
887
888        Args:
889            state (dict): Extra state from the `state_dict`
890        """
891        raise RuntimeError(
892            "Reached a code path in Module.set_extra_state() that should never be called. "
893            "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
894            "to report this bug."
895        )
896
897    def _apply(self, fn, recurse=True):
898        if recurse:
899            for module in self.children():
900                module._apply(fn)
901
902        def compute_should_use_set_data(tensor, tensor_applied):
903            if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
904                # If the new tensor has compatible tensor type as the existing tensor,
905                # the current behavior is to change the tensor in-place using `.data =`,
906                # and the future behavior is to overwrite the existing tensor. However,
907                # changing the current behavior is a BC-breaking change, and we want it
908                # to happen in future releases. So for now we introduce the
909                # `torch.__future__.get_overwrite_module_params_on_conversion()`
910                # global flag to let the user control whether they want the future
911                # behavior of overwriting the existing tensor or not.
912                return not torch.__future__.get_overwrite_module_params_on_conversion()
913            else:
914                return False
915
916        should_use_swap_tensors = (
917            torch.__future__.get_swap_module_params_on_conversion()
918        )
919
920        for key, param in self._parameters.items():
921            if param is None:
922                continue
923            # Tensors stored in modules are graph leaves, and we don't want to
924            # track autograd history of `param_applied`, so we have to use
925            # `with torch.no_grad():`
926            with torch.no_grad():
927                param_applied = fn(param)
928            p_should_use_set_data = compute_should_use_set_data(param, param_applied)
929
930            # subclasses may have multiple child tensors so we need to use swap_tensors
931            p_should_use_swap_tensors = (
932                should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied)
933            )
934
935            param_grad = param.grad
936            if p_should_use_swap_tensors:
937                try:
938                    if param_grad is not None:
939                        # Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping.
940                        # Decrement use count of the gradient by setting to None
941                        param.grad = None
942                    param_applied = torch.nn.Parameter(
943                        param_applied, requires_grad=param.requires_grad
944                    )
945                    torch.utils.swap_tensors(param, param_applied)
946                except Exception as e:
947                    if param_grad is not None:
948                        param.grad = param_grad
949                    raise RuntimeError(
950                        f"_apply(): Couldn't swap {self._get_name()}.{key}"
951                    ) from e
952                out_param = param
953            elif p_should_use_set_data:
954                param.data = param_applied
955                out_param = param
956            else:
957                assert isinstance(param, Parameter)
958                assert param.is_leaf
959                out_param = Parameter(param_applied, param.requires_grad)
960                self._parameters[key] = out_param
961
962            if param_grad is not None:
963                with torch.no_grad():
964                    grad_applied = fn(param_grad)
965                g_should_use_set_data = compute_should_use_set_data(
966                    param_grad, grad_applied
967                )
968                if p_should_use_swap_tensors:
969                    grad_applied.requires_grad_(param_grad.requires_grad)
970                    try:
971                        torch.utils.swap_tensors(param_grad, grad_applied)
972                    except Exception as e:
973                        raise RuntimeError(
974                            f"_apply(): Couldn't swap {self._get_name()}.{key}.grad"
975                        ) from e
976                    out_param.grad = param_grad
977                elif g_should_use_set_data:
978                    assert out_param.grad is not None
979                    out_param.grad.data = grad_applied
980                else:
981                    assert param_grad.is_leaf
982                    out_param.grad = grad_applied.requires_grad_(
983                        param_grad.requires_grad
984                    )
985
986        for key, buf in self._buffers.items():
987            if buf is not None:
988                self._buffers[key] = fn(buf)
989
990        return self
991
992    def apply(self: T, fn: Callable[["Module"], None]) -> T:
993        r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.
994
995        Typical use includes initializing the parameters of a model
996        (see also :ref:`nn-init-doc`).
997
998        Args:
999            fn (:class:`Module` -> None): function to be applied to each submodule
1000
1001        Returns:
1002            Module: self
1003
1004        Example::
1005
1006            >>> @torch.no_grad()
1007            >>> def init_weights(m):
1008            >>>     print(m)
1009            >>>     if type(m) == nn.Linear:
1010            >>>         m.weight.fill_(1.0)
1011            >>>         print(m.weight)
1012            >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
1013            >>> net.apply(init_weights)
1014            Linear(in_features=2, out_features=2, bias=True)
1015            Parameter containing:
1016            tensor([[1., 1.],
1017                    [1., 1.]], requires_grad=True)
1018            Linear(in_features=2, out_features=2, bias=True)
1019            Parameter containing:
1020            tensor([[1., 1.],
1021                    [1., 1.]], requires_grad=True)
1022            Sequential(
1023              (0): Linear(in_features=2, out_features=2, bias=True)
1024              (1): Linear(in_features=2, out_features=2, bias=True)
1025            )
1026
1027        """
1028        for module in self.children():
1029            module.apply(fn)
1030        fn(self)
1031        return self
1032
1033    def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:
1034        r"""Move all model parameters and buffers to the GPU.
1035
1036        This also makes associated parameters and buffers different objects. So
1037        it should be called before constructing optimizer if the module will
1038        live on GPU while being optimized.
1039
1040        .. note::
1041            This method modifies the module in-place.
1042
1043        Args:
1044            device (int, optional): if specified, all parameters will be
1045                copied to that device
1046
1047        Returns:
1048            Module: self
1049        """
1050        return self._apply(lambda t: t.cuda(device))
1051
1052    def ipu(self: T, device: Optional[Union[int, device]] = None) -> T:
1053        r"""Move all model parameters and buffers to the IPU.
1054
1055        This also makes associated parameters and buffers different objects. So
1056        it should be called before constructing optimizer if the module will
1057        live on IPU while being optimized.
1058
1059        .. note::
1060            This method modifies the module in-place.
1061
1062        Arguments:
1063            device (int, optional): if specified, all parameters will be
1064                copied to that device
1065
1066        Returns:
1067            Module: self
1068        """
1069        return self._apply(lambda t: t.ipu(device))
1070
1071    def xpu(self: T, device: Optional[Union[int, device]] = None) -> T:
1072        r"""Move all model parameters and buffers to the XPU.
1073
1074        This also makes associated parameters and buffers different objects. So
1075        it should be called before constructing optimizer if the module will
1076        live on XPU while being optimized.
1077
1078        .. note::
1079            This method modifies the module in-place.
1080
1081        Arguments:
1082            device (int, optional): if specified, all parameters will be
1083                copied to that device
1084
1085        Returns:
1086            Module: self
1087        """
1088        return self._apply(lambda t: t.xpu(device))
1089
1090    def mtia(self: T, device: Optional[Union[int, device]] = None) -> T:
1091        r"""Move all model parameters and buffers to the MTIA.
1092
1093        This also makes associated parameters and buffers different objects. So
1094        it should be called before constructing optimizer if the module will
1095        live on MTIA while being optimized.
1096
1097        .. note::
1098            This method modifies the module in-place.
1099
1100        Arguments:
1101            device (int, optional): if specified, all parameters will be
1102                copied to that device
1103
1104        Returns:
1105            Module: self
1106        """
1107        return self._apply(lambda t: t.mtia(device))
1108
1109    def cpu(self: T) -> T:
1110        r"""Move all model parameters and buffers to the CPU.
1111
1112        .. note::
1113            This method modifies the module in-place.
1114
1115        Returns:
1116            Module: self
1117        """
1118        return self._apply(lambda t: t.cpu())
1119
1120    def type(self: T, dst_type: Union[dtype, str]) -> T:
1121        r"""Casts all parameters and buffers to :attr:`dst_type`.
1122
1123        .. note::
1124            This method modifies the module in-place.
1125
1126        Args:
1127            dst_type (type or string): the desired type
1128
1129        Returns:
1130            Module: self
1131        """
1132        return self._apply(lambda t: t.type(dst_type))
1133
1134    def float(self: T) -> T:
1135        r"""Casts all floating point parameters and buffers to ``float`` datatype.
1136
1137        .. note::
1138            This method modifies the module in-place.
1139
1140        Returns:
1141            Module: self
1142        """
1143        return self._apply(lambda t: t.float() if t.is_floating_point() else t)
1144
1145    def double(self: T) -> T:
1146        r"""Casts all floating point parameters and buffers to ``double`` datatype.
1147
1148        .. note::
1149            This method modifies the module in-place.
1150
1151        Returns:
1152            Module: self
1153        """
1154        return self._apply(lambda t: t.double() if t.is_floating_point() else t)
1155
1156    def half(self: T) -> T:
1157        r"""Casts all floating point parameters and buffers to ``half`` datatype.
1158
1159        .. note::
1160            This method modifies the module in-place.
1161
1162        Returns:
1163            Module: self
1164        """
1165        return self._apply(lambda t: t.half() if t.is_floating_point() else t)
1166
1167    def bfloat16(self: T) -> T:
1168        r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype.
1169
1170        .. note::
1171            This method modifies the module in-place.
1172
1173        Returns:
1174            Module: self
1175        """
1176        return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t)
1177
1178    def to_empty(
1179        self: T, *, device: Optional[DeviceLikeType], recurse: bool = True
1180    ) -> T:
1181        r"""Move the parameters and buffers to the specified device without copying storage.
1182
1183        Args:
1184            device (:class:`torch.device`): The desired device of the parameters
1185                and buffers in this module.
1186            recurse (bool): Whether parameters and buffers of submodules should
1187                be recursively moved to the specified device.
1188
1189        Returns:
1190            Module: self
1191        """
1192        return self._apply(
1193            lambda t: torch.empty_like(t, device=device), recurse=recurse
1194        )
1195
1196    @overload
1197    def to(
1198        self,
1199        device: Optional[DeviceLikeType] = ...,
1200        dtype: Optional[dtype] = ...,
1201        non_blocking: bool = ...,
1202    ) -> Self:
1203        ...
1204
1205    @overload
1206    def to(self, dtype: dtype, non_blocking: bool = ...) -> Self:
1207        ...
1208
1209    @overload
1210    def to(self, tensor: Tensor, non_blocking: bool = ...) -> Self:
1211        ...
1212
1213    def to(self, *args, **kwargs):
1214        r"""Move and/or cast the parameters and buffers.
1215
1216        This can be called as
1217
1218        .. function:: to(device=None, dtype=None, non_blocking=False)
1219           :noindex:
1220
1221        .. function:: to(dtype, non_blocking=False)
1222           :noindex:
1223
1224        .. function:: to(tensor, non_blocking=False)
1225           :noindex:
1226
1227        .. function:: to(memory_format=torch.channels_last)
1228           :noindex:
1229
1230        Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
1231        floating point or complex :attr:`dtype`\ s. In addition, this method will
1232        only cast the floating point or complex parameters and buffers to :attr:`dtype`
1233        (if given). The integral parameters and buffers will be moved
1234        :attr:`device`, if that is given, but with dtypes unchanged. When
1235        :attr:`non_blocking` is set, it tries to convert/move asynchronously
1236        with respect to the host if possible, e.g., moving CPU Tensors with
1237        pinned memory to CUDA devices.
1238
1239        See below for examples.
1240
1241        .. note::
1242            This method modifies the module in-place.
1243
1244        Args:
1245            device (:class:`torch.device`): the desired device of the parameters
1246                and buffers in this module
1247            dtype (:class:`torch.dtype`): the desired floating point or complex dtype of
1248                the parameters and buffers in this module
1249            tensor (torch.Tensor): Tensor whose dtype and device are the desired
1250                dtype and device for all parameters and buffers in this module
1251            memory_format (:class:`torch.memory_format`): the desired memory
1252                format for 4D parameters and buffers in this module (keyword
1253                only argument)
1254
1255        Returns:
1256            Module: self
1257
1258        Examples::
1259
1260            >>> # xdoctest: +IGNORE_WANT("non-deterministic")
1261            >>> linear = nn.Linear(2, 2)
1262            >>> linear.weight
1263            Parameter containing:
1264            tensor([[ 0.1913, -0.3420],
1265                    [-0.5113, -0.2325]])
1266            >>> linear.to(torch.double)
1267            Linear(in_features=2, out_features=2, bias=True)
1268            >>> linear.weight
1269            Parameter containing:
1270            tensor([[ 0.1913, -0.3420],
1271                    [-0.5113, -0.2325]], dtype=torch.float64)
1272            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
1273            >>> gpu1 = torch.device("cuda:1")
1274            >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
1275            Linear(in_features=2, out_features=2, bias=True)
1276            >>> linear.weight
1277            Parameter containing:
1278            tensor([[ 0.1914, -0.3420],
1279                    [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
1280            >>> cpu = torch.device("cpu")
1281            >>> linear.to(cpu)
1282            Linear(in_features=2, out_features=2, bias=True)
1283            >>> linear.weight
1284            Parameter containing:
1285            tensor([[ 0.1914, -0.3420],
1286                    [-0.5112, -0.2324]], dtype=torch.float16)
1287
1288            >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
1289            >>> linear.weight
1290            Parameter containing:
1291            tensor([[ 0.3741+0.j,  0.2382+0.j],
1292                    [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
1293            >>> linear(torch.ones(3, 2, dtype=torch.cdouble))
1294            tensor([[0.6122+0.j, 0.1150+0.j],
1295                    [0.6122+0.j, 0.1150+0.j],
1296                    [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
1297
1298        """
1299        device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
1300            *args, **kwargs
1301        )
1302
1303        if dtype is not None:
1304            if not (dtype.is_floating_point or dtype.is_complex):
1305                raise TypeError(
1306                    "nn.Module.to only accepts floating point or complex "
1307                    f"dtypes, but got desired dtype={dtype}"
1308                )
1309            if dtype.is_complex:
1310                warnings.warn(
1311                    "Complex modules are a new feature under active development whose design may change, "
1312                    "and some modules might not work as expected when using complex tensors as parameters or buffers. "
1313                    "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml "
1314                    "if a complex module does not work as expected."
1315                )
1316
1317        def convert(t):
1318            try:
1319                if convert_to_format is not None and t.dim() in (4, 5):
1320                    return t.to(
1321                        device,
1322                        dtype if t.is_floating_point() or t.is_complex() else None,
1323                        non_blocking,
1324                        memory_format=convert_to_format,
1325                    )
1326                return t.to(
1327                    device,
1328                    dtype if t.is_floating_point() or t.is_complex() else None,
1329                    non_blocking,
1330                )
1331            except NotImplementedError as e:
1332                if str(e) == "Cannot copy out of meta tensor; no data!":
1333                    raise NotImplementedError(
1334                        f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() "
1335                        f"when moving module from meta to a different device."
1336                    ) from None
1337                else:
1338                    raise
1339
1340        return self._apply(convert)
1341
1342    def register_full_backward_pre_hook(
1343        self,
1344        hook: Callable[["Module", _grad_t], Union[None, _grad_t]],
1345        prepend: bool = False,
1346    ) -> RemovableHandle:
1347        r"""Register a backward pre-hook on the module.
1348
1349        The hook will be called every time the gradients for the module are computed.
1350        The hook should have the following signature::
1351
1352            hook(module, grad_output) -> tuple[Tensor] or None
1353
1354        The :attr:`grad_output` is a tuple. The hook should
1355        not modify its arguments, but it can optionally return a new gradient with
1356        respect to the output that will be used in place of :attr:`grad_output` in
1357        subsequent computations. Entries in :attr:`grad_output` will be ``None`` for
1358        all non-Tensor arguments.
1359
1360        For technical reasons, when this hook is applied to a Module, its forward function will
1361        receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
1362        of each Tensor returned by the Module's forward function.
1363
1364        .. warning ::
1365            Modifying inputs inplace is not allowed when using backward hooks and
1366            will raise an error.
1367
1368        Args:
1369            hook (Callable): The user-defined hook to be registered.
1370            prepend (bool): If true, the provided ``hook`` will be fired before
1371                all existing ``backward_pre`` hooks on this
1372                :class:`torch.nn.modules.Module`. Otherwise, the provided
1373                ``hook`` will be fired after all existing ``backward_pre`` hooks
1374                on this :class:`torch.nn.modules.Module`. Note that global
1375                ``backward_pre`` hooks registered with
1376                :func:`register_module_full_backward_pre_hook` will fire before
1377                all hooks registered by this method.
1378
1379        Returns:
1380            :class:`torch.utils.hooks.RemovableHandle`:
1381                a handle that can be used to remove the added hook by calling
1382                ``handle.remove()``
1383
1384        """
1385        handle = RemovableHandle(self._backward_pre_hooks)
1386        self._backward_pre_hooks[handle.id] = hook
1387        if prepend:
1388            self._backward_pre_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]
1389        return handle
1390
1391    def register_backward_hook(
1392        self, hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]]
1393    ) -> RemovableHandle:
1394        r"""Register a backward hook on the module.
1395
1396        This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and
1397        the behavior of this function will change in future versions.
1398
1399        Returns:
1400            :class:`torch.utils.hooks.RemovableHandle`:
1401                a handle that can be used to remove the added hook by calling
1402                ``handle.remove()``
1403
1404        """
1405        if self._is_full_backward_hook is True:
1406            raise RuntimeError(
1407                "Cannot use both regular backward hooks and full backward hooks on a "
1408                "single Module. Please use only one of them."
1409            )
1410
1411        self._is_full_backward_hook = False
1412
1413        handle = RemovableHandle(self._backward_hooks)
1414        self._backward_hooks[handle.id] = hook
1415        return handle
1416
1417    def register_full_backward_hook(
1418        self,
1419        hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]],
1420        prepend: bool = False,
1421    ) -> RemovableHandle:
1422        r"""Register a backward hook on the module.
1423
1424        The hook will be called every time the gradients with respect to a module
1425        are computed, i.e. the hook will execute if and only if the gradients with
1426        respect to module outputs are computed. The hook should have the following
1427        signature::
1428
1429            hook(module, grad_input, grad_output) -> tuple(Tensor) or None
1430
1431        The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients
1432        with respect to the inputs and outputs respectively. The hook should
1433        not modify its arguments, but it can optionally return a new gradient with
1434        respect to the input that will be used in place of :attr:`grad_input` in
1435        subsequent computations. :attr:`grad_input` will only correspond to the inputs given
1436        as positional arguments and all kwarg arguments are ignored. Entries
1437        in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor
1438        arguments.
1439
1440        For technical reasons, when this hook is applied to a Module, its forward function will
1441        receive a view of each Tensor passed to the Module. Similarly the caller will receive a view
1442        of each Tensor returned by the Module's forward function.
1443
1444        .. warning ::
1445            Modifying inputs or outputs inplace is not allowed when using backward hooks and
1446            will raise an error.
1447
1448        Args:
1449            hook (Callable): The user-defined hook to be registered.
1450            prepend (bool): If true, the provided ``hook`` will be fired before
1451                all existing ``backward`` hooks on this
1452                :class:`torch.nn.modules.Module`. Otherwise, the provided
1453                ``hook`` will be fired after all existing ``backward`` hooks on
1454                this :class:`torch.nn.modules.Module`. Note that global
1455                ``backward`` hooks registered with
1456                :func:`register_module_full_backward_hook` will fire before
1457                all hooks registered by this method.
1458
1459        Returns:
1460            :class:`torch.utils.hooks.RemovableHandle`:
1461                a handle that can be used to remove the added hook by calling
1462                ``handle.remove()``
1463
1464        """
1465        if self._is_full_backward_hook is False:
1466            raise RuntimeError(
1467                "Cannot use both regular backward hooks and full backward hooks on a "
1468                "single Module. Please use only one of them."
1469            )
1470
1471        self._is_full_backward_hook = True
1472
1473        handle = RemovableHandle(self._backward_hooks)
1474        self._backward_hooks[handle.id] = hook
1475        if prepend:
1476            self._backward_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]
1477        return handle
1478
1479    def _get_backward_hooks(self):
1480        r"""Return the backward hooks for use in the call function.
1481
1482        It returns two lists, one with the full backward hooks and one with the non-full
1483        backward hooks.
1484        """
1485        full_backward_hooks: List[Callable] = []
1486        if _global_is_full_backward_hook is True:
1487            full_backward_hooks += _global_backward_hooks.values()
1488        if self._is_full_backward_hook is True:
1489            full_backward_hooks += self._backward_hooks.values()
1490
1491        non_full_backward_hooks: List[Callable] = []
1492        if _global_is_full_backward_hook is False:
1493            non_full_backward_hooks += _global_backward_hooks.values()
1494        if self._is_full_backward_hook is False:
1495            non_full_backward_hooks += self._backward_hooks.values()
1496
1497        return full_backward_hooks, non_full_backward_hooks
1498
1499    def _get_backward_pre_hooks(self):
1500        backward_pre_hooks: List[Callable] = []
1501        backward_pre_hooks += _global_backward_pre_hooks.values()
1502        backward_pre_hooks += self._backward_pre_hooks.values()
1503
1504        return backward_pre_hooks
1505
1506    def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn):
1507        if not isinstance(result, torch.Tensor):
1508            if not (
1509                isinstance(result, tuple)
1510                and all(isinstance(r, torch.Tensor) for r in result)
1511            ):
1512                warnings.warn(
1513                    "Using non-full backward hooks on a Module that does not return a "
1514                    "single Tensor or a tuple of Tensors is deprecated and will be removed "
1515                    "in future versions. This hook will be missing some of the grad_output. "
1516                    "Please use register_full_backward_hook to get the documented behavior.",
1517                    FutureWarning,
1518                    stacklevel=2,
1519                )
1520                return
1521        else:
1522            result = (result,)
1523
1524        if not isinstance(inputs, torch.Tensor):
1525            if not (
1526                isinstance(inputs, tuple)
1527                and all(isinstance(i, torch.Tensor) for i in inputs)
1528            ):
1529                warnings.warn(
1530                    "Using non-full backward hooks on a Module that does not take as input a "
1531                    "single Tensor or a tuple of Tensors is deprecated and will be removed "
1532                    "in future versions. This hook will be missing some of the grad_input. "
1533                    "Please use register_full_backward_hook to get the documented behavior.",
1534                    FutureWarning,
1535                    stacklevel=2,
1536                )
1537                return
1538        else:
1539            inputs = (inputs,)
1540
1541        # At this point we are sure that inputs and result are tuple of Tensors
1542        out_grad_fn = {r.grad_fn for r in result if r.grad_fn is not None}
1543        if len(out_grad_fn) == 0 or (
1544            len(out_grad_fn) == 1 and grad_fn not in out_grad_fn
1545        ):
1546            warnings.warn(
1547                "Using a non-full backward hook when outputs are nested in python data structure "
1548                "is deprecated and will be removed in future versions. This hook will be missing "
1549                "some grad_output.",
1550                FutureWarning,
1551                stacklevel=2,
1552            )
1553        elif len(out_grad_fn) > 1:
1554            warnings.warn(
1555                "Using a non-full backward hook when outputs are generated by different autograd Nodes "
1556                "is deprecated and will be removed in future versions. This hook will be missing "
1557                "some grad_output. Please use register_full_backward_hook to get the documented behavior.",
1558                FutureWarning,
1559                stacklevel=2,
1560            )
1561        else:
1562            # At this point the grad_output part of the hook will most likely be correct
1563            inputs_grad_fn = {i.grad_fn for i in inputs if i.grad_fn is not None}
1564
1565            next_functions = {n[0] for n in grad_fn.next_functions}
1566
1567            if inputs_grad_fn != next_functions:
1568                warnings.warn(
1569                    "Using a non-full backward hook when the forward contains multiple autograd Nodes "
1570                    "is deprecated and will be removed in future versions. This hook will be missing "
1571                    "some grad_input. Please use register_full_backward_hook to get the documented "
1572                    "behavior.",
1573                    FutureWarning,
1574                    stacklevel=2,
1575                )
1576
1577    def register_forward_pre_hook(
1578        self,
1579        hook: Union[
1580            Callable[[T, Tuple[Any, ...]], Optional[Any]],
1581            Callable[
1582                [T, Tuple[Any, ...], Dict[str, Any]],
1583                Optional[Tuple[Any, Dict[str, Any]]],
1584            ],
1585        ],
1586        *,
1587        prepend: bool = False,
1588        with_kwargs: bool = False,
1589    ) -> RemovableHandle:
1590        r"""Register a forward pre-hook on the module.
1591
1592        The hook will be called every time before :func:`forward` is invoked.
1593
1594
1595        If ``with_kwargs`` is false or not specified, the input contains only
1596        the positional arguments given to the module. Keyword arguments won't be
1597        passed to the hooks and only to the ``forward``. The hook can modify the
1598        input. User can either return a tuple or a single modified value in the
1599        hook. We will wrap the value into a tuple if a single value is returned
1600        (unless that value is already a tuple). The hook should have the
1601        following signature::
1602
1603            hook(module, args) -> None or modified input
1604
1605        If ``with_kwargs`` is true, the forward pre-hook will be passed the
1606        kwargs given to the forward function. And if the hook modifies the
1607        input, both the args and kwargs should be returned. The hook should have
1608        the following signature::
1609
1610            hook(module, args, kwargs) -> None or a tuple of modified input and kwargs
1611
1612        Args:
1613            hook (Callable): The user defined hook to be registered.
1614            prepend (bool): If true, the provided ``hook`` will be fired before
1615                all existing ``forward_pre`` hooks on this
1616                :class:`torch.nn.modules.Module`. Otherwise, the provided
1617                ``hook`` will be fired after all existing ``forward_pre`` hooks
1618                on this :class:`torch.nn.modules.Module`. Note that global
1619                ``forward_pre`` hooks registered with
1620                :func:`register_module_forward_pre_hook` will fire before all
1621                hooks registered by this method.
1622                Default: ``False``
1623            with_kwargs (bool): If true, the ``hook`` will be passed the kwargs
1624                given to the forward function.
1625                Default: ``False``
1626
1627        Returns:
1628            :class:`torch.utils.hooks.RemovableHandle`:
1629                a handle that can be used to remove the added hook by calling
1630                ``handle.remove()``
1631        """
1632        handle = RemovableHandle(
1633            self._forward_pre_hooks, extra_dict=self._forward_pre_hooks_with_kwargs
1634        )
1635        self._forward_pre_hooks[handle.id] = hook
1636        if with_kwargs:
1637            self._forward_pre_hooks_with_kwargs[handle.id] = True
1638
1639        if prepend:
1640            self._forward_pre_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]
1641        return handle
1642
1643    def register_forward_hook(
1644        self,
1645        hook: Union[
1646            Callable[[T, Tuple[Any, ...], Any], Optional[Any]],
1647            Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]],
1648        ],
1649        *,
1650        prepend: bool = False,
1651        with_kwargs: bool = False,
1652        always_call: bool = False,
1653    ) -> RemovableHandle:
1654        r"""Register a forward hook on the module.
1655
1656        The hook will be called every time after :func:`forward` has computed an output.
1657
1658        If ``with_kwargs`` is ``False`` or not specified, the input contains only
1659        the positional arguments given to the module. Keyword arguments won't be
1660        passed to the hooks and only to the ``forward``. The hook can modify the
1661        output. It can modify the input inplace but it will not have effect on
1662        forward since this is called after :func:`forward` is called. The hook
1663        should have the following signature::
1664
1665            hook(module, args, output) -> None or modified output
1666
1667        If ``with_kwargs`` is ``True``, the forward hook will be passed the
1668        ``kwargs`` given to the forward function and be expected to return the
1669        output possibly modified. The hook should have the following signature::
1670
1671            hook(module, args, kwargs, output) -> None or modified output
1672
1673        Args:
1674            hook (Callable): The user defined hook to be registered.
1675            prepend (bool): If ``True``, the provided ``hook`` will be fired
1676                before all existing ``forward`` hooks on this
1677                :class:`torch.nn.modules.Module`. Otherwise, the provided
1678                ``hook`` will be fired after all existing ``forward`` hooks on
1679                this :class:`torch.nn.modules.Module`. Note that global
1680                ``forward`` hooks registered with
1681                :func:`register_module_forward_hook` will fire before all hooks
1682                registered by this method.
1683                Default: ``False``
1684            with_kwargs (bool): If ``True``, the ``hook`` will be passed the
1685                kwargs given to the forward function.
1686                Default: ``False``
1687            always_call (bool): If ``True`` the ``hook`` will be run regardless of
1688                whether an exception is raised while calling the Module.
1689                Default: ``False``
1690
1691        Returns:
1692            :class:`torch.utils.hooks.RemovableHandle`:
1693                a handle that can be used to remove the added hook by calling
1694                ``handle.remove()``
1695        """
1696        handle = RemovableHandle(
1697            self._forward_hooks,
1698            extra_dict=[
1699                self._forward_hooks_with_kwargs,
1700                self._forward_hooks_always_called,
1701            ],
1702        )
1703        self._forward_hooks[handle.id] = hook
1704        if with_kwargs:
1705            self._forward_hooks_with_kwargs[handle.id] = True
1706        if always_call:
1707            self._forward_hooks_always_called[handle.id] = True
1708        if prepend:
1709            self._forward_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]
1710        return handle
1711
1712    def _slow_forward(self, *input, **kwargs):
1713        tracing_state = torch._C._get_tracing_state()
1714        if not tracing_state or isinstance(self.forward, torch._C.ScriptMethod):
1715            return self.forward(*input, **kwargs)
1716        recording_scopes = torch.jit._trace._trace_module_map is not None
1717        if recording_scopes:
1718            # type ignore was added because at this point one knows that
1719            # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any]
1720            name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None  # type: ignore[index, operator] # noqa: B950
1721            if name:
1722                tracing_state.push_scope(name)
1723            else:
1724                recording_scopes = False
1725        try:
1726            result = self.forward(*input, **kwargs)
1727        finally:
1728            if recording_scopes:
1729                tracing_state.pop_scope()
1730        return result
1731
1732    def _wrapped_call_impl(self, *args, **kwargs):
1733        if self._compiled_call_impl is not None:
1734            return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
1735        else:
1736            return self._call_impl(*args, **kwargs)
1737
1738    # torchrec tests the code consistency with the following code
1739    # fmt: off
1740    def _call_impl(self, *args, **kwargs):
1741        forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
1742        # If we don't have any hooks, we want to skip the rest of the logic in
1743        # this function, and just call forward.
1744        if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745                or _global_backward_pre_hooks or _global_backward_hooks
1746                or _global_forward_hooks or _global_forward_pre_hooks):
1747            return forward_call(*args, **kwargs)
1748
1749        result = None
1750        called_always_called_hooks = set()
1751
1752        def inner():
1753            nonlocal result, args, kwargs
1754
1755            full_backward_hooks, non_full_backward_hooks = [], []
1756            backward_pre_hooks = []
1757            if self._backward_pre_hooks or _global_backward_pre_hooks:
1758                backward_pre_hooks = self._get_backward_pre_hooks()
1759
1760            if self._backward_hooks or _global_backward_hooks:
1761                full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
1762
1763            if _global_forward_pre_hooks or self._forward_pre_hooks:
1764                for hook_id, hook in (
1765                    *_global_forward_pre_hooks.items(),
1766                    *self._forward_pre_hooks.items(),
1767                ):
1768                    if hook_id in self._forward_pre_hooks_with_kwargs:
1769                        args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
1770                        if args_kwargs_result is not None:
1771                            if isinstance(args_kwargs_result, tuple) and len(args_kwargs_result) == 2:
1772                                args, kwargs = args_kwargs_result
1773                            else:
1774                                raise RuntimeError(
1775                                    "forward pre-hook must return None or a tuple "
1776                                    f"of (new_args, new_kwargs), but got {args_kwargs_result}."
1777                                )
1778                    else:
1779                        args_result = hook(self, args)
1780                        if args_result is not None:
1781                            if not isinstance(args_result, tuple):
1782                                args_result = (args_result,)
1783                            args = args_result
1784
1785            bw_hook = None
1786            if full_backward_hooks or backward_pre_hooks:
1787                bw_hook = BackwardHook(self, full_backward_hooks, backward_pre_hooks)
1788                args = bw_hook.setup_input_hook(args)
1789
1790            result = forward_call(*args, **kwargs)
1791            if _global_forward_hooks or self._forward_hooks:
1792                for hook_id, hook in (
1793                    *_global_forward_hooks.items(),
1794                    *self._forward_hooks.items(),
1795                ):
1796                    # mark that always called hook is run
1797                    if hook_id in self._forward_hooks_always_called or hook_id in _global_forward_hooks_always_called:
1798                        called_always_called_hooks.add(hook_id)
1799
1800                    if hook_id in self._forward_hooks_with_kwargs:
1801                        hook_result = hook(self, args, kwargs, result)
1802                    else:
1803                        hook_result = hook(self, args, result)
1804
1805                    if hook_result is not None:
1806                        result = hook_result
1807
1808            if bw_hook:
1809                if not isinstance(result, (torch.Tensor, tuple)):
1810                    warnings.warn("For backward hooks to be called,"
1811                                  " module output should be a Tensor or a tuple of Tensors"
1812                                  f" but received {type(result)}")
1813                result = bw_hook.setup_output_hook(result)
1814
1815            # Handle the non-full backward hooks
1816            if non_full_backward_hooks:
1817                var = result
1818                while not isinstance(var, torch.Tensor):
1819                    if isinstance(var, dict):
1820                        var = next(v for v in var.values() if isinstance(v, torch.Tensor))
1821                    else:
1822                        var = var[0]
1823                grad_fn = var.grad_fn
1824                if grad_fn is not None:
1825                    for hook in non_full_backward_hooks:
1826                        grad_fn.register_hook(_WrappedHook(hook, self))
1827                    self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
1828
1829            return result
1830
1831        from torch.compiler import is_compiling
1832
1833        # This is technically not behavior equivalent when compiling, but it's
1834        # incredibly unlikely we will ever support throwing an exception in NN
1835        # module, and then catching it here, and then reraising it, and then
1836        # catching it again, and expecting the resulting frame to be compiled.
1837        # The reraise here just gunks up our exception handling for no good
1838        # reason.  Don't try to run the always called hooks in event of
1839        # exception.
1840        if is_compiling():
1841            return inner()
1842
1843        try:
1844            return inner()
1845        except Exception:
1846            # run always called hooks if they have not already been run
1847            # For now only forward hooks have the always_call option but perhaps
1848            # this functionality should be added to full backward hooks as well.
1849            for hook_id, hook in _global_forward_hooks.items():
1850                if hook_id in _global_forward_hooks_always_called and hook_id not in called_always_called_hooks:  # type: ignore[possibly-undefined]
1851                    try:
1852                        hook_result = hook(self, args, result)  # type: ignore[possibly-undefined]
1853                        if hook_result is not None:
1854                            result = hook_result
1855                    except Exception as e:
1856                        warnings.warn("global module forward hook with ``always_call=True`` raised an exception "
1857                                      f"that was silenced as another error was raised in forward: {str(e)}")
1858                        continue
1859
1860            for hook_id, hook in self._forward_hooks.items():
1861                if hook_id in self._forward_hooks_always_called and hook_id not in called_always_called_hooks:  # type: ignore[possibly-undefined]
1862                    try:
1863                        if hook_id in self._forward_hooks_with_kwargs:
1864                            hook_result = hook(self, args, kwargs, result)  # type: ignore[possibly-undefined]
1865                        else:
1866                            hook_result = hook(self, args, result)  # type: ignore[possibly-undefined]
1867                        if hook_result is not None:
1868                            result = hook_result
1869                    except Exception as e:
1870                        warnings.warn("module forward hook with ``always_call=True`` raised an exception "
1871                                      f"that was silenced as another error was raised in forward: {str(e)}")
1872                        continue
1873            # raise exception raised in try block
1874            raise
1875    # fmt: on
1876
1877    __call__: Callable[..., Any] = _wrapped_call_impl
1878
1879    def __getstate__(self):
1880        state = self.__dict__.copy()
1881        state.pop("_compiled_call_impl", None)
1882        return state
1883
1884    def __setstate__(self, state):
1885        self.__dict__.update(state)
1886
1887        # Support loading old checkpoints that don't have the following attrs:
1888        if "_forward_pre_hooks" not in self.__dict__:
1889            self._forward_pre_hooks = OrderedDict()
1890        if "_forward_pre_hooks_with_kwargs" not in self.__dict__:
1891            self._forward_pre_hooks_with_kwargs = OrderedDict()
1892        if "_forward_hooks_with_kwargs" not in self.__dict__:
1893            self._forward_hooks_with_kwargs = OrderedDict()
1894        if "_forward_hooks_always_called" not in self.__dict__:
1895            self._forward_hooks_always_called = OrderedDict()
1896        if "_state_dict_hooks" not in self.__dict__:
1897            self._state_dict_hooks = OrderedDict()
1898        if "_state_dict_pre_hooks" not in self.__dict__:
1899            self._state_dict_pre_hooks = OrderedDict()
1900        if "_load_state_dict_pre_hooks" not in self.__dict__:
1901            self._load_state_dict_pre_hooks = OrderedDict()
1902        if "_load_state_dict_post_hooks" not in self.__dict__:
1903            self._load_state_dict_post_hooks = OrderedDict()
1904        if "_non_persistent_buffers_set" not in self.__dict__:
1905            self._non_persistent_buffers_set = set()
1906        if "_is_full_backward_hook" not in self.__dict__:
1907            self._is_full_backward_hook = None
1908        if "_backward_pre_hooks" not in self.__dict__:
1909            self._backward_pre_hooks = OrderedDict()
1910
1911    # On the return type:
1912    # We choose to return `Any` in the `__getattr__` type signature instead of a more strict `Union[Tensor, Module]`.
1913    # This is done for better interop with various type checkers for the end users.
1914    # Having a stricter return type doesn't play nicely with `register_buffer()` and forces
1915    # people to excessively use type-ignores, asserts, casts, etc.
1916    # See full discussion on the problems with returning `Union` here
1917    # https://github.com/microsoft/pyright/issues/4213
1918    def __getattr__(self, name: str) -> Any:
1919        if "_parameters" in self.__dict__:
1920            _parameters = self.__dict__["_parameters"]
1921            if name in _parameters:
1922                return _parameters[name]
1923        if "_buffers" in self.__dict__:
1924            _buffers = self.__dict__["_buffers"]
1925            if name in _buffers:
1926                return _buffers[name]
1927        if "_modules" in self.__dict__:
1928            modules = self.__dict__["_modules"]
1929            if name in modules:
1930                return modules[name]
1931        raise AttributeError(
1932            f"'{type(self).__name__}' object has no attribute '{name}'"
1933        )
1934
1935    def __setattr__(self, name: str, value: Union[Tensor, "Module"]) -> None:
1936        def remove_from(*dicts_or_sets):
1937            for d in dicts_or_sets:
1938                if name in d:
1939                    if isinstance(d, dict):
1940                        del d[name]
1941                    else:
1942                        d.discard(name)
1943
1944        params = self.__dict__.get("_parameters")
1945        if isinstance(value, Parameter):
1946            if params is None:
1947                raise AttributeError(
1948                    "cannot assign parameters before Module.__init__() call"
1949                )
1950            remove_from(
1951                self.__dict__,
1952                self._buffers,
1953                self._modules,
1954                self._non_persistent_buffers_set,
1955            )
1956            self.register_parameter(name, value)
1957        elif params is not None and name in params:
1958            if value is not None:
1959                raise TypeError(
1960                    f"cannot assign '{torch.typename(value)}' as parameter '{name}' "
1961                    "(torch.nn.Parameter or None expected)"
1962                )
1963            self.register_parameter(name, value)
1964        else:
1965            modules = self.__dict__.get("_modules")
1966            if isinstance(value, Module):
1967                if modules is None:
1968                    raise AttributeError(
1969                        "cannot assign module before Module.__init__() call"
1970                    )
1971                remove_from(
1972                    self.__dict__,
1973                    self._parameters,
1974                    self._buffers,
1975                    self._non_persistent_buffers_set,
1976                )
1977                for hook in _global_module_registration_hooks.values():
1978                    output = hook(self, name, value)
1979                    if output is not None:
1980                        value = output
1981                modules[name] = value
1982            elif modules is not None and name in modules:
1983                if value is not None:
1984                    raise TypeError(
1985                        f"cannot assign '{torch.typename(value)}' as child module '{name}' "
1986                        "(torch.nn.Module or None expected)"
1987                    )
1988                for hook in _global_module_registration_hooks.values():
1989                    output = hook(self, name, value)
1990                    if output is not None:
1991                        value = output
1992                modules[name] = value
1993            else:
1994                buffers = self.__dict__.get("_buffers")
1995                if isinstance(value, Buffer) or buffers is not None and name in buffers:
1996                    if value is not None and not isinstance(value, torch.Tensor):
1997                        raise TypeError(
1998                            f"cannot assign '{torch.typename(value)}' as buffer '{name}' "
1999                            "(torch.nn.Buffer, torch.Tensor or None expected)"
2000                        )
2001                    if isinstance(value, Buffer):
2002                        persistent = value.persistent
2003                    else:
2004                        persistent = name not in self._non_persistent_buffers_set
2005                    # === HACK ===
2006                    # This whole block below should just be:
2007                    # self.register_buffer(name, value, persistent)
2008
2009                    # But to support subclasses of nn.Module that (wrongfully) implement a
2010                    # register_buffer() method that doesn't have the "persistent"
2011                    # argument. Only pass it in if it is accepted otherwise assume
2012                    # it is always true
2013                    if self.register_buffer is torch.nn.Module.register_buffer:
2014                        self.register_buffer(name, value, persistent)
2015                    else:
2016                        sign = inspect.signature(self.register_buffer)
2017                        if "persistent" in sign.parameters:
2018                            self.register_buffer(name, value, persistent)
2019                        else:
2020                            if not persistent:
2021                                raise RuntimeError(
2022                                    "Registering a non-persistent buffer "
2023                                    "on a Module subclass that implements "
2024                                    "register_buffer() without the persistent "
2025                                    "argument is not allowed."
2026                                )
2027                            # Assume that the implementation without the argument has the
2028                            # behavior from before the argument was added: persistent=True
2029                            self.register_buffer(name, value)
2030                    # === HACK END ===
2031                else:
2032                    super().__setattr__(name, value)
2033
2034    def __delattr__(self, name):
2035        if name in self._parameters:
2036            del self._parameters[name]
2037        elif name in self._buffers:
2038            del self._buffers[name]
2039            self._non_persistent_buffers_set.discard(name)
2040        elif name in self._modules:
2041            del self._modules[name]
2042        else:
2043            super().__delattr__(name)
2044
2045    def _register_state_dict_hook(self, hook):
2046        r"""Register a post-hook for the :meth:`~torch.nn.Module.state_dict` method.
2047
2048        It should have the following signature::
2049            hook(module, state_dict, prefix, local_metadata) -> None or state_dict
2050
2051        The registered hooks can modify the ``state_dict`` inplace or return a new one.
2052        If a new ``state_dict`` is returned, it will only be respected if it is the root
2053        module that :meth:`~nn.Module.state_dict` is called from.
2054        """
2055        if getattr(hook, "_from_public_api", False):
2056            raise RuntimeError(
2057                "Cannot register the same function as the state dict post hook that was "
2058                "previously registered via register_state_dict_post_hook"
2059            )
2060        handle = RemovableHandle(self._state_dict_hooks)
2061        self._state_dict_hooks[handle.id] = hook
2062        return handle
2063
2064    def register_state_dict_post_hook(self, hook):
2065        r"""Register a post-hook for the :meth:`~torch.nn.Module.state_dict` method.
2066
2067        It should have the following signature::
2068            hook(module, state_dict, prefix, local_metadata) -> None
2069
2070        The registered hooks can modify the ``state_dict`` inplace.
2071        """
2072        # In _register_state_dict_hook there was a bug described in
2073        # https://github.com/pytorch/pytorch/issues/117437 where the return value
2074        # was only respected for the root module but not child submodules.
2075        # We fix this in this public version by only allowing inplace modifications on
2076        # the state_dict by the hook. However, since hooks registered via both these
2077        # APIs will be added to `_state_dict_hooks` and the type of `_state_dict_hooks`
2078        # cannot be changed due to many dependencies on it, we mark a hook
2079        # as being registered via the public API by setting `_from_public_api` on it.
2080        # In the implementation of `state_dict`, if the callable does not have this
2081        # flag, the old behavior of respecting the return value will be preserved
2082        # for the root module, otherwise, we ensure that the hook returns None.
2083        hook._from_public_api = True
2084        handle = RemovableHandle(self._state_dict_hooks)
2085        self._state_dict_hooks[handle.id] = hook
2086        return handle
2087
2088    def register_state_dict_pre_hook(self, hook):
2089        r"""Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method.
2090
2091        It should have the following signature::
2092            hook(module, prefix, keep_vars) -> None
2093
2094        The registered hooks can be used to perform pre-processing before the ``state_dict``
2095        call is made.
2096        """
2097        handle = RemovableHandle(self._state_dict_pre_hooks)
2098        self._state_dict_pre_hooks[handle.id] = hook
2099        return handle
2100
2101    def _save_to_state_dict(self, destination, prefix, keep_vars):
2102        r"""Save module state to the `destination` dictionary.
2103
2104        The `destination` dictionary will contain the state
2105        of the module, but not its descendants. This is called on every
2106        submodule in :meth:`~torch.nn.Module.state_dict`.
2107
2108        In rare cases, subclasses can achieve class-specific behavior by
2109        overriding this method with custom logic.
2110
2111        Args:
2112            destination (dict): a dict where state will be stored
2113            prefix (str): the prefix for parameters and buffers used in this
2114                module
2115        """
2116        for name, param in self._parameters.items():
2117            if param is not None:
2118                destination[prefix + name] = param if keep_vars else param.detach()
2119        for name, buf in self._buffers.items():
2120            if buf is not None and name not in self._non_persistent_buffers_set:
2121                destination[prefix + name] = buf if keep_vars else buf.detach()
2122        extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
2123        if (
2124            getattr(self.__class__, "get_extra_state", Module.get_extra_state)
2125            is not Module.get_extra_state
2126        ):
2127            destination[extra_state_key] = self.get_extra_state()
2128
2129    # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
2130    # back that same object. But if they pass nothing, an `OrderedDict` is created and returned.
2131    T_destination = TypeVar("T_destination", bound=Dict[str, Any])
2132
2133    @overload
2134    def state_dict(
2135        self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...
2136    ) -> T_destination:
2137        ...
2138
2139    @overload
2140    def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]:
2141        ...
2142
2143    # TODO: Change `*args` to `*` and remove the corresponding warning in docs when BC allows.
2144    # Also remove the logic for arg parsing together.
2145    def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
2146        r"""Return a dictionary containing references to the whole state of the module.
2147
2148        Both parameters and persistent buffers (e.g. running averages) are
2149        included. Keys are corresponding parameter and buffer names.
2150        Parameters and buffers set to ``None`` are not included.
2151
2152        .. note::
2153            The returned object is a shallow copy. It contains references
2154            to the module's parameters and buffers.
2155
2156        .. warning::
2157            Currently ``state_dict()`` also accepts positional arguments for
2158            ``destination``, ``prefix`` and ``keep_vars`` in order. However,
2159            this is being deprecated and keyword arguments will be enforced in
2160            future releases.
2161
2162        .. warning::
2163            Please avoid the use of argument ``destination`` as it is not
2164            designed for end-users.
2165
2166        Args:
2167            destination (dict, optional): If provided, the state of module will
2168                be updated into the dict and the same object is returned.
2169                Otherwise, an ``OrderedDict`` will be created and returned.
2170                Default: ``None``.
2171            prefix (str, optional): a prefix added to parameter and buffer
2172                names to compose the keys in state_dict. Default: ``''``.
2173            keep_vars (bool, optional): by default the :class:`~torch.Tensor` s
2174                returned in the state dict are detached from autograd. If it's
2175                set to ``True``, detaching will not be performed.
2176                Default: ``False``.
2177
2178        Returns:
2179            dict:
2180                a dictionary containing a whole state of the module
2181
2182        Example::
2183
2184            >>> # xdoctest: +SKIP("undefined vars")
2185            >>> module.state_dict().keys()
2186            ['bias', 'weight']
2187
2188        """
2189        # TODO: Remove `args` and the parsing logic when BC allows.
2190        if len(args) > 0:
2191            # DeprecationWarning is ignored by default
2192            warnings.warn(
2193                "Positional args are being deprecated, use kwargs instead. Refer to "
2194                "https://pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.state_dict"
2195                " for details.",
2196                FutureWarning,
2197                stacklevel=2,
2198            )
2199            if destination is None:
2200                destination = args[0]
2201            if len(args) > 1 and prefix == "":
2202                prefix = args[1]
2203            if len(args) > 2 and keep_vars is False:
2204                keep_vars = args[2]
2205
2206        if destination is None:
2207            destination = OrderedDict()
2208            destination._metadata = OrderedDict()
2209
2210        local_metadata = dict(version=self._version)
2211        if hasattr(destination, "_metadata"):
2212            destination._metadata[prefix[:-1]] = local_metadata
2213
2214        for hook in self._state_dict_pre_hooks.values():
2215            hook(self, prefix, keep_vars)
2216        self._save_to_state_dict(destination, prefix, keep_vars)
2217        for name, module in self._modules.items():
2218            if module is not None:
2219                module.state_dict(
2220                    destination=destination,
2221                    prefix=prefix + name + ".",
2222                    keep_vars=keep_vars,
2223                )
2224        for hook in self._state_dict_hooks.values():
2225            hook_result = hook(self, destination, prefix, local_metadata)
2226            if not getattr(hook, "_from_public_api", False):
2227                if hook_result is not None:
2228                    destination = hook_result
2229            else:
2230                if hook_result is not None:
2231                    raise RuntimeError("state_dict post-hook must return None")
2232        return destination
2233
2234    def _register_load_state_dict_pre_hook(self, hook, with_module=False):
2235        r"""See :meth:`~torch.nn.Module.register_load_state_dict_pre_hook` for details.
2236
2237        A subtle difference is that if ``with_module`` is set to ``False``, then the
2238        hook will not take the ``module`` as the first argument whereas
2239        :meth:`~torch.nn.Module.register_load_state_dict_pre_hook` always takes the
2240        ``module`` as the first argument.
2241
2242        Arguments:
2243            hook (Callable): Callable hook that will be invoked before
2244                loading the state dict.
2245            with_module (bool, optional): Whether or not to pass the module
2246                instance to the hook as the first parameter.
2247        """
2248        handle = RemovableHandle(self._load_state_dict_pre_hooks)
2249        self._load_state_dict_pre_hooks[handle.id] = _WrappedHook(
2250            hook, self if with_module else None
2251        )
2252        return handle
2253
2254    def register_load_state_dict_pre_hook(self, hook):
2255        r"""Register a pre-hook to be run before module's :meth:`~nn.Module.load_state_dict` is called.
2256
2257        It should have the following signature::
2258            hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None  # noqa: B950
2259
2260        Arguments:
2261            hook (Callable): Callable hook that will be invoked before
2262                loading the state dict.
2263        """
2264        return self._register_load_state_dict_pre_hook(hook, with_module=True)
2265
2266    def register_load_state_dict_post_hook(self, hook):
2267        r"""Register a post-hook to be run after module's :meth:`~nn.Module.load_state_dict` is called.
2268
2269        It should have the following signature::
2270            hook(module, incompatible_keys) -> None
2271
2272        The ``module`` argument is the current module that this hook is registered
2273        on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting
2274        of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys``
2275        is a ``list`` of ``str`` containing the missing keys and
2276        ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.
2277
2278        The given incompatible_keys can be modified inplace if needed.
2279
2280        Note that the checks performed when calling :func:`load_state_dict` with
2281        ``strict=True`` are affected by modifications the hook makes to
2282        ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either
2283        set of keys will result in an error being thrown when ``strict=True``, and
2284        clearing out both missing and unexpected keys will avoid an error.
2285
2286        Returns:
2287            :class:`torch.utils.hooks.RemovableHandle`:
2288                a handle that can be used to remove the added hook by calling
2289                ``handle.remove()``
2290        """
2291        handle = RemovableHandle(self._load_state_dict_post_hooks)
2292        self._load_state_dict_post_hooks[handle.id] = hook
2293        return handle
2294
2295    def _load_from_state_dict(
2296        self,
2297        state_dict,
2298        prefix,
2299        local_metadata,
2300        strict,
2301        missing_keys,
2302        unexpected_keys,
2303        error_msgs,
2304    ):
2305        r"""Copy parameters and buffers from :attr:`state_dict` into only this module, but not its descendants.
2306
2307        This is called on every submodule
2308        in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
2309        module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
2310        For state dicts without metadata, :attr:`local_metadata` is empty.
2311        Subclasses can achieve class-specific backward compatible loading using
2312        the version number at `local_metadata.get("version", None)`.
2313        Additionally, :attr:`local_metadata` can also contain the key
2314        `assign_to_params_buffers` that indicates whether keys should be
2315        assigned their corresponding tensor in the state_dict.
2316
2317        .. note::
2318            :attr:`state_dict` is not the same object as the input
2319            :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
2320            it can be modified.
2321
2322        Args:
2323            state_dict (dict): a dict containing parameters and
2324                persistent buffers.
2325            prefix (str): the prefix for parameters and buffers used in this
2326                module
2327            local_metadata (dict): a dict containing the metadata for this module.
2328                See
2329            strict (bool): whether to strictly enforce that the keys in
2330                :attr:`state_dict` with :attr:`prefix` match the names of
2331                parameters and buffers in this module
2332            missing_keys (list of str): if ``strict=True``, add missing keys to
2333                this list
2334            unexpected_keys (list of str): if ``strict=True``, add unexpected
2335                keys to this list
2336            error_msgs (list of str): error messages should be added to this
2337                list, and will be reported together in
2338                :meth:`~torch.nn.Module.load_state_dict`
2339        """
2340        for hook in self._load_state_dict_pre_hooks.values():
2341            hook(
2342                state_dict,
2343                prefix,
2344                local_metadata,
2345                strict,
2346                missing_keys,
2347                unexpected_keys,
2348                error_msgs,
2349            )
2350
2351        persistent_buffers = {
2352            k: v
2353            for k, v in self._buffers.items()
2354            if k not in self._non_persistent_buffers_set
2355        }
2356        local_name_params = itertools.chain(
2357            self._parameters.items(), persistent_buffers.items()
2358        )
2359        local_state = {k: v for k, v in local_name_params if v is not None}
2360        assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
2361        use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion()
2362
2363        for name, param in local_state.items():
2364            key = prefix + name
2365            if key in state_dict:
2366                input_param = state_dict[key]
2367                if not torch.overrides.is_tensor_like(input_param):
2368                    error_msgs.append(
2369                        f'While copying the parameter named "{key}", '
2370                        "expected torch.Tensor or Tensor-like object from checkpoint but "
2371                        f"received {type(input_param)}"
2372                    )
2373                    continue
2374
2375                # This is used to avoid copying uninitialized parameters into
2376                # non-lazy modules, since they dont have the hook to do the checks
2377                # in such case, it will error when accessing the .shape attribute.
2378                is_param_lazy = torch.nn.parameter.is_lazy(param)
2379                # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
2380                if (
2381                    not is_param_lazy
2382                    and len(param.shape) == 0
2383                    and len(input_param.shape) == 1
2384                ):
2385                    input_param = input_param[0]
2386
2387                if not is_param_lazy and input_param.shape != param.shape:
2388                    # local shape should match the one in checkpoint
2389                    error_msgs.append(
2390                        f"size mismatch for {key}: copying a param with shape {input_param.shape} from checkpoint, "
2391                        f"the shape in current model is {param.shape}."
2392                    )
2393                    continue
2394
2395                if (
2396                    param.is_meta
2397                    and not input_param.is_meta
2398                    and not assign_to_params_buffers
2399                ):
2400                    warnings.warn(
2401                        f"for {key}: copying from a non-meta parameter in the checkpoint to a meta "
2402                        "parameter in the current model, which is a no-op. (Did you mean to "
2403                        "pass `assign=True` to assign items in the state dictionary to their "
2404                        "corresponding key in the module instead of copying them in place?)"
2405                    )
2406
2407                try:
2408                    with torch.no_grad():
2409                        if use_swap_tensors:
2410                            new_input_param = param.module_load(
2411                                input_param, assign=assign_to_params_buffers
2412                            )
2413                            if id(new_input_param) == id(input_param) or id(
2414                                new_input_param
2415                            ) == id(param):
2416                                raise RuntimeError(
2417                                    "module_load returned one of self or other, please .detach() "
2418                                    "the result if returning one of the inputs in module_load"
2419                                )
2420                            if isinstance(param, torch.nn.Parameter):
2421                                if not isinstance(new_input_param, torch.nn.Parameter):
2422                                    new_input_param = torch.nn.Parameter(
2423                                        new_input_param,
2424                                        requires_grad=param.requires_grad,
2425                                    )
2426                                else:
2427                                    new_input_param.requires_grad_(param.requires_grad)
2428                            torch.utils.swap_tensors(param, new_input_param)
2429                            del new_input_param
2430                        elif assign_to_params_buffers:
2431                            # Shape checks are already done above
2432                            if isinstance(param, torch.nn.Parameter):
2433                                if not isinstance(input_param, torch.nn.Parameter):
2434                                    input_param = torch.nn.Parameter(
2435                                        input_param, requires_grad=param.requires_grad
2436                                    )
2437                                else:
2438                                    input_param.requires_grad_(param.requires_grad)
2439                            setattr(self, name, input_param)
2440                        else:
2441                            param.copy_(input_param)
2442                except Exception as ex:
2443                    action = "swapping" if use_swap_tensors else "copying"
2444                    error_msgs.append(
2445                        f'While {action} the parameter named "{key}", '
2446                        f"whose dimensions in the model are {param.size()} and "
2447                        f"whose dimensions in the checkpoint are {input_param.size()}, "
2448                        f"an exception occurred : {ex.args}."
2449                    )
2450            elif strict:
2451                missing_keys.append(key)
2452
2453        extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
2454        if (
2455            getattr(self.__class__, "set_extra_state", Module.set_extra_state)
2456            is not Module.set_extra_state
2457        ):
2458            if extra_state_key in state_dict:
2459                self.set_extra_state(state_dict[extra_state_key])
2460            elif strict:
2461                missing_keys.append(extra_state_key)
2462        elif strict and (extra_state_key in state_dict):
2463            unexpected_keys.append(extra_state_key)
2464
2465        if strict:
2466            for key in state_dict.keys():
2467                if key.startswith(prefix) and key != extra_state_key:
2468                    input_name = key[len(prefix) :].split(".", 1)
2469                    # Must be Module if it have attributes
2470                    if len(input_name) > 1:
2471                        if input_name[0] not in self._modules:
2472                            unexpected_keys.append(key)
2473                    elif input_name[0] not in local_state:
2474                        unexpected_keys.append(key)
2475
2476    def load_state_dict(
2477        self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
2478    ):
2479        r"""Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.
2480
2481        If :attr:`strict` is ``True``, then
2482        the keys of :attr:`state_dict` must exactly match the keys returned
2483        by this module's :meth:`~torch.nn.Module.state_dict` function.
2484
2485        .. warning::
2486            If :attr:`assign` is ``True`` the optimizer must be created after
2487            the call to :attr:`load_state_dict` unless
2488            :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.
2489
2490        Args:
2491            state_dict (dict): a dict containing parameters and
2492                persistent buffers.
2493            strict (bool, optional): whether to strictly enforce that the keys
2494                in :attr:`state_dict` match the keys returned by this module's
2495                :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
2496            assign (bool, optional): When ``False``, the properties of the tensors
2497                in the current module are preserved while when ``True``, the
2498                properties of the Tensors in the state dict are preserved. The only
2499                exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s
2500                for which the value from the module is preserved.
2501                Default: ``False``
2502
2503        Returns:
2504            ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
2505                * **missing_keys** is a list of str containing any keys that are expected
2506                    by this module but missing from the provided ``state_dict``.
2507                * **unexpected_keys** is a list of str containing the keys that are not
2508                    expected by this module but present in the provided ``state_dict``.
2509
2510        Note:
2511            If a parameter or buffer is registered as ``None`` and its corresponding key
2512            exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
2513            ``RuntimeError``.
2514        """
2515        if not isinstance(state_dict, Mapping):
2516            raise TypeError(
2517                f"Expected state_dict to be dict-like, got {type(state_dict)}."
2518            )
2519
2520        missing_keys: List[str] = []
2521        unexpected_keys: List[str] = []
2522        error_msgs: List[str] = []
2523
2524        # copy state_dict so _load_from_state_dict can modify it
2525        metadata = getattr(state_dict, "_metadata", None)
2526        state_dict = OrderedDict(state_dict)
2527        if metadata is not None:
2528            # mypy isn't aware that "_metadata" exists in state_dict
2529            state_dict._metadata = metadata  # type: ignore[attr-defined]
2530
2531        def load(module, local_state_dict, prefix=""):
2532            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
2533            if assign:
2534                local_metadata["assign_to_params_buffers"] = assign
2535            module._load_from_state_dict(
2536                local_state_dict,
2537                prefix,
2538                local_metadata,
2539                True,
2540                missing_keys,
2541                unexpected_keys,
2542                error_msgs,
2543            )
2544            for name, child in module._modules.items():
2545                if child is not None:
2546                    child_prefix = prefix + name + "."
2547                    child_state_dict = {
2548                        k: v
2549                        for k, v in local_state_dict.items()
2550                        if k.startswith(child_prefix)
2551                    }
2552                    load(child, child_state_dict, child_prefix)  # noqa: F821
2553
2554            # Note that the hook can modify missing_keys and unexpected_keys.
2555            incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
2556            for hook in module._load_state_dict_post_hooks.values():
2557                out = hook(module, incompatible_keys)
2558                assert out is None, (
2559                    "Hooks registered with ``register_load_state_dict_post_hook`` are not"
2560                    "expected to return new values, if incompatible_keys need to be modified,"
2561                    "it should be done inplace."
2562                )
2563
2564        load(self, state_dict)
2565        del load
2566
2567        if strict:
2568            if len(unexpected_keys) > 0:
2569                error_msgs.insert(
2570                    0,
2571                    "Unexpected key(s) in state_dict: {}. ".format(
2572                        ", ".join(f'"{k}"' for k in unexpected_keys)
2573                    ),
2574                )
2575            if len(missing_keys) > 0:
2576                error_msgs.insert(
2577                    0,
2578                    "Missing key(s) in state_dict: {}. ".format(
2579                        ", ".join(f'"{k}"' for k in missing_keys)
2580                    ),
2581                )
2582
2583        if len(error_msgs) > 0:
2584            raise RuntimeError(
2585                "Error(s) in loading state_dict for {}:\n\t{}".format(
2586                    self.__class__.__name__, "\n\t".join(error_msgs)
2587                )
2588            )
2589        return _IncompatibleKeys(missing_keys, unexpected_keys)
2590
2591    def _named_members(
2592        self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True
2593    ):
2594        r"""Help yield various names + members of modules."""
2595        memo = set()
2596        modules = (
2597            self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate)
2598            if recurse
2599            else [(prefix, self)]
2600        )
2601        for module_prefix, module in modules:
2602            members = get_members_fn(module)
2603            for k, v in members:
2604                if v is None or v in memo:
2605                    continue
2606                if remove_duplicate:
2607                    memo.add(v)
2608                name = module_prefix + ("." if module_prefix else "") + k
2609                yield name, v
2610
2611    def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
2612        r"""Return an iterator over module parameters.
2613
2614        This is typically passed to an optimizer.
2615
2616        Args:
2617            recurse (bool): if True, then yields parameters of this module
2618                and all submodules. Otherwise, yields only parameters that
2619                are direct members of this module.
2620
2621        Yields:
2622            Parameter: module parameter
2623
2624        Example::
2625
2626            >>> # xdoctest: +SKIP("undefined vars")
2627            >>> for param in model.parameters():
2628            >>>     print(type(param), param.size())
2629            <class 'torch.Tensor'> (20L,)
2630            <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
2631
2632        """
2633        for name, param in self.named_parameters(recurse=recurse):
2634            yield param
2635
2636    def named_parameters(
2637        self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
2638    ) -> Iterator[Tuple[str, Parameter]]:
2639        r"""Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
2640
2641        Args:
2642            prefix (str): prefix to prepend to all parameter names.
2643            recurse (bool): if True, then yields parameters of this module
2644                and all submodules. Otherwise, yields only parameters that
2645                are direct members of this module.
2646            remove_duplicate (bool, optional): whether to remove the duplicated
2647                parameters in the result. Defaults to True.
2648
2649        Yields:
2650            (str, Parameter): Tuple containing the name and parameter
2651
2652        Example::
2653
2654            >>> # xdoctest: +SKIP("undefined vars")
2655            >>> for name, param in self.named_parameters():
2656            >>>     if name in ['bias']:
2657            >>>         print(param.size())
2658
2659        """
2660        gen = self._named_members(
2661            lambda module: module._parameters.items(),
2662            prefix=prefix,
2663            recurse=recurse,
2664            remove_duplicate=remove_duplicate,
2665        )
2666        yield from gen
2667
2668    def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
2669        r"""Return an iterator over module buffers.
2670
2671        Args:
2672            recurse (bool): if True, then yields buffers of this module
2673                and all submodules. Otherwise, yields only buffers that
2674                are direct members of this module.
2675
2676        Yields:
2677            torch.Tensor: module buffer
2678
2679        Example::
2680
2681            >>> # xdoctest: +SKIP("undefined vars")
2682            >>> for buf in model.buffers():
2683            >>>     print(type(buf), buf.size())
2684            <class 'torch.Tensor'> (20L,)
2685            <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
2686
2687        """
2688        for _, buf in self.named_buffers(recurse=recurse):
2689            yield buf
2690
2691    def named_buffers(
2692        self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
2693    ) -> Iterator[Tuple[str, Tensor]]:
2694        r"""Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
2695
2696        Args:
2697            prefix (str): prefix to prepend to all buffer names.
2698            recurse (bool, optional): if True, then yields buffers of this module
2699                and all submodules. Otherwise, yields only buffers that
2700                are direct members of this module. Defaults to True.
2701            remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.
2702
2703        Yields:
2704            (str, torch.Tensor): Tuple containing the name and buffer
2705
2706        Example::
2707
2708            >>> # xdoctest: +SKIP("undefined vars")
2709            >>> for name, buf in self.named_buffers():
2710            >>>     if name in ['running_var']:
2711            >>>         print(buf.size())
2712
2713        """
2714        gen = self._named_members(
2715            lambda module: module._buffers.items(),
2716            prefix=prefix,
2717            recurse=recurse,
2718            remove_duplicate=remove_duplicate,
2719        )
2720        yield from gen
2721
2722    def children(self) -> Iterator["Module"]:
2723        r"""Return an iterator over immediate children modules.
2724
2725        Yields:
2726            Module: a child module
2727        """
2728        for name, module in self.named_children():
2729            yield module
2730
2731    def named_children(self) -> Iterator[Tuple[str, "Module"]]:
2732        r"""Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
2733
2734        Yields:
2735            (str, Module): Tuple containing a name and child module
2736
2737        Example::
2738
2739            >>> # xdoctest: +SKIP("undefined vars")
2740            >>> for name, module in model.named_children():
2741            >>>     if name in ['conv4', 'conv5']:
2742            >>>         print(module)
2743
2744        """
2745        memo = set()
2746        for name, module in self._modules.items():
2747            if module is not None and module not in memo:
2748                memo.add(module)
2749                yield name, module
2750
2751    def modules(self) -> Iterator["Module"]:
2752        r"""Return an iterator over all modules in the network.
2753
2754        Yields:
2755            Module: a module in the network
2756
2757        Note:
2758            Duplicate modules are returned only once. In the following
2759            example, ``l`` will be returned only once.
2760
2761        Example::
2762
2763            >>> l = nn.Linear(2, 2)
2764            >>> net = nn.Sequential(l, l)
2765            >>> for idx, m in enumerate(net.modules()):
2766            ...     print(idx, '->', m)
2767
2768            0 -> Sequential(
2769              (0): Linear(in_features=2, out_features=2, bias=True)
2770              (1): Linear(in_features=2, out_features=2, bias=True)
2771            )
2772            1 -> Linear(in_features=2, out_features=2, bias=True)
2773
2774        """
2775        for _, module in self.named_modules():
2776            yield module
2777
2778    def named_modules(
2779        self,
2780        memo: Optional[Set["Module"]] = None,
2781        prefix: str = "",
2782        remove_duplicate: bool = True,
2783    ):
2784        r"""Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
2785
2786        Args:
2787            memo: a memo to store the set of modules already added to the result
2788            prefix: a prefix that will be added to the name of the module
2789            remove_duplicate: whether to remove the duplicated module instances in the result
2790                or not
2791
2792        Yields:
2793            (str, Module): Tuple of name and module
2794
2795        Note:
2796            Duplicate modules are returned only once. In the following
2797            example, ``l`` will be returned only once.
2798
2799        Example::
2800
2801            >>> l = nn.Linear(2, 2)
2802            >>> net = nn.Sequential(l, l)
2803            >>> for idx, m in enumerate(net.named_modules()):
2804            ...     print(idx, '->', m)
2805
2806            0 -> ('', Sequential(
2807              (0): Linear(in_features=2, out_features=2, bias=True)
2808              (1): Linear(in_features=2, out_features=2, bias=True)
2809            ))
2810            1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
2811
2812        """
2813        if memo is None:
2814            memo = set()
2815        if self not in memo:
2816            if remove_duplicate:
2817                memo.add(self)
2818            yield prefix, self
2819            for name, module in self._modules.items():
2820                if module is None:
2821                    continue
2822                submodule_prefix = prefix + ("." if prefix else "") + name
2823                yield from module.named_modules(
2824                    memo, submodule_prefix, remove_duplicate
2825                )
2826
2827    def train(self: T, mode: bool = True) -> T:
2828        r"""Set the module in training mode.
2829
2830        This has any effect only on certain modules. See documentations of
2831        particular modules for details of their behaviors in training/evaluation
2832        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
2833        etc.
2834
2835        Args:
2836            mode (bool): whether to set training mode (``True``) or evaluation
2837                         mode (``False``). Default: ``True``.
2838
2839        Returns:
2840            Module: self
2841        """
2842        if not isinstance(mode, bool):
2843            raise ValueError("training mode is expected to be boolean")
2844        self.training = mode
2845        for module in self.children():
2846            module.train(mode)
2847        return self
2848
2849    def eval(self: T) -> T:
2850        r"""Set the module in evaluation mode.
2851
2852        This has any effect only on certain modules. See documentations of
2853        particular modules for details of their behaviors in training/evaluation
2854        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
2855        etc.
2856
2857        This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
2858
2859        See :ref:`locally-disable-grad-doc` for a comparison between
2860        `.eval()` and several similar mechanisms that may be confused with it.
2861
2862        Returns:
2863            Module: self
2864        """
2865        return self.train(False)
2866
2867    def requires_grad_(self: T, requires_grad: bool = True) -> T:
2868        r"""Change if autograd should record operations on parameters in this module.
2869
2870        This method sets the parameters' :attr:`requires_grad` attributes
2871        in-place.
2872
2873        This method is helpful for freezing part of the module for finetuning
2874        or training parts of a model individually (e.g., GAN training).
2875
2876        See :ref:`locally-disable-grad-doc` for a comparison between
2877        `.requires_grad_()` and several similar mechanisms that may be confused with it.
2878
2879        Args:
2880            requires_grad (bool): whether autograd should record operations on
2881                                  parameters in this module. Default: ``True``.
2882
2883        Returns:
2884            Module: self
2885        """
2886        for p in self.parameters():
2887            p.requires_grad_(requires_grad)
2888        return self
2889
2890    def zero_grad(self, set_to_none: bool = True) -> None:
2891        r"""Reset gradients of all model parameters.
2892
2893        See similar function under :class:`torch.optim.Optimizer` for more context.
2894
2895        Args:
2896            set_to_none (bool): instead of setting to zero, set the grads to None.
2897                See :meth:`torch.optim.Optimizer.zero_grad` for details.
2898        """
2899        if getattr(self, "_is_replica", False):
2900            warnings.warn(
2901                "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. "
2902                "The parameters are copied (in a differentiable manner) from the original module. "
2903                "This means they are not leaf nodes in autograd and so don't accumulate gradients. "
2904                "If you need gradients in your forward method, consider using autograd.grad instead."
2905            )
2906
2907        for p in self.parameters():
2908            if p.grad is not None:
2909                if set_to_none:
2910                    p.grad = None
2911                else:
2912                    if p.grad.grad_fn is not None:
2913                        p.grad.detach_()
2914                    else:
2915                        p.grad.requires_grad_(False)
2916                    p.grad.zero_()
2917
2918    def share_memory(self: T) -> T:
2919        r"""See :meth:`torch.Tensor.share_memory_`."""
2920        return self._apply(lambda t: t.share_memory_())
2921
2922    def _get_name(self):
2923        return self.__class__.__name__
2924
2925    def extra_repr(self) -> str:
2926        r"""Set the extra representation of the module.
2927
2928        To print customized extra information, you should re-implement
2929        this method in your own modules. Both single-line and multi-line
2930        strings are acceptable.
2931        """
2932        return ""
2933
2934    def __repr__(self):
2935        # We treat the extra repr like the sub-module, one item per line
2936        extra_lines = []
2937        extra_repr = self.extra_repr()
2938        # empty string will be split into list ['']
2939        if extra_repr:
2940            extra_lines = extra_repr.split("\n")
2941        child_lines = []
2942        for key, module in self._modules.items():
2943            mod_str = repr(module)
2944            mod_str = _addindent(mod_str, 2)
2945            child_lines.append("(" + key + "): " + mod_str)
2946        lines = extra_lines + child_lines
2947
2948        main_str = self._get_name() + "("
2949        if lines:
2950            # simple one-liner info, which most builtin Modules will use
2951            if len(extra_lines) == 1 and not child_lines:
2952                main_str += extra_lines[0]
2953            else:
2954                main_str += "\n  " + "\n  ".join(lines) + "\n"
2955
2956        main_str += ")"
2957        return main_str
2958
2959    def __dir__(self):
2960        module_attrs = dir(self.__class__)
2961        attrs = list(self.__dict__.keys())
2962        parameters = list(self._parameters.keys())
2963        modules = list(self._modules.keys())
2964        buffers = list(self._buffers.keys())
2965        keys = module_attrs + attrs + parameters + modules + buffers
2966
2967        # Eliminate attrs that are not legal Python variable names
2968        keys = [key for key in keys if not key[0].isdigit()]
2969
2970        return sorted(keys)
2971
2972    def _replicate_for_data_parallel(self):
2973        replica = self.__new__(type(self))
2974        replica.__dict__ = self.__dict__.copy()
2975
2976        # replicas do not have parameters themselves, the replicas reference the original
2977        # module.
2978        replica._parameters = {}
2979        replica._buffers = replica._buffers.copy()
2980        replica._modules = replica._modules.copy()
2981        replica._is_replica = True  # type: ignore[assignment]
2982
2983        return replica
2984
2985    def compile(self, *args, **kwargs):
2986        """
2987        Compile this Module's forward using :func:`torch.compile`.
2988
2989        This Module's `__call__` method is compiled and all arguments are passed as-is
2990        to :func:`torch.compile`.
2991
2992        See :func:`torch.compile` for details on the arguments for this function.
2993        """
2994        self._compiled_call_impl = torch.compile(self._call_impl, *args, **kwargs)
2995