1# mypy: allow-untyped-defs 2 3import functools 4import inspect 5import itertools 6import warnings 7import weakref 8from collections import namedtuple, OrderedDict 9from typing import ( 10 Any, 11 Callable, 12 Dict, 13 Iterator, 14 List, 15 Mapping, 16 Optional, 17 overload, 18 Set, 19 Tuple, 20 TypeVar, 21 Union, 22) 23from typing_extensions import Self 24 25import torch 26from torch import device, dtype, Tensor 27from torch._prims_common import DeviceLikeType 28from torch.nn.parameter import Buffer, Parameter 29from torch.utils._python_dispatch import is_traceable_wrapper_subclass 30from torch.utils.hooks import BackwardHook, RemovableHandle 31 32 33__all__ = [ 34 "register_module_forward_pre_hook", 35 "register_module_forward_hook", 36 "register_module_full_backward_pre_hook", 37 "register_module_backward_hook", 38 "register_module_full_backward_hook", 39 "register_module_buffer_registration_hook", 40 "register_module_module_registration_hook", 41 "register_module_parameter_registration_hook", 42 "Module", 43] 44 45_grad_t = Union[Tuple[Tensor, ...], Tensor] 46# See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use 47# of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be 48# the type of the subclass, not the looser type of `Module`. 49T = TypeVar("T", bound="Module") 50 51 52class _IncompatibleKeys( 53 namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]), 54): 55 def __repr__(self): 56 if not self.missing_keys and not self.unexpected_keys: 57 return "<All keys matched successfully>" 58 return super().__repr__() 59 60 __str__ = __repr__ 61 62 63def _addindent(s_, numSpaces): 64 s = s_.split("\n") 65 # don't do anything for single-line stuff 66 if len(s) == 1: 67 return s_ 68 first = s.pop(0) 69 s = [(numSpaces * " ") + line for line in s] 70 s = "\n".join(s) 71 s = first + "\n" + s 72 return s 73 74 75r"""This tracks hooks common to all modules that are executed immediately before 76.registering the buffer/module/parameter""" 77_global_buffer_registration_hooks: Dict[int, Callable] = OrderedDict() 78_global_module_registration_hooks: Dict[int, Callable] = OrderedDict() 79_global_parameter_registration_hooks: Dict[int, Callable] = OrderedDict() 80 81 82class _WrappedHook: 83 def __init__(self, hook: Callable, module: Optional["Module"] = None): 84 self.hook: Callable = hook 85 functools.update_wrapper(self, hook) 86 87 self.with_module: bool = False 88 89 if module is not None: 90 self.module: weakref.ReferenceType[Module] = weakref.ref(module) 91 self.with_module = True 92 93 def __call__(self, *args: Any, **kwargs: Any) -> Any: 94 if self.with_module: 95 module = self.module() 96 if module is None: 97 raise RuntimeError("You are trying to call the hook of a dead Module!") 98 return self.hook(module, *args, **kwargs) 99 return self.hook(*args, **kwargs) 100 101 def __getstate__(self) -> Dict: 102 result = {"hook": self.hook, "with_module": self.with_module} 103 if self.with_module: 104 result["module"] = self.module() 105 106 return result 107 108 def __setstate__(self, state: Dict): 109 self.hook = state["hook"] 110 self.with_module = state["with_module"] 111 112 if self.with_module: 113 if state["module"] is None: 114 raise RuntimeError( 115 "You are trying to revive the hook of a dead Module!" 116 ) 117 self.module = weakref.ref(state["module"]) 118 119 120r"""This tracks hooks common to all modules that are executed before/after 121calling forward and backward. This is global state used for debugging/profiling 122purposes""" 123_global_backward_pre_hooks: Dict[int, Callable] = OrderedDict() 124_global_backward_hooks: Dict[int, Callable] = OrderedDict() 125_global_is_full_backward_hook: Optional[bool] = None 126_global_forward_pre_hooks: Dict[int, Callable] = OrderedDict() 127_global_forward_hooks: Dict[int, Callable] = OrderedDict() 128_global_forward_hooks_always_called: Dict[int, bool] = OrderedDict() 129 130_EXTRA_STATE_KEY_SUFFIX = "_extra_state" 131 132 133def register_module_buffer_registration_hook( 134 hook: Callable[..., None], 135) -> RemovableHandle: 136 r"""Register a buffer registration hook common to all modules. 137 138 .. warning :: 139 140 This adds global state to the `nn.Module` module 141 142 The hook will be called every time :func:`register_buffer` is invoked. 143 It should have the following signature:: 144 145 hook(module, name, buffer) -> None or new buffer 146 147 The hook can modify the input or return a single modified value in the hook. 148 149 Returns: 150 :class:`torch.utils.hooks.RemovableHandle`: 151 a handle that can be used to remove the added hook by calling 152 ``handle.remove()`` 153 """ 154 handle = RemovableHandle(_global_buffer_registration_hooks) 155 _global_buffer_registration_hooks[handle.id] = hook 156 return handle 157 158 159def register_module_module_registration_hook( 160 hook: Callable[..., None], 161) -> RemovableHandle: 162 r"""Register a module registration hook common to all modules. 163 164 .. warning :: 165 166 This adds global state to the `nn.Module` module 167 168 The hook will be called every time :func:`register_module` is invoked. 169 It should have the following signature:: 170 171 hook(module, name, submodule) -> None or new submodule 172 173 The hook can modify the input or return a single modified value in the hook. 174 175 Returns: 176 :class:`torch.utils.hooks.RemovableHandle`: 177 a handle that can be used to remove the added hook by calling 178 ``handle.remove()`` 179 """ 180 handle = RemovableHandle(_global_module_registration_hooks) 181 _global_module_registration_hooks[handle.id] = hook 182 return handle 183 184 185def register_module_parameter_registration_hook( 186 hook: Callable[..., None], 187) -> RemovableHandle: 188 r"""Register a parameter registration hook common to all modules. 189 190 .. warning :: 191 192 This adds global state to the `nn.Module` module 193 194 The hook will be called every time :func:`register_parameter` is invoked. 195 It should have the following signature:: 196 197 hook(module, name, param) -> None or new parameter 198 199 The hook can modify the input or return a single modified value in the hook. 200 201 Returns: 202 :class:`torch.utils.hooks.RemovableHandle`: 203 a handle that can be used to remove the added hook by calling 204 ``handle.remove()`` 205 """ 206 handle = RemovableHandle(_global_parameter_registration_hooks) 207 _global_parameter_registration_hooks[handle.id] = hook 208 return handle 209 210 211def register_module_forward_pre_hook(hook: Callable[..., None]) -> RemovableHandle: 212 r"""Register a forward pre-hook common to all modules. 213 214 .. warning :: 215 216 This adds global state to the `nn.module` module 217 and it is only intended for debugging/profiling purposes. 218 219 The hook will be called every time before :func:`forward` is invoked. 220 It should have the following signature:: 221 222 hook(module, input) -> None or modified input 223 224 The input contains only the positional arguments given to the module. 225 Keyword arguments won't be passed to the hooks and only to the ``forward``. 226 The hook can modify the input. User can either return a tuple or a 227 single modified value in the hook. We will wrap the value into a tuple 228 if a single value is returned(unless that value is already a tuple). 229 230 This hook has precedence over the specific module hooks registered with 231 ``register_forward_pre_hook``. 232 233 Returns: 234 :class:`torch.utils.hooks.RemovableHandle`: 235 a handle that can be used to remove the added hook by calling 236 ``handle.remove()`` 237 """ 238 handle = RemovableHandle(_global_forward_pre_hooks) 239 _global_forward_pre_hooks[handle.id] = hook 240 return handle 241 242 243def register_module_forward_hook( 244 hook: Callable[..., None], 245 *, 246 always_call: bool = False, 247) -> RemovableHandle: 248 r"""Register a global forward hook for all the modules. 249 250 .. warning :: 251 252 This adds global state to the `nn.module` module 253 and it is only intended for debugging/profiling purposes. 254 255 The hook will be called every time after :func:`forward` has computed an output. 256 It should have the following signature:: 257 258 hook(module, input, output) -> None or modified output 259 260 The input contains only the positional arguments given to the module. 261 Keyword arguments won't be passed to the hooks and only to the ``forward``. 262 The hook can modify the output. It can modify the input inplace but 263 it will not have effect on forward since this is called after 264 :func:`forward` is called. 265 266 Parameters: 267 hook (Callable): The user defined hook to be registered. 268 always_call (bool): If ``True`` the ``hook`` will be run regardless of 269 whether an exception is raised while calling the Module. 270 Default: ``False`` 271 Returns: 272 :class:`torch.utils.hooks.RemovableHandle`: 273 a handle that can be used to remove the added hook by calling 274 ``handle.remove()`` 275 276 This hook will be executed before specific module hooks registered with 277 ``register_forward_hook``. 278 """ 279 handle = RemovableHandle( 280 _global_forward_hooks, extra_dict=_global_forward_hooks_always_called 281 ) 282 _global_forward_hooks[handle.id] = hook 283 if always_call: 284 _global_forward_hooks_always_called[handle.id] = True 285 return handle 286 287 288def register_module_backward_hook( 289 hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]], 290) -> RemovableHandle: 291 r"""Register a backward hook common to all the modules. 292 293 This function is deprecated in favor of 294 :func:`torch.nn.modules.module.register_module_full_backward_hook` 295 and the behavior of this function will change in future versions. 296 297 Returns: 298 :class:`torch.utils.hooks.RemovableHandle`: 299 a handle that can be used to remove the added hook by calling 300 ``handle.remove()`` 301 302 """ 303 global _global_is_full_backward_hook 304 if _global_is_full_backward_hook is True: 305 raise RuntimeError( 306 "Cannot use both regular backward hooks and full backward hooks as a " 307 "global Module hook. Please use only one of them." 308 ) 309 310 _global_is_full_backward_hook = False 311 312 handle = RemovableHandle(_global_backward_hooks) 313 _global_backward_hooks[handle.id] = hook 314 return handle 315 316 317def register_module_full_backward_pre_hook( 318 hook: Callable[["Module", _grad_t], Union[None, _grad_t]], 319) -> RemovableHandle: 320 r"""Register a backward pre-hook common to all the modules. 321 322 .. warning :: 323 This adds global state to the `nn.module` module 324 and it is only intended for debugging/profiling purposes. 325 326 Hooks registered using this function behave in the same way as those 327 registered by :meth:`torch.nn.Module.register_full_backward_pre_hook`. 328 Refer to its documentation for more details. 329 330 Hooks registered using this function will be called before hooks registered 331 using :meth:`torch.nn.Module.register_full_backward_pre_hook`. 332 333 Returns: 334 :class:`torch.utils.hooks.RemovableHandle`: 335 a handle that can be used to remove the added hook by calling 336 ``handle.remove()`` 337 338 """ 339 handle = RemovableHandle(_global_backward_pre_hooks) 340 _global_backward_pre_hooks[handle.id] = hook 341 return handle 342 343 344def register_module_full_backward_hook( 345 hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]], 346) -> RemovableHandle: 347 r"""Register a backward hook common to all the modules. 348 349 .. warning :: 350 This adds global state to the `nn.module` module 351 and it is only intended for debugging/profiling purposes. 352 353 Hooks registered using this function behave in the same way as those 354 registered by :meth:`torch.nn.Module.register_full_backward_hook`. 355 Refer to its documentation for more details. 356 357 Hooks registered using this function will be called before hooks registered 358 using :meth:`torch.nn.Module.register_full_backward_hook`. 359 360 Returns: 361 :class:`torch.utils.hooks.RemovableHandle`: 362 a handle that can be used to remove the added hook by calling 363 ``handle.remove()`` 364 365 """ 366 global _global_is_full_backward_hook 367 if _global_is_full_backward_hook is False: 368 raise RuntimeError( 369 "Cannot use both regular backward hooks and full backward hooks as a " 370 "global Module hook. Please use only one of them." 371 ) 372 373 _global_is_full_backward_hook = True 374 375 handle = RemovableHandle(_global_backward_hooks) 376 _global_backward_hooks[handle.id] = hook 377 return handle 378 379 380# Trick mypy into not applying contravariance rules to inputs by defining 381# forward as a value, rather than a function. See also 382# https://github.com/python/mypy/issues/8795 383def _forward_unimplemented(self, *input: Any) -> None: 384 r"""Define the computation performed at every call. 385 386 Should be overridden by all subclasses. 387 388 .. note:: 389 Although the recipe for forward pass needs to be defined within 390 this function, one should call the :class:`Module` instance afterwards 391 instead of this since the former takes care of running the 392 registered hooks while the latter silently ignores them. 393 """ 394 raise NotImplementedError( 395 f'Module [{type(self).__name__}] is missing the required "forward" function' 396 ) 397 398 399class Module: 400 r"""Base class for all neural network modules. 401 402 Your models should also subclass this class. 403 404 Modules can also contain other Modules, allowing to nest them in 405 a tree structure. You can assign the submodules as regular attributes:: 406 407 import torch.nn as nn 408 import torch.nn.functional as F 409 410 class Model(nn.Module): 411 def __init__(self) -> None: 412 super().__init__() 413 self.conv1 = nn.Conv2d(1, 20, 5) 414 self.conv2 = nn.Conv2d(20, 20, 5) 415 416 def forward(self, x): 417 x = F.relu(self.conv1(x)) 418 return F.relu(self.conv2(x)) 419 420 Submodules assigned in this way will be registered, and will have their 421 parameters converted too when you call :meth:`to`, etc. 422 423 .. note:: 424 As per the example above, an ``__init__()`` call to the parent class 425 must be made before assignment on the child. 426 427 :ivar training: Boolean represents whether this module is in training or 428 evaluation mode. 429 :vartype training: bool 430 """ 431 432 dump_patches: bool = False 433 434 _version: int = 1 435 r"""This allows better BC support for :meth:`load_state_dict`. In 436 :meth:`state_dict`, the version number will be saved as in the attribute 437 `_metadata` of the returned state dict, and thus pickled. `_metadata` is a 438 dictionary with keys that follow the naming convention of state dict. See 439 ``_load_from_state_dict`` on how to use this information in loading. 440 441 If new parameters/buffers are added/removed from a module, this number shall 442 be bumped, and the module's `_load_from_state_dict` method can compare the 443 version number and do appropriate changes if the state dict is from before 444 the change.""" 445 446 training: bool 447 _parameters: Dict[str, Optional[Parameter]] 448 _buffers: Dict[str, Optional[Tensor]] 449 _non_persistent_buffers_set: Set[str] 450 _backward_pre_hooks: Dict[int, Callable] 451 _backward_hooks: Dict[int, Callable] 452 _is_full_backward_hook: Optional[bool] 453 _forward_hooks: Dict[int, Callable] 454 # Marks whether the corresponding _forward_hooks accept kwargs or not. 455 # As JIT does not support Set[int], this dict is used as a set, where all 456 # hooks represented in this dict accept kwargs. 457 _forward_hooks_with_kwargs: Dict[int, bool] 458 # forward hooks that should always be called even if an exception is raised 459 _forward_hooks_always_called: Dict[int, bool] 460 _forward_pre_hooks: Dict[int, Callable] 461 # Marks whether the corresponding _forward_hooks accept kwargs or not. 462 # As JIT does not support Set[int], this dict is used as a set, where all 463 # hooks represented in this dict accept kwargs. 464 _forward_pre_hooks_with_kwargs: Dict[int, bool] 465 _state_dict_hooks: Dict[int, Callable] 466 _load_state_dict_pre_hooks: Dict[int, Callable] 467 _state_dict_pre_hooks: Dict[int, Callable] 468 _load_state_dict_post_hooks: Dict[int, Callable] 469 _modules: Dict[str, Optional["Module"]] 470 call_super_init: bool = False 471 _compiled_call_impl: Optional[Callable] = None 472 473 def __init__(self, *args, **kwargs) -> None: 474 """Initialize internal Module state, shared by both nn.Module and ScriptModule.""" 475 torch._C._log_api_usage_once("python.nn_module") 476 477 # Backward compatibility: no args used to be allowed when call_super_init=False 478 if self.call_super_init is False and bool(kwargs): 479 raise TypeError( 480 f"{type(self).__name__}.__init__() got an unexpected keyword argument '{next(iter(kwargs))}'" 481 "" 482 ) 483 484 if self.call_super_init is False and bool(args): 485 raise TypeError( 486 f"{type(self).__name__}.__init__() takes 1 positional argument but {len(args) + 1} were" 487 " given" 488 ) 489 490 """ 491 Calls super().__setattr__('a', a) instead of the typical self.a = a 492 to avoid Module.__setattr__ overhead. Module's __setattr__ has special 493 handling for parameters, submodules, and buffers but simply calls into 494 super().__setattr__ for all other attributes. 495 """ 496 super().__setattr__("training", True) 497 super().__setattr__("_parameters", {}) 498 super().__setattr__("_buffers", {}) 499 super().__setattr__("_non_persistent_buffers_set", set()) 500 super().__setattr__("_backward_pre_hooks", OrderedDict()) 501 super().__setattr__("_backward_hooks", OrderedDict()) 502 super().__setattr__("_is_full_backward_hook", None) 503 super().__setattr__("_forward_hooks", OrderedDict()) 504 super().__setattr__("_forward_hooks_with_kwargs", OrderedDict()) 505 super().__setattr__("_forward_hooks_always_called", OrderedDict()) 506 super().__setattr__("_forward_pre_hooks", OrderedDict()) 507 super().__setattr__("_forward_pre_hooks_with_kwargs", OrderedDict()) 508 super().__setattr__("_state_dict_hooks", OrderedDict()) 509 super().__setattr__("_state_dict_pre_hooks", OrderedDict()) 510 super().__setattr__("_load_state_dict_pre_hooks", OrderedDict()) 511 super().__setattr__("_load_state_dict_post_hooks", OrderedDict()) 512 super().__setattr__("_modules", {}) 513 514 if self.call_super_init: 515 super().__init__(*args, **kwargs) 516 517 forward: Callable[..., Any] = _forward_unimplemented 518 519 def register_buffer( 520 self, name: str, tensor: Optional[Tensor], persistent: bool = True 521 ) -> None: 522 r"""Add a buffer to the module. 523 524 This is typically used to register a buffer that should not to be 525 considered a model parameter. For example, BatchNorm's ``running_mean`` 526 is not a parameter, but is part of the module's state. Buffers, by 527 default, are persistent and will be saved alongside parameters. This 528 behavior can be changed by setting :attr:`persistent` to ``False``. The 529 only difference between a persistent buffer and a non-persistent buffer 530 is that the latter will not be a part of this module's 531 :attr:`state_dict`. 532 533 Buffers can be accessed as attributes using given names. 534 535 Args: 536 name (str): name of the buffer. The buffer can be accessed 537 from this module using the given name 538 tensor (Tensor or None): buffer to be registered. If ``None``, then operations 539 that run on buffers, such as :attr:`cuda`, are ignored. If ``None``, 540 the buffer is **not** included in the module's :attr:`state_dict`. 541 persistent (bool): whether the buffer is part of this module's 542 :attr:`state_dict`. 543 544 Example:: 545 546 >>> # xdoctest: +SKIP("undefined vars") 547 >>> self.register_buffer('running_mean', torch.zeros(num_features)) 548 549 """ 550 if persistent is False and isinstance(self, torch.jit.ScriptModule): 551 raise RuntimeError("ScriptModule does not support non-persistent buffers") 552 553 if "_buffers" not in self.__dict__: 554 raise AttributeError("cannot assign buffer before Module.__init__() call") 555 elif not isinstance(name, str): 556 raise TypeError( 557 f"buffer name should be a string. Got {torch.typename(name)}" 558 ) 559 elif "." in name: 560 raise KeyError('buffer name can\'t contain "."') 561 elif name == "": 562 raise KeyError('buffer name can\'t be empty string ""') 563 elif hasattr(self, name) and name not in self._buffers: 564 raise KeyError(f"attribute '{name}' already exists") 565 elif tensor is not None and not isinstance(tensor, torch.Tensor): 566 raise TypeError( 567 f"cannot assign '{torch.typename(tensor)}' object to buffer '{name}' " 568 "(torch Tensor or None required)" 569 ) 570 else: 571 for hook in _global_buffer_registration_hooks.values(): 572 output = hook(self, name, tensor) 573 if output is not None: 574 tensor = output 575 self._buffers[name] = tensor 576 if persistent: 577 self._non_persistent_buffers_set.discard(name) 578 else: 579 self._non_persistent_buffers_set.add(name) 580 581 def register_parameter(self, name: str, param: Optional[Parameter]) -> None: 582 r"""Add a parameter to the module. 583 584 The parameter can be accessed as an attribute using given name. 585 586 Args: 587 name (str): name of the parameter. The parameter can be accessed 588 from this module using the given name 589 param (Parameter or None): parameter to be added to the module. If 590 ``None``, then operations that run on parameters, such as :attr:`cuda`, 591 are ignored. If ``None``, the parameter is **not** included in the 592 module's :attr:`state_dict`. 593 """ 594 if "_parameters" not in self.__dict__: 595 raise AttributeError( 596 "cannot assign parameter before Module.__init__() call" 597 ) 598 599 elif not isinstance(name, str): 600 raise TypeError( 601 f"parameter name should be a string. Got {torch.typename(name)}" 602 ) 603 elif "." in name: 604 raise KeyError('parameter name can\'t contain "."') 605 elif name == "": 606 raise KeyError('parameter name can\'t be empty string ""') 607 elif hasattr(self, name) and name not in self._parameters: 608 raise KeyError(f"attribute '{name}' already exists") 609 610 if param is None: 611 self._parameters[name] = None 612 elif not isinstance(param, Parameter): 613 raise TypeError( 614 f"cannot assign '{torch.typename(param)}' object to parameter '{name}' " 615 "(torch.nn.Parameter or None required)" 616 ) 617 elif param.grad_fn: 618 raise ValueError( 619 f"Cannot assign non-leaf Tensor to parameter '{name}'. Model " 620 f"parameters must be created explicitly. To express '{name}' " 621 "as a function of another Tensor, compute the value in " 622 "the forward() method." 623 ) 624 else: 625 for hook in _global_parameter_registration_hooks.values(): 626 output = hook(self, name, param) 627 if output is not None: 628 param = output 629 self._parameters[name] = param 630 631 def add_module(self, name: str, module: Optional["Module"]) -> None: 632 r"""Add a child module to the current module. 633 634 The module can be accessed as an attribute using the given name. 635 636 Args: 637 name (str): name of the child module. The child module can be 638 accessed from this module using the given name 639 module (Module): child module to be added to the module. 640 """ 641 if not isinstance(module, Module) and module is not None: 642 raise TypeError(f"{torch.typename(module)} is not a Module subclass") 643 elif not isinstance(name, str): 644 raise TypeError( 645 f"module name should be a string. Got {torch.typename(name)}" 646 ) 647 elif hasattr(self, name) and name not in self._modules: 648 raise KeyError(f"attribute '{name}' already exists") 649 elif "." in name: 650 raise KeyError(f'module name can\'t contain ".", got: {name}') 651 elif name == "": 652 raise KeyError('module name can\'t be empty string ""') 653 for hook in _global_module_registration_hooks.values(): 654 output = hook(self, name, module) 655 if output is not None: 656 module = output 657 self._modules[name] = module 658 659 def register_module(self, name: str, module: Optional["Module"]) -> None: 660 r"""Alias for :func:`add_module`.""" 661 self.add_module(name, module) 662 663 def get_submodule(self, target: str) -> "Module": 664 """Return the submodule given by ``target`` if it exists, otherwise throw an error. 665 666 For example, let's say you have an ``nn.Module`` ``A`` that 667 looks like this: 668 669 .. code-block:: text 670 671 A( 672 (net_b): Module( 673 (net_c): Module( 674 (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) 675 ) 676 (linear): Linear(in_features=100, out_features=200, bias=True) 677 ) 678 ) 679 680 (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested 681 submodule ``net_b``, which itself has two submodules ``net_c`` 682 and ``linear``. ``net_c`` then has a submodule ``conv``.) 683 684 To check whether or not we have the ``linear`` submodule, we 685 would call ``get_submodule("net_b.linear")``. To check whether 686 we have the ``conv`` submodule, we would call 687 ``get_submodule("net_b.net_c.conv")``. 688 689 The runtime of ``get_submodule`` is bounded by the degree 690 of module nesting in ``target``. A query against 691 ``named_modules`` achieves the same result, but it is O(N) in 692 the number of transitive modules. So, for a simple check to see 693 if some submodule exists, ``get_submodule`` should always be 694 used. 695 696 Args: 697 target: The fully-qualified string name of the submodule 698 to look for. (See above example for how to specify a 699 fully-qualified string.) 700 701 Returns: 702 torch.nn.Module: The submodule referenced by ``target`` 703 704 Raises: 705 AttributeError: If the target string references an invalid 706 path or resolves to something that is not an 707 ``nn.Module`` 708 """ 709 if target == "": 710 return self 711 712 atoms: List[str] = target.split(".") 713 mod: torch.nn.Module = self 714 715 for item in atoms: 716 if not hasattr(mod, item): 717 raise AttributeError( 718 mod._get_name() + " has no " "attribute `" + item + "`" 719 ) 720 721 mod = getattr(mod, item) 722 723 if not isinstance(mod, torch.nn.Module): 724 raise AttributeError("`" + item + "` is not " "an nn.Module") 725 726 return mod 727 728 def set_submodule(self, target: str, module: "Module") -> None: 729 """ 730 Set the submodule given by ``target`` if it exists, otherwise throw an error. 731 732 For example, let's say you have an ``nn.Module`` ``A`` that 733 looks like this: 734 735 .. code-block:: text 736 737 A( 738 (net_b): Module( 739 (net_c): Module( 740 (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) 741 ) 742 (linear): Linear(in_features=100, out_features=200, bias=True) 743 ) 744 ) 745 746 (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested 747 submodule ``net_b``, which itself has two submodules ``net_c`` 748 and ``linear``. ``net_c`` then has a submodule ``conv``.) 749 750 To overide the ``Conv2d`` with a new submodule ``Linear``, you 751 would call 752 ``set_submodule("net_b.net_c.conv", nn.Linear(33, 16))``. 753 754 Args: 755 target: The fully-qualified string name of the submodule 756 to look for. (See above example for how to specify a 757 fully-qualified string.) 758 module: The module to set the submodule to. 759 760 Raises: 761 ValueError: If the target string is empty 762 AttributeError: If the target string references an invalid 763 path or resolves to something that is not an 764 ``nn.Module`` 765 """ 766 if target == "": 767 raise ValueError("Cannot set the submodule without a target name!") 768 769 atoms: List[str] = target.split(".") 770 name = atoms.pop(-1) 771 mod: torch.nn.Module = self 772 773 for item in atoms: 774 if not hasattr(mod, item): 775 raise AttributeError( 776 mod._get_name() + " has no attribute `" + item + "`" 777 ) 778 779 mod = getattr(mod, item) 780 781 # Use isinstance instead of type here to also handle subclass of nn.Module 782 if not isinstance(mod, torch.nn.Module): 783 raise AttributeError("`" + item + "` is not an nn.Module") 784 785 setattr(mod, name, module) 786 787 def get_parameter(self, target: str) -> "Parameter": 788 """Return the parameter given by ``target`` if it exists, otherwise throw an error. 789 790 See the docstring for ``get_submodule`` for a more detailed 791 explanation of this method's functionality as well as how to 792 correctly specify ``target``. 793 794 Args: 795 target: The fully-qualified string name of the Parameter 796 to look for. (See ``get_submodule`` for how to specify a 797 fully-qualified string.) 798 799 Returns: 800 torch.nn.Parameter: The Parameter referenced by ``target`` 801 802 Raises: 803 AttributeError: If the target string references an invalid 804 path or resolves to something that is not an 805 ``nn.Parameter`` 806 """ 807 module_path, _, param_name = target.rpartition(".") 808 809 mod: torch.nn.Module = self.get_submodule(module_path) 810 811 if not hasattr(mod, param_name): 812 raise AttributeError( 813 mod._get_name() + " has no attribute `" + param_name + "`" 814 ) 815 816 param: torch.nn.Parameter = getattr(mod, param_name) 817 818 if not isinstance(param, torch.nn.Parameter): 819 raise AttributeError("`" + param_name + "` is not an " "nn.Parameter") 820 821 return param 822 823 def get_buffer(self, target: str) -> "Tensor": 824 """Return the buffer given by ``target`` if it exists, otherwise throw an error. 825 826 See the docstring for ``get_submodule`` for a more detailed 827 explanation of this method's functionality as well as how to 828 correctly specify ``target``. 829 830 Args: 831 target: The fully-qualified string name of the buffer 832 to look for. (See ``get_submodule`` for how to specify a 833 fully-qualified string.) 834 835 Returns: 836 torch.Tensor: The buffer referenced by ``target`` 837 838 Raises: 839 AttributeError: If the target string references an invalid 840 path or resolves to something that is not a 841 buffer 842 """ 843 module_path, _, buffer_name = target.rpartition(".") 844 845 mod: torch.nn.Module = self.get_submodule(module_path) 846 847 if not hasattr(mod, buffer_name): 848 raise AttributeError( 849 mod._get_name() + " has no attribute `" + buffer_name + "`" 850 ) 851 852 buffer: torch.Tensor = getattr(mod, buffer_name) 853 854 if buffer_name not in mod._buffers: 855 raise AttributeError("`" + buffer_name + "` is not a buffer") 856 857 return buffer 858 859 def get_extra_state(self) -> Any: 860 """Return any extra state to include in the module's state_dict. 861 862 Implement this and a corresponding :func:`set_extra_state` for your module 863 if you need to store extra state. This function is called when building the 864 module's `state_dict()`. 865 866 Note that extra state should be picklable to ensure working serialization 867 of the state_dict. We only provide provide backwards compatibility guarantees 868 for serializing Tensors; other objects may break backwards compatibility if 869 their serialized pickled form changes. 870 871 Returns: 872 object: Any extra state to store in the module's state_dict 873 """ 874 raise RuntimeError( 875 "Reached a code path in Module.get_extra_state() that should never be called. " 876 "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " 877 "to report this bug." 878 ) 879 880 def set_extra_state(self, state: Any) -> None: 881 """Set extra state contained in the loaded `state_dict`. 882 883 This function is called from :func:`load_state_dict` to handle any extra state 884 found within the `state_dict`. Implement this function and a corresponding 885 :func:`get_extra_state` for your module if you need to store extra state within its 886 `state_dict`. 887 888 Args: 889 state (dict): Extra state from the `state_dict` 890 """ 891 raise RuntimeError( 892 "Reached a code path in Module.set_extra_state() that should never be called. " 893 "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " 894 "to report this bug." 895 ) 896 897 def _apply(self, fn, recurse=True): 898 if recurse: 899 for module in self.children(): 900 module._apply(fn) 901 902 def compute_should_use_set_data(tensor, tensor_applied): 903 if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): 904 # If the new tensor has compatible tensor type as the existing tensor, 905 # the current behavior is to change the tensor in-place using `.data =`, 906 # and the future behavior is to overwrite the existing tensor. However, 907 # changing the current behavior is a BC-breaking change, and we want it 908 # to happen in future releases. So for now we introduce the 909 # `torch.__future__.get_overwrite_module_params_on_conversion()` 910 # global flag to let the user control whether they want the future 911 # behavior of overwriting the existing tensor or not. 912 return not torch.__future__.get_overwrite_module_params_on_conversion() 913 else: 914 return False 915 916 should_use_swap_tensors = ( 917 torch.__future__.get_swap_module_params_on_conversion() 918 ) 919 920 for key, param in self._parameters.items(): 921 if param is None: 922 continue 923 # Tensors stored in modules are graph leaves, and we don't want to 924 # track autograd history of `param_applied`, so we have to use 925 # `with torch.no_grad():` 926 with torch.no_grad(): 927 param_applied = fn(param) 928 p_should_use_set_data = compute_should_use_set_data(param, param_applied) 929 930 # subclasses may have multiple child tensors so we need to use swap_tensors 931 p_should_use_swap_tensors = ( 932 should_use_swap_tensors or is_traceable_wrapper_subclass(param_applied) 933 ) 934 935 param_grad = param.grad 936 if p_should_use_swap_tensors: 937 try: 938 if param_grad is not None: 939 # Accessing param.grad makes its at::Tensor's use_count 2, which will prevent swapping. 940 # Decrement use count of the gradient by setting to None 941 param.grad = None 942 param_applied = torch.nn.Parameter( 943 param_applied, requires_grad=param.requires_grad 944 ) 945 torch.utils.swap_tensors(param, param_applied) 946 except Exception as e: 947 if param_grad is not None: 948 param.grad = param_grad 949 raise RuntimeError( 950 f"_apply(): Couldn't swap {self._get_name()}.{key}" 951 ) from e 952 out_param = param 953 elif p_should_use_set_data: 954 param.data = param_applied 955 out_param = param 956 else: 957 assert isinstance(param, Parameter) 958 assert param.is_leaf 959 out_param = Parameter(param_applied, param.requires_grad) 960 self._parameters[key] = out_param 961 962 if param_grad is not None: 963 with torch.no_grad(): 964 grad_applied = fn(param_grad) 965 g_should_use_set_data = compute_should_use_set_data( 966 param_grad, grad_applied 967 ) 968 if p_should_use_swap_tensors: 969 grad_applied.requires_grad_(param_grad.requires_grad) 970 try: 971 torch.utils.swap_tensors(param_grad, grad_applied) 972 except Exception as e: 973 raise RuntimeError( 974 f"_apply(): Couldn't swap {self._get_name()}.{key}.grad" 975 ) from e 976 out_param.grad = param_grad 977 elif g_should_use_set_data: 978 assert out_param.grad is not None 979 out_param.grad.data = grad_applied 980 else: 981 assert param_grad.is_leaf 982 out_param.grad = grad_applied.requires_grad_( 983 param_grad.requires_grad 984 ) 985 986 for key, buf in self._buffers.items(): 987 if buf is not None: 988 self._buffers[key] = fn(buf) 989 990 return self 991 992 def apply(self: T, fn: Callable[["Module"], None]) -> T: 993 r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. 994 995 Typical use includes initializing the parameters of a model 996 (see also :ref:`nn-init-doc`). 997 998 Args: 999 fn (:class:`Module` -> None): function to be applied to each submodule 1000 1001 Returns: 1002 Module: self 1003 1004 Example:: 1005 1006 >>> @torch.no_grad() 1007 >>> def init_weights(m): 1008 >>> print(m) 1009 >>> if type(m) == nn.Linear: 1010 >>> m.weight.fill_(1.0) 1011 >>> print(m.weight) 1012 >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) 1013 >>> net.apply(init_weights) 1014 Linear(in_features=2, out_features=2, bias=True) 1015 Parameter containing: 1016 tensor([[1., 1.], 1017 [1., 1.]], requires_grad=True) 1018 Linear(in_features=2, out_features=2, bias=True) 1019 Parameter containing: 1020 tensor([[1., 1.], 1021 [1., 1.]], requires_grad=True) 1022 Sequential( 1023 (0): Linear(in_features=2, out_features=2, bias=True) 1024 (1): Linear(in_features=2, out_features=2, bias=True) 1025 ) 1026 1027 """ 1028 for module in self.children(): 1029 module.apply(fn) 1030 fn(self) 1031 return self 1032 1033 def cuda(self: T, device: Optional[Union[int, device]] = None) -> T: 1034 r"""Move all model parameters and buffers to the GPU. 1035 1036 This also makes associated parameters and buffers different objects. So 1037 it should be called before constructing optimizer if the module will 1038 live on GPU while being optimized. 1039 1040 .. note:: 1041 This method modifies the module in-place. 1042 1043 Args: 1044 device (int, optional): if specified, all parameters will be 1045 copied to that device 1046 1047 Returns: 1048 Module: self 1049 """ 1050 return self._apply(lambda t: t.cuda(device)) 1051 1052 def ipu(self: T, device: Optional[Union[int, device]] = None) -> T: 1053 r"""Move all model parameters and buffers to the IPU. 1054 1055 This also makes associated parameters and buffers different objects. So 1056 it should be called before constructing optimizer if the module will 1057 live on IPU while being optimized. 1058 1059 .. note:: 1060 This method modifies the module in-place. 1061 1062 Arguments: 1063 device (int, optional): if specified, all parameters will be 1064 copied to that device 1065 1066 Returns: 1067 Module: self 1068 """ 1069 return self._apply(lambda t: t.ipu(device)) 1070 1071 def xpu(self: T, device: Optional[Union[int, device]] = None) -> T: 1072 r"""Move all model parameters and buffers to the XPU. 1073 1074 This also makes associated parameters and buffers different objects. So 1075 it should be called before constructing optimizer if the module will 1076 live on XPU while being optimized. 1077 1078 .. note:: 1079 This method modifies the module in-place. 1080 1081 Arguments: 1082 device (int, optional): if specified, all parameters will be 1083 copied to that device 1084 1085 Returns: 1086 Module: self 1087 """ 1088 return self._apply(lambda t: t.xpu(device)) 1089 1090 def mtia(self: T, device: Optional[Union[int, device]] = None) -> T: 1091 r"""Move all model parameters and buffers to the MTIA. 1092 1093 This also makes associated parameters and buffers different objects. So 1094 it should be called before constructing optimizer if the module will 1095 live on MTIA while being optimized. 1096 1097 .. note:: 1098 This method modifies the module in-place. 1099 1100 Arguments: 1101 device (int, optional): if specified, all parameters will be 1102 copied to that device 1103 1104 Returns: 1105 Module: self 1106 """ 1107 return self._apply(lambda t: t.mtia(device)) 1108 1109 def cpu(self: T) -> T: 1110 r"""Move all model parameters and buffers to the CPU. 1111 1112 .. note:: 1113 This method modifies the module in-place. 1114 1115 Returns: 1116 Module: self 1117 """ 1118 return self._apply(lambda t: t.cpu()) 1119 1120 def type(self: T, dst_type: Union[dtype, str]) -> T: 1121 r"""Casts all parameters and buffers to :attr:`dst_type`. 1122 1123 .. note:: 1124 This method modifies the module in-place. 1125 1126 Args: 1127 dst_type (type or string): the desired type 1128 1129 Returns: 1130 Module: self 1131 """ 1132 return self._apply(lambda t: t.type(dst_type)) 1133 1134 def float(self: T) -> T: 1135 r"""Casts all floating point parameters and buffers to ``float`` datatype. 1136 1137 .. note:: 1138 This method modifies the module in-place. 1139 1140 Returns: 1141 Module: self 1142 """ 1143 return self._apply(lambda t: t.float() if t.is_floating_point() else t) 1144 1145 def double(self: T) -> T: 1146 r"""Casts all floating point parameters and buffers to ``double`` datatype. 1147 1148 .. note:: 1149 This method modifies the module in-place. 1150 1151 Returns: 1152 Module: self 1153 """ 1154 return self._apply(lambda t: t.double() if t.is_floating_point() else t) 1155 1156 def half(self: T) -> T: 1157 r"""Casts all floating point parameters and buffers to ``half`` datatype. 1158 1159 .. note:: 1160 This method modifies the module in-place. 1161 1162 Returns: 1163 Module: self 1164 """ 1165 return self._apply(lambda t: t.half() if t.is_floating_point() else t) 1166 1167 def bfloat16(self: T) -> T: 1168 r"""Casts all floating point parameters and buffers to ``bfloat16`` datatype. 1169 1170 .. note:: 1171 This method modifies the module in-place. 1172 1173 Returns: 1174 Module: self 1175 """ 1176 return self._apply(lambda t: t.bfloat16() if t.is_floating_point() else t) 1177 1178 def to_empty( 1179 self: T, *, device: Optional[DeviceLikeType], recurse: bool = True 1180 ) -> T: 1181 r"""Move the parameters and buffers to the specified device without copying storage. 1182 1183 Args: 1184 device (:class:`torch.device`): The desired device of the parameters 1185 and buffers in this module. 1186 recurse (bool): Whether parameters and buffers of submodules should 1187 be recursively moved to the specified device. 1188 1189 Returns: 1190 Module: self 1191 """ 1192 return self._apply( 1193 lambda t: torch.empty_like(t, device=device), recurse=recurse 1194 ) 1195 1196 @overload 1197 def to( 1198 self, 1199 device: Optional[DeviceLikeType] = ..., 1200 dtype: Optional[dtype] = ..., 1201 non_blocking: bool = ..., 1202 ) -> Self: 1203 ... 1204 1205 @overload 1206 def to(self, dtype: dtype, non_blocking: bool = ...) -> Self: 1207 ... 1208 1209 @overload 1210 def to(self, tensor: Tensor, non_blocking: bool = ...) -> Self: 1211 ... 1212 1213 def to(self, *args, **kwargs): 1214 r"""Move and/or cast the parameters and buffers. 1215 1216 This can be called as 1217 1218 .. function:: to(device=None, dtype=None, non_blocking=False) 1219 :noindex: 1220 1221 .. function:: to(dtype, non_blocking=False) 1222 :noindex: 1223 1224 .. function:: to(tensor, non_blocking=False) 1225 :noindex: 1226 1227 .. function:: to(memory_format=torch.channels_last) 1228 :noindex: 1229 1230 Its signature is similar to :meth:`torch.Tensor.to`, but only accepts 1231 floating point or complex :attr:`dtype`\ s. In addition, this method will 1232 only cast the floating point or complex parameters and buffers to :attr:`dtype` 1233 (if given). The integral parameters and buffers will be moved 1234 :attr:`device`, if that is given, but with dtypes unchanged. When 1235 :attr:`non_blocking` is set, it tries to convert/move asynchronously 1236 with respect to the host if possible, e.g., moving CPU Tensors with 1237 pinned memory to CUDA devices. 1238 1239 See below for examples. 1240 1241 .. note:: 1242 This method modifies the module in-place. 1243 1244 Args: 1245 device (:class:`torch.device`): the desired device of the parameters 1246 and buffers in this module 1247 dtype (:class:`torch.dtype`): the desired floating point or complex dtype of 1248 the parameters and buffers in this module 1249 tensor (torch.Tensor): Tensor whose dtype and device are the desired 1250 dtype and device for all parameters and buffers in this module 1251 memory_format (:class:`torch.memory_format`): the desired memory 1252 format for 4D parameters and buffers in this module (keyword 1253 only argument) 1254 1255 Returns: 1256 Module: self 1257 1258 Examples:: 1259 1260 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 1261 >>> linear = nn.Linear(2, 2) 1262 >>> linear.weight 1263 Parameter containing: 1264 tensor([[ 0.1913, -0.3420], 1265 [-0.5113, -0.2325]]) 1266 >>> linear.to(torch.double) 1267 Linear(in_features=2, out_features=2, bias=True) 1268 >>> linear.weight 1269 Parameter containing: 1270 tensor([[ 0.1913, -0.3420], 1271 [-0.5113, -0.2325]], dtype=torch.float64) 1272 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) 1273 >>> gpu1 = torch.device("cuda:1") 1274 >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) 1275 Linear(in_features=2, out_features=2, bias=True) 1276 >>> linear.weight 1277 Parameter containing: 1278 tensor([[ 0.1914, -0.3420], 1279 [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') 1280 >>> cpu = torch.device("cpu") 1281 >>> linear.to(cpu) 1282 Linear(in_features=2, out_features=2, bias=True) 1283 >>> linear.weight 1284 Parameter containing: 1285 tensor([[ 0.1914, -0.3420], 1286 [-0.5112, -0.2324]], dtype=torch.float16) 1287 1288 >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) 1289 >>> linear.weight 1290 Parameter containing: 1291 tensor([[ 0.3741+0.j, 0.2382+0.j], 1292 [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) 1293 >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) 1294 tensor([[0.6122+0.j, 0.1150+0.j], 1295 [0.6122+0.j, 0.1150+0.j], 1296 [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128) 1297 1298 """ 1299 device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( 1300 *args, **kwargs 1301 ) 1302 1303 if dtype is not None: 1304 if not (dtype.is_floating_point or dtype.is_complex): 1305 raise TypeError( 1306 "nn.Module.to only accepts floating point or complex " 1307 f"dtypes, but got desired dtype={dtype}" 1308 ) 1309 if dtype.is_complex: 1310 warnings.warn( 1311 "Complex modules are a new feature under active development whose design may change, " 1312 "and some modules might not work as expected when using complex tensors as parameters or buffers. " 1313 "Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml " 1314 "if a complex module does not work as expected." 1315 ) 1316 1317 def convert(t): 1318 try: 1319 if convert_to_format is not None and t.dim() in (4, 5): 1320 return t.to( 1321 device, 1322 dtype if t.is_floating_point() or t.is_complex() else None, 1323 non_blocking, 1324 memory_format=convert_to_format, 1325 ) 1326 return t.to( 1327 device, 1328 dtype if t.is_floating_point() or t.is_complex() else None, 1329 non_blocking, 1330 ) 1331 except NotImplementedError as e: 1332 if str(e) == "Cannot copy out of meta tensor; no data!": 1333 raise NotImplementedError( 1334 f"{e} Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() " 1335 f"when moving module from meta to a different device." 1336 ) from None 1337 else: 1338 raise 1339 1340 return self._apply(convert) 1341 1342 def register_full_backward_pre_hook( 1343 self, 1344 hook: Callable[["Module", _grad_t], Union[None, _grad_t]], 1345 prepend: bool = False, 1346 ) -> RemovableHandle: 1347 r"""Register a backward pre-hook on the module. 1348 1349 The hook will be called every time the gradients for the module are computed. 1350 The hook should have the following signature:: 1351 1352 hook(module, grad_output) -> tuple[Tensor] or None 1353 1354 The :attr:`grad_output` is a tuple. The hook should 1355 not modify its arguments, but it can optionally return a new gradient with 1356 respect to the output that will be used in place of :attr:`grad_output` in 1357 subsequent computations. Entries in :attr:`grad_output` will be ``None`` for 1358 all non-Tensor arguments. 1359 1360 For technical reasons, when this hook is applied to a Module, its forward function will 1361 receive a view of each Tensor passed to the Module. Similarly the caller will receive a view 1362 of each Tensor returned by the Module's forward function. 1363 1364 .. warning :: 1365 Modifying inputs inplace is not allowed when using backward hooks and 1366 will raise an error. 1367 1368 Args: 1369 hook (Callable): The user-defined hook to be registered. 1370 prepend (bool): If true, the provided ``hook`` will be fired before 1371 all existing ``backward_pre`` hooks on this 1372 :class:`torch.nn.modules.Module`. Otherwise, the provided 1373 ``hook`` will be fired after all existing ``backward_pre`` hooks 1374 on this :class:`torch.nn.modules.Module`. Note that global 1375 ``backward_pre`` hooks registered with 1376 :func:`register_module_full_backward_pre_hook` will fire before 1377 all hooks registered by this method. 1378 1379 Returns: 1380 :class:`torch.utils.hooks.RemovableHandle`: 1381 a handle that can be used to remove the added hook by calling 1382 ``handle.remove()`` 1383 1384 """ 1385 handle = RemovableHandle(self._backward_pre_hooks) 1386 self._backward_pre_hooks[handle.id] = hook 1387 if prepend: 1388 self._backward_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] 1389 return handle 1390 1391 def register_backward_hook( 1392 self, hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]] 1393 ) -> RemovableHandle: 1394 r"""Register a backward hook on the module. 1395 1396 This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and 1397 the behavior of this function will change in future versions. 1398 1399 Returns: 1400 :class:`torch.utils.hooks.RemovableHandle`: 1401 a handle that can be used to remove the added hook by calling 1402 ``handle.remove()`` 1403 1404 """ 1405 if self._is_full_backward_hook is True: 1406 raise RuntimeError( 1407 "Cannot use both regular backward hooks and full backward hooks on a " 1408 "single Module. Please use only one of them." 1409 ) 1410 1411 self._is_full_backward_hook = False 1412 1413 handle = RemovableHandle(self._backward_hooks) 1414 self._backward_hooks[handle.id] = hook 1415 return handle 1416 1417 def register_full_backward_hook( 1418 self, 1419 hook: Callable[["Module", _grad_t, _grad_t], Union[None, _grad_t]], 1420 prepend: bool = False, 1421 ) -> RemovableHandle: 1422 r"""Register a backward hook on the module. 1423 1424 The hook will be called every time the gradients with respect to a module 1425 are computed, i.e. the hook will execute if and only if the gradients with 1426 respect to module outputs are computed. The hook should have the following 1427 signature:: 1428 1429 hook(module, grad_input, grad_output) -> tuple(Tensor) or None 1430 1431 The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients 1432 with respect to the inputs and outputs respectively. The hook should 1433 not modify its arguments, but it can optionally return a new gradient with 1434 respect to the input that will be used in place of :attr:`grad_input` in 1435 subsequent computations. :attr:`grad_input` will only correspond to the inputs given 1436 as positional arguments and all kwarg arguments are ignored. Entries 1437 in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor 1438 arguments. 1439 1440 For technical reasons, when this hook is applied to a Module, its forward function will 1441 receive a view of each Tensor passed to the Module. Similarly the caller will receive a view 1442 of each Tensor returned by the Module's forward function. 1443 1444 .. warning :: 1445 Modifying inputs or outputs inplace is not allowed when using backward hooks and 1446 will raise an error. 1447 1448 Args: 1449 hook (Callable): The user-defined hook to be registered. 1450 prepend (bool): If true, the provided ``hook`` will be fired before 1451 all existing ``backward`` hooks on this 1452 :class:`torch.nn.modules.Module`. Otherwise, the provided 1453 ``hook`` will be fired after all existing ``backward`` hooks on 1454 this :class:`torch.nn.modules.Module`. Note that global 1455 ``backward`` hooks registered with 1456 :func:`register_module_full_backward_hook` will fire before 1457 all hooks registered by this method. 1458 1459 Returns: 1460 :class:`torch.utils.hooks.RemovableHandle`: 1461 a handle that can be used to remove the added hook by calling 1462 ``handle.remove()`` 1463 1464 """ 1465 if self._is_full_backward_hook is False: 1466 raise RuntimeError( 1467 "Cannot use both regular backward hooks and full backward hooks on a " 1468 "single Module. Please use only one of them." 1469 ) 1470 1471 self._is_full_backward_hook = True 1472 1473 handle = RemovableHandle(self._backward_hooks) 1474 self._backward_hooks[handle.id] = hook 1475 if prepend: 1476 self._backward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] 1477 return handle 1478 1479 def _get_backward_hooks(self): 1480 r"""Return the backward hooks for use in the call function. 1481 1482 It returns two lists, one with the full backward hooks and one with the non-full 1483 backward hooks. 1484 """ 1485 full_backward_hooks: List[Callable] = [] 1486 if _global_is_full_backward_hook is True: 1487 full_backward_hooks += _global_backward_hooks.values() 1488 if self._is_full_backward_hook is True: 1489 full_backward_hooks += self._backward_hooks.values() 1490 1491 non_full_backward_hooks: List[Callable] = [] 1492 if _global_is_full_backward_hook is False: 1493 non_full_backward_hooks += _global_backward_hooks.values() 1494 if self._is_full_backward_hook is False: 1495 non_full_backward_hooks += self._backward_hooks.values() 1496 1497 return full_backward_hooks, non_full_backward_hooks 1498 1499 def _get_backward_pre_hooks(self): 1500 backward_pre_hooks: List[Callable] = [] 1501 backward_pre_hooks += _global_backward_pre_hooks.values() 1502 backward_pre_hooks += self._backward_pre_hooks.values() 1503 1504 return backward_pre_hooks 1505 1506 def _maybe_warn_non_full_backward_hook(self, inputs, result, grad_fn): 1507 if not isinstance(result, torch.Tensor): 1508 if not ( 1509 isinstance(result, tuple) 1510 and all(isinstance(r, torch.Tensor) for r in result) 1511 ): 1512 warnings.warn( 1513 "Using non-full backward hooks on a Module that does not return a " 1514 "single Tensor or a tuple of Tensors is deprecated and will be removed " 1515 "in future versions. This hook will be missing some of the grad_output. " 1516 "Please use register_full_backward_hook to get the documented behavior.", 1517 FutureWarning, 1518 stacklevel=2, 1519 ) 1520 return 1521 else: 1522 result = (result,) 1523 1524 if not isinstance(inputs, torch.Tensor): 1525 if not ( 1526 isinstance(inputs, tuple) 1527 and all(isinstance(i, torch.Tensor) for i in inputs) 1528 ): 1529 warnings.warn( 1530 "Using non-full backward hooks on a Module that does not take as input a " 1531 "single Tensor or a tuple of Tensors is deprecated and will be removed " 1532 "in future versions. This hook will be missing some of the grad_input. " 1533 "Please use register_full_backward_hook to get the documented behavior.", 1534 FutureWarning, 1535 stacklevel=2, 1536 ) 1537 return 1538 else: 1539 inputs = (inputs,) 1540 1541 # At this point we are sure that inputs and result are tuple of Tensors 1542 out_grad_fn = {r.grad_fn for r in result if r.grad_fn is not None} 1543 if len(out_grad_fn) == 0 or ( 1544 len(out_grad_fn) == 1 and grad_fn not in out_grad_fn 1545 ): 1546 warnings.warn( 1547 "Using a non-full backward hook when outputs are nested in python data structure " 1548 "is deprecated and will be removed in future versions. This hook will be missing " 1549 "some grad_output.", 1550 FutureWarning, 1551 stacklevel=2, 1552 ) 1553 elif len(out_grad_fn) > 1: 1554 warnings.warn( 1555 "Using a non-full backward hook when outputs are generated by different autograd Nodes " 1556 "is deprecated and will be removed in future versions. This hook will be missing " 1557 "some grad_output. Please use register_full_backward_hook to get the documented behavior.", 1558 FutureWarning, 1559 stacklevel=2, 1560 ) 1561 else: 1562 # At this point the grad_output part of the hook will most likely be correct 1563 inputs_grad_fn = {i.grad_fn for i in inputs if i.grad_fn is not None} 1564 1565 next_functions = {n[0] for n in grad_fn.next_functions} 1566 1567 if inputs_grad_fn != next_functions: 1568 warnings.warn( 1569 "Using a non-full backward hook when the forward contains multiple autograd Nodes " 1570 "is deprecated and will be removed in future versions. This hook will be missing " 1571 "some grad_input. Please use register_full_backward_hook to get the documented " 1572 "behavior.", 1573 FutureWarning, 1574 stacklevel=2, 1575 ) 1576 1577 def register_forward_pre_hook( 1578 self, 1579 hook: Union[ 1580 Callable[[T, Tuple[Any, ...]], Optional[Any]], 1581 Callable[ 1582 [T, Tuple[Any, ...], Dict[str, Any]], 1583 Optional[Tuple[Any, Dict[str, Any]]], 1584 ], 1585 ], 1586 *, 1587 prepend: bool = False, 1588 with_kwargs: bool = False, 1589 ) -> RemovableHandle: 1590 r"""Register a forward pre-hook on the module. 1591 1592 The hook will be called every time before :func:`forward` is invoked. 1593 1594 1595 If ``with_kwargs`` is false or not specified, the input contains only 1596 the positional arguments given to the module. Keyword arguments won't be 1597 passed to the hooks and only to the ``forward``. The hook can modify the 1598 input. User can either return a tuple or a single modified value in the 1599 hook. We will wrap the value into a tuple if a single value is returned 1600 (unless that value is already a tuple). The hook should have the 1601 following signature:: 1602 1603 hook(module, args) -> None or modified input 1604 1605 If ``with_kwargs`` is true, the forward pre-hook will be passed the 1606 kwargs given to the forward function. And if the hook modifies the 1607 input, both the args and kwargs should be returned. The hook should have 1608 the following signature:: 1609 1610 hook(module, args, kwargs) -> None or a tuple of modified input and kwargs 1611 1612 Args: 1613 hook (Callable): The user defined hook to be registered. 1614 prepend (bool): If true, the provided ``hook`` will be fired before 1615 all existing ``forward_pre`` hooks on this 1616 :class:`torch.nn.modules.Module`. Otherwise, the provided 1617 ``hook`` will be fired after all existing ``forward_pre`` hooks 1618 on this :class:`torch.nn.modules.Module`. Note that global 1619 ``forward_pre`` hooks registered with 1620 :func:`register_module_forward_pre_hook` will fire before all 1621 hooks registered by this method. 1622 Default: ``False`` 1623 with_kwargs (bool): If true, the ``hook`` will be passed the kwargs 1624 given to the forward function. 1625 Default: ``False`` 1626 1627 Returns: 1628 :class:`torch.utils.hooks.RemovableHandle`: 1629 a handle that can be used to remove the added hook by calling 1630 ``handle.remove()`` 1631 """ 1632 handle = RemovableHandle( 1633 self._forward_pre_hooks, extra_dict=self._forward_pre_hooks_with_kwargs 1634 ) 1635 self._forward_pre_hooks[handle.id] = hook 1636 if with_kwargs: 1637 self._forward_pre_hooks_with_kwargs[handle.id] = True 1638 1639 if prepend: 1640 self._forward_pre_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] 1641 return handle 1642 1643 def register_forward_hook( 1644 self, 1645 hook: Union[ 1646 Callable[[T, Tuple[Any, ...], Any], Optional[Any]], 1647 Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]], 1648 ], 1649 *, 1650 prepend: bool = False, 1651 with_kwargs: bool = False, 1652 always_call: bool = False, 1653 ) -> RemovableHandle: 1654 r"""Register a forward hook on the module. 1655 1656 The hook will be called every time after :func:`forward` has computed an output. 1657 1658 If ``with_kwargs`` is ``False`` or not specified, the input contains only 1659 the positional arguments given to the module. Keyword arguments won't be 1660 passed to the hooks and only to the ``forward``. The hook can modify the 1661 output. It can modify the input inplace but it will not have effect on 1662 forward since this is called after :func:`forward` is called. The hook 1663 should have the following signature:: 1664 1665 hook(module, args, output) -> None or modified output 1666 1667 If ``with_kwargs`` is ``True``, the forward hook will be passed the 1668 ``kwargs`` given to the forward function and be expected to return the 1669 output possibly modified. The hook should have the following signature:: 1670 1671 hook(module, args, kwargs, output) -> None or modified output 1672 1673 Args: 1674 hook (Callable): The user defined hook to be registered. 1675 prepend (bool): If ``True``, the provided ``hook`` will be fired 1676 before all existing ``forward`` hooks on this 1677 :class:`torch.nn.modules.Module`. Otherwise, the provided 1678 ``hook`` will be fired after all existing ``forward`` hooks on 1679 this :class:`torch.nn.modules.Module`. Note that global 1680 ``forward`` hooks registered with 1681 :func:`register_module_forward_hook` will fire before all hooks 1682 registered by this method. 1683 Default: ``False`` 1684 with_kwargs (bool): If ``True``, the ``hook`` will be passed the 1685 kwargs given to the forward function. 1686 Default: ``False`` 1687 always_call (bool): If ``True`` the ``hook`` will be run regardless of 1688 whether an exception is raised while calling the Module. 1689 Default: ``False`` 1690 1691 Returns: 1692 :class:`torch.utils.hooks.RemovableHandle`: 1693 a handle that can be used to remove the added hook by calling 1694 ``handle.remove()`` 1695 """ 1696 handle = RemovableHandle( 1697 self._forward_hooks, 1698 extra_dict=[ 1699 self._forward_hooks_with_kwargs, 1700 self._forward_hooks_always_called, 1701 ], 1702 ) 1703 self._forward_hooks[handle.id] = hook 1704 if with_kwargs: 1705 self._forward_hooks_with_kwargs[handle.id] = True 1706 if always_call: 1707 self._forward_hooks_always_called[handle.id] = True 1708 if prepend: 1709 self._forward_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] 1710 return handle 1711 1712 def _slow_forward(self, *input, **kwargs): 1713 tracing_state = torch._C._get_tracing_state() 1714 if not tracing_state or isinstance(self.forward, torch._C.ScriptMethod): 1715 return self.forward(*input, **kwargs) 1716 recording_scopes = torch.jit._trace._trace_module_map is not None 1717 if recording_scopes: 1718 # type ignore was added because at this point one knows that 1719 # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any] 1720 name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None # type: ignore[index, operator] # noqa: B950 1721 if name: 1722 tracing_state.push_scope(name) 1723 else: 1724 recording_scopes = False 1725 try: 1726 result = self.forward(*input, **kwargs) 1727 finally: 1728 if recording_scopes: 1729 tracing_state.pop_scope() 1730 return result 1731 1732 def _wrapped_call_impl(self, *args, **kwargs): 1733 if self._compiled_call_impl is not None: 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: 1736 return self._call_impl(*args, **kwargs) 1737 1738 # torchrec tests the code consistency with the following code 1739 # fmt: off 1740 def _call_impl(self, *args, **kwargs): 1741 forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward) 1742 # If we don't have any hooks, we want to skip the rest of the logic in 1743 # this function, and just call forward. 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1745 or _global_backward_pre_hooks or _global_backward_hooks 1746 or _global_forward_hooks or _global_forward_pre_hooks): 1747 return forward_call(*args, **kwargs) 1748 1749 result = None 1750 called_always_called_hooks = set() 1751 1752 def inner(): 1753 nonlocal result, args, kwargs 1754 1755 full_backward_hooks, non_full_backward_hooks = [], [] 1756 backward_pre_hooks = [] 1757 if self._backward_pre_hooks or _global_backward_pre_hooks: 1758 backward_pre_hooks = self._get_backward_pre_hooks() 1759 1760 if self._backward_hooks or _global_backward_hooks: 1761 full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() 1762 1763 if _global_forward_pre_hooks or self._forward_pre_hooks: 1764 for hook_id, hook in ( 1765 *_global_forward_pre_hooks.items(), 1766 *self._forward_pre_hooks.items(), 1767 ): 1768 if hook_id in self._forward_pre_hooks_with_kwargs: 1769 args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] 1770 if args_kwargs_result is not None: 1771 if isinstance(args_kwargs_result, tuple) and len(args_kwargs_result) == 2: 1772 args, kwargs = args_kwargs_result 1773 else: 1774 raise RuntimeError( 1775 "forward pre-hook must return None or a tuple " 1776 f"of (new_args, new_kwargs), but got {args_kwargs_result}." 1777 ) 1778 else: 1779 args_result = hook(self, args) 1780 if args_result is not None: 1781 if not isinstance(args_result, tuple): 1782 args_result = (args_result,) 1783 args = args_result 1784 1785 bw_hook = None 1786 if full_backward_hooks or backward_pre_hooks: 1787 bw_hook = BackwardHook(self, full_backward_hooks, backward_pre_hooks) 1788 args = bw_hook.setup_input_hook(args) 1789 1790 result = forward_call(*args, **kwargs) 1791 if _global_forward_hooks or self._forward_hooks: 1792 for hook_id, hook in ( 1793 *_global_forward_hooks.items(), 1794 *self._forward_hooks.items(), 1795 ): 1796 # mark that always called hook is run 1797 if hook_id in self._forward_hooks_always_called or hook_id in _global_forward_hooks_always_called: 1798 called_always_called_hooks.add(hook_id) 1799 1800 if hook_id in self._forward_hooks_with_kwargs: 1801 hook_result = hook(self, args, kwargs, result) 1802 else: 1803 hook_result = hook(self, args, result) 1804 1805 if hook_result is not None: 1806 result = hook_result 1807 1808 if bw_hook: 1809 if not isinstance(result, (torch.Tensor, tuple)): 1810 warnings.warn("For backward hooks to be called," 1811 " module output should be a Tensor or a tuple of Tensors" 1812 f" but received {type(result)}") 1813 result = bw_hook.setup_output_hook(result) 1814 1815 # Handle the non-full backward hooks 1816 if non_full_backward_hooks: 1817 var = result 1818 while not isinstance(var, torch.Tensor): 1819 if isinstance(var, dict): 1820 var = next(v for v in var.values() if isinstance(v, torch.Tensor)) 1821 else: 1822 var = var[0] 1823 grad_fn = var.grad_fn 1824 if grad_fn is not None: 1825 for hook in non_full_backward_hooks: 1826 grad_fn.register_hook(_WrappedHook(hook, self)) 1827 self._maybe_warn_non_full_backward_hook(args, result, grad_fn) 1828 1829 return result 1830 1831 from torch.compiler import is_compiling 1832 1833 # This is technically not behavior equivalent when compiling, but it's 1834 # incredibly unlikely we will ever support throwing an exception in NN 1835 # module, and then catching it here, and then reraising it, and then 1836 # catching it again, and expecting the resulting frame to be compiled. 1837 # The reraise here just gunks up our exception handling for no good 1838 # reason. Don't try to run the always called hooks in event of 1839 # exception. 1840 if is_compiling(): 1841 return inner() 1842 1843 try: 1844 return inner() 1845 except Exception: 1846 # run always called hooks if they have not already been run 1847 # For now only forward hooks have the always_call option but perhaps 1848 # this functionality should be added to full backward hooks as well. 1849 for hook_id, hook in _global_forward_hooks.items(): 1850 if hook_id in _global_forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] 1851 try: 1852 hook_result = hook(self, args, result) # type: ignore[possibly-undefined] 1853 if hook_result is not None: 1854 result = hook_result 1855 except Exception as e: 1856 warnings.warn("global module forward hook with ``always_call=True`` raised an exception " 1857 f"that was silenced as another error was raised in forward: {str(e)}") 1858 continue 1859 1860 for hook_id, hook in self._forward_hooks.items(): 1861 if hook_id in self._forward_hooks_always_called and hook_id not in called_always_called_hooks: # type: ignore[possibly-undefined] 1862 try: 1863 if hook_id in self._forward_hooks_with_kwargs: 1864 hook_result = hook(self, args, kwargs, result) # type: ignore[possibly-undefined] 1865 else: 1866 hook_result = hook(self, args, result) # type: ignore[possibly-undefined] 1867 if hook_result is not None: 1868 result = hook_result 1869 except Exception as e: 1870 warnings.warn("module forward hook with ``always_call=True`` raised an exception " 1871 f"that was silenced as another error was raised in forward: {str(e)}") 1872 continue 1873 # raise exception raised in try block 1874 raise 1875 # fmt: on 1876 1877 __call__: Callable[..., Any] = _wrapped_call_impl 1878 1879 def __getstate__(self): 1880 state = self.__dict__.copy() 1881 state.pop("_compiled_call_impl", None) 1882 return state 1883 1884 def __setstate__(self, state): 1885 self.__dict__.update(state) 1886 1887 # Support loading old checkpoints that don't have the following attrs: 1888 if "_forward_pre_hooks" not in self.__dict__: 1889 self._forward_pre_hooks = OrderedDict() 1890 if "_forward_pre_hooks_with_kwargs" not in self.__dict__: 1891 self._forward_pre_hooks_with_kwargs = OrderedDict() 1892 if "_forward_hooks_with_kwargs" not in self.__dict__: 1893 self._forward_hooks_with_kwargs = OrderedDict() 1894 if "_forward_hooks_always_called" not in self.__dict__: 1895 self._forward_hooks_always_called = OrderedDict() 1896 if "_state_dict_hooks" not in self.__dict__: 1897 self._state_dict_hooks = OrderedDict() 1898 if "_state_dict_pre_hooks" not in self.__dict__: 1899 self._state_dict_pre_hooks = OrderedDict() 1900 if "_load_state_dict_pre_hooks" not in self.__dict__: 1901 self._load_state_dict_pre_hooks = OrderedDict() 1902 if "_load_state_dict_post_hooks" not in self.__dict__: 1903 self._load_state_dict_post_hooks = OrderedDict() 1904 if "_non_persistent_buffers_set" not in self.__dict__: 1905 self._non_persistent_buffers_set = set() 1906 if "_is_full_backward_hook" not in self.__dict__: 1907 self._is_full_backward_hook = None 1908 if "_backward_pre_hooks" not in self.__dict__: 1909 self._backward_pre_hooks = OrderedDict() 1910 1911 # On the return type: 1912 # We choose to return `Any` in the `__getattr__` type signature instead of a more strict `Union[Tensor, Module]`. 1913 # This is done for better interop with various type checkers for the end users. 1914 # Having a stricter return type doesn't play nicely with `register_buffer()` and forces 1915 # people to excessively use type-ignores, asserts, casts, etc. 1916 # See full discussion on the problems with returning `Union` here 1917 # https://github.com/microsoft/pyright/issues/4213 1918 def __getattr__(self, name: str) -> Any: 1919 if "_parameters" in self.__dict__: 1920 _parameters = self.__dict__["_parameters"] 1921 if name in _parameters: 1922 return _parameters[name] 1923 if "_buffers" in self.__dict__: 1924 _buffers = self.__dict__["_buffers"] 1925 if name in _buffers: 1926 return _buffers[name] 1927 if "_modules" in self.__dict__: 1928 modules = self.__dict__["_modules"] 1929 if name in modules: 1930 return modules[name] 1931 raise AttributeError( 1932 f"'{type(self).__name__}' object has no attribute '{name}'" 1933 ) 1934 1935 def __setattr__(self, name: str, value: Union[Tensor, "Module"]) -> None: 1936 def remove_from(*dicts_or_sets): 1937 for d in dicts_or_sets: 1938 if name in d: 1939 if isinstance(d, dict): 1940 del d[name] 1941 else: 1942 d.discard(name) 1943 1944 params = self.__dict__.get("_parameters") 1945 if isinstance(value, Parameter): 1946 if params is None: 1947 raise AttributeError( 1948 "cannot assign parameters before Module.__init__() call" 1949 ) 1950 remove_from( 1951 self.__dict__, 1952 self._buffers, 1953 self._modules, 1954 self._non_persistent_buffers_set, 1955 ) 1956 self.register_parameter(name, value) 1957 elif params is not None and name in params: 1958 if value is not None: 1959 raise TypeError( 1960 f"cannot assign '{torch.typename(value)}' as parameter '{name}' " 1961 "(torch.nn.Parameter or None expected)" 1962 ) 1963 self.register_parameter(name, value) 1964 else: 1965 modules = self.__dict__.get("_modules") 1966 if isinstance(value, Module): 1967 if modules is None: 1968 raise AttributeError( 1969 "cannot assign module before Module.__init__() call" 1970 ) 1971 remove_from( 1972 self.__dict__, 1973 self._parameters, 1974 self._buffers, 1975 self._non_persistent_buffers_set, 1976 ) 1977 for hook in _global_module_registration_hooks.values(): 1978 output = hook(self, name, value) 1979 if output is not None: 1980 value = output 1981 modules[name] = value 1982 elif modules is not None and name in modules: 1983 if value is not None: 1984 raise TypeError( 1985 f"cannot assign '{torch.typename(value)}' as child module '{name}' " 1986 "(torch.nn.Module or None expected)" 1987 ) 1988 for hook in _global_module_registration_hooks.values(): 1989 output = hook(self, name, value) 1990 if output is not None: 1991 value = output 1992 modules[name] = value 1993 else: 1994 buffers = self.__dict__.get("_buffers") 1995 if isinstance(value, Buffer) or buffers is not None and name in buffers: 1996 if value is not None and not isinstance(value, torch.Tensor): 1997 raise TypeError( 1998 f"cannot assign '{torch.typename(value)}' as buffer '{name}' " 1999 "(torch.nn.Buffer, torch.Tensor or None expected)" 2000 ) 2001 if isinstance(value, Buffer): 2002 persistent = value.persistent 2003 else: 2004 persistent = name not in self._non_persistent_buffers_set 2005 # === HACK === 2006 # This whole block below should just be: 2007 # self.register_buffer(name, value, persistent) 2008 2009 # But to support subclasses of nn.Module that (wrongfully) implement a 2010 # register_buffer() method that doesn't have the "persistent" 2011 # argument. Only pass it in if it is accepted otherwise assume 2012 # it is always true 2013 if self.register_buffer is torch.nn.Module.register_buffer: 2014 self.register_buffer(name, value, persistent) 2015 else: 2016 sign = inspect.signature(self.register_buffer) 2017 if "persistent" in sign.parameters: 2018 self.register_buffer(name, value, persistent) 2019 else: 2020 if not persistent: 2021 raise RuntimeError( 2022 "Registering a non-persistent buffer " 2023 "on a Module subclass that implements " 2024 "register_buffer() without the persistent " 2025 "argument is not allowed." 2026 ) 2027 # Assume that the implementation without the argument has the 2028 # behavior from before the argument was added: persistent=True 2029 self.register_buffer(name, value) 2030 # === HACK END === 2031 else: 2032 super().__setattr__(name, value) 2033 2034 def __delattr__(self, name): 2035 if name in self._parameters: 2036 del self._parameters[name] 2037 elif name in self._buffers: 2038 del self._buffers[name] 2039 self._non_persistent_buffers_set.discard(name) 2040 elif name in self._modules: 2041 del self._modules[name] 2042 else: 2043 super().__delattr__(name) 2044 2045 def _register_state_dict_hook(self, hook): 2046 r"""Register a post-hook for the :meth:`~torch.nn.Module.state_dict` method. 2047 2048 It should have the following signature:: 2049 hook(module, state_dict, prefix, local_metadata) -> None or state_dict 2050 2051 The registered hooks can modify the ``state_dict`` inplace or return a new one. 2052 If a new ``state_dict`` is returned, it will only be respected if it is the root 2053 module that :meth:`~nn.Module.state_dict` is called from. 2054 """ 2055 if getattr(hook, "_from_public_api", False): 2056 raise RuntimeError( 2057 "Cannot register the same function as the state dict post hook that was " 2058 "previously registered via register_state_dict_post_hook" 2059 ) 2060 handle = RemovableHandle(self._state_dict_hooks) 2061 self._state_dict_hooks[handle.id] = hook 2062 return handle 2063 2064 def register_state_dict_post_hook(self, hook): 2065 r"""Register a post-hook for the :meth:`~torch.nn.Module.state_dict` method. 2066 2067 It should have the following signature:: 2068 hook(module, state_dict, prefix, local_metadata) -> None 2069 2070 The registered hooks can modify the ``state_dict`` inplace. 2071 """ 2072 # In _register_state_dict_hook there was a bug described in 2073 # https://github.com/pytorch/pytorch/issues/117437 where the return value 2074 # was only respected for the root module but not child submodules. 2075 # We fix this in this public version by only allowing inplace modifications on 2076 # the state_dict by the hook. However, since hooks registered via both these 2077 # APIs will be added to `_state_dict_hooks` and the type of `_state_dict_hooks` 2078 # cannot be changed due to many dependencies on it, we mark a hook 2079 # as being registered via the public API by setting `_from_public_api` on it. 2080 # In the implementation of `state_dict`, if the callable does not have this 2081 # flag, the old behavior of respecting the return value will be preserved 2082 # for the root module, otherwise, we ensure that the hook returns None. 2083 hook._from_public_api = True 2084 handle = RemovableHandle(self._state_dict_hooks) 2085 self._state_dict_hooks[handle.id] = hook 2086 return handle 2087 2088 def register_state_dict_pre_hook(self, hook): 2089 r"""Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method. 2090 2091 It should have the following signature:: 2092 hook(module, prefix, keep_vars) -> None 2093 2094 The registered hooks can be used to perform pre-processing before the ``state_dict`` 2095 call is made. 2096 """ 2097 handle = RemovableHandle(self._state_dict_pre_hooks) 2098 self._state_dict_pre_hooks[handle.id] = hook 2099 return handle 2100 2101 def _save_to_state_dict(self, destination, prefix, keep_vars): 2102 r"""Save module state to the `destination` dictionary. 2103 2104 The `destination` dictionary will contain the state 2105 of the module, but not its descendants. This is called on every 2106 submodule in :meth:`~torch.nn.Module.state_dict`. 2107 2108 In rare cases, subclasses can achieve class-specific behavior by 2109 overriding this method with custom logic. 2110 2111 Args: 2112 destination (dict): a dict where state will be stored 2113 prefix (str): the prefix for parameters and buffers used in this 2114 module 2115 """ 2116 for name, param in self._parameters.items(): 2117 if param is not None: 2118 destination[prefix + name] = param if keep_vars else param.detach() 2119 for name, buf in self._buffers.items(): 2120 if buf is not None and name not in self._non_persistent_buffers_set: 2121 destination[prefix + name] = buf if keep_vars else buf.detach() 2122 extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX 2123 if ( 2124 getattr(self.__class__, "get_extra_state", Module.get_extra_state) 2125 is not Module.get_extra_state 2126 ): 2127 destination[extra_state_key] = self.get_extra_state() 2128 2129 # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns 2130 # back that same object. But if they pass nothing, an `OrderedDict` is created and returned. 2131 T_destination = TypeVar("T_destination", bound=Dict[str, Any]) 2132 2133 @overload 2134 def state_dict( 2135 self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ... 2136 ) -> T_destination: 2137 ... 2138 2139 @overload 2140 def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: 2141 ... 2142 2143 # TODO: Change `*args` to `*` and remove the corresponding warning in docs when BC allows. 2144 # Also remove the logic for arg parsing together. 2145 def state_dict(self, *args, destination=None, prefix="", keep_vars=False): 2146 r"""Return a dictionary containing references to the whole state of the module. 2147 2148 Both parameters and persistent buffers (e.g. running averages) are 2149 included. Keys are corresponding parameter and buffer names. 2150 Parameters and buffers set to ``None`` are not included. 2151 2152 .. note:: 2153 The returned object is a shallow copy. It contains references 2154 to the module's parameters and buffers. 2155 2156 .. warning:: 2157 Currently ``state_dict()`` also accepts positional arguments for 2158 ``destination``, ``prefix`` and ``keep_vars`` in order. However, 2159 this is being deprecated and keyword arguments will be enforced in 2160 future releases. 2161 2162 .. warning:: 2163 Please avoid the use of argument ``destination`` as it is not 2164 designed for end-users. 2165 2166 Args: 2167 destination (dict, optional): If provided, the state of module will 2168 be updated into the dict and the same object is returned. 2169 Otherwise, an ``OrderedDict`` will be created and returned. 2170 Default: ``None``. 2171 prefix (str, optional): a prefix added to parameter and buffer 2172 names to compose the keys in state_dict. Default: ``''``. 2173 keep_vars (bool, optional): by default the :class:`~torch.Tensor` s 2174 returned in the state dict are detached from autograd. If it's 2175 set to ``True``, detaching will not be performed. 2176 Default: ``False``. 2177 2178 Returns: 2179 dict: 2180 a dictionary containing a whole state of the module 2181 2182 Example:: 2183 2184 >>> # xdoctest: +SKIP("undefined vars") 2185 >>> module.state_dict().keys() 2186 ['bias', 'weight'] 2187 2188 """ 2189 # TODO: Remove `args` and the parsing logic when BC allows. 2190 if len(args) > 0: 2191 # DeprecationWarning is ignored by default 2192 warnings.warn( 2193 "Positional args are being deprecated, use kwargs instead. Refer to " 2194 "https://pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.state_dict" 2195 " for details.", 2196 FutureWarning, 2197 stacklevel=2, 2198 ) 2199 if destination is None: 2200 destination = args[0] 2201 if len(args) > 1 and prefix == "": 2202 prefix = args[1] 2203 if len(args) > 2 and keep_vars is False: 2204 keep_vars = args[2] 2205 2206 if destination is None: 2207 destination = OrderedDict() 2208 destination._metadata = OrderedDict() 2209 2210 local_metadata = dict(version=self._version) 2211 if hasattr(destination, "_metadata"): 2212 destination._metadata[prefix[:-1]] = local_metadata 2213 2214 for hook in self._state_dict_pre_hooks.values(): 2215 hook(self, prefix, keep_vars) 2216 self._save_to_state_dict(destination, prefix, keep_vars) 2217 for name, module in self._modules.items(): 2218 if module is not None: 2219 module.state_dict( 2220 destination=destination, 2221 prefix=prefix + name + ".", 2222 keep_vars=keep_vars, 2223 ) 2224 for hook in self._state_dict_hooks.values(): 2225 hook_result = hook(self, destination, prefix, local_metadata) 2226 if not getattr(hook, "_from_public_api", False): 2227 if hook_result is not None: 2228 destination = hook_result 2229 else: 2230 if hook_result is not None: 2231 raise RuntimeError("state_dict post-hook must return None") 2232 return destination 2233 2234 def _register_load_state_dict_pre_hook(self, hook, with_module=False): 2235 r"""See :meth:`~torch.nn.Module.register_load_state_dict_pre_hook` for details. 2236 2237 A subtle difference is that if ``with_module`` is set to ``False``, then the 2238 hook will not take the ``module`` as the first argument whereas 2239 :meth:`~torch.nn.Module.register_load_state_dict_pre_hook` always takes the 2240 ``module`` as the first argument. 2241 2242 Arguments: 2243 hook (Callable): Callable hook that will be invoked before 2244 loading the state dict. 2245 with_module (bool, optional): Whether or not to pass the module 2246 instance to the hook as the first parameter. 2247 """ 2248 handle = RemovableHandle(self._load_state_dict_pre_hooks) 2249 self._load_state_dict_pre_hooks[handle.id] = _WrappedHook( 2250 hook, self if with_module else None 2251 ) 2252 return handle 2253 2254 def register_load_state_dict_pre_hook(self, hook): 2255 r"""Register a pre-hook to be run before module's :meth:`~nn.Module.load_state_dict` is called. 2256 2257 It should have the following signature:: 2258 hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950 2259 2260 Arguments: 2261 hook (Callable): Callable hook that will be invoked before 2262 loading the state dict. 2263 """ 2264 return self._register_load_state_dict_pre_hook(hook, with_module=True) 2265 2266 def register_load_state_dict_post_hook(self, hook): 2267 r"""Register a post-hook to be run after module's :meth:`~nn.Module.load_state_dict` is called. 2268 2269 It should have the following signature:: 2270 hook(module, incompatible_keys) -> None 2271 2272 The ``module`` argument is the current module that this hook is registered 2273 on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting 2274 of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` 2275 is a ``list`` of ``str`` containing the missing keys and 2276 ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys. 2277 2278 The given incompatible_keys can be modified inplace if needed. 2279 2280 Note that the checks performed when calling :func:`load_state_dict` with 2281 ``strict=True`` are affected by modifications the hook makes to 2282 ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either 2283 set of keys will result in an error being thrown when ``strict=True``, and 2284 clearing out both missing and unexpected keys will avoid an error. 2285 2286 Returns: 2287 :class:`torch.utils.hooks.RemovableHandle`: 2288 a handle that can be used to remove the added hook by calling 2289 ``handle.remove()`` 2290 """ 2291 handle = RemovableHandle(self._load_state_dict_post_hooks) 2292 self._load_state_dict_post_hooks[handle.id] = hook 2293 return handle 2294 2295 def _load_from_state_dict( 2296 self, 2297 state_dict, 2298 prefix, 2299 local_metadata, 2300 strict, 2301 missing_keys, 2302 unexpected_keys, 2303 error_msgs, 2304 ): 2305 r"""Copy parameters and buffers from :attr:`state_dict` into only this module, but not its descendants. 2306 2307 This is called on every submodule 2308 in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this 2309 module in input :attr:`state_dict` is provided as :attr:`local_metadata`. 2310 For state dicts without metadata, :attr:`local_metadata` is empty. 2311 Subclasses can achieve class-specific backward compatible loading using 2312 the version number at `local_metadata.get("version", None)`. 2313 Additionally, :attr:`local_metadata` can also contain the key 2314 `assign_to_params_buffers` that indicates whether keys should be 2315 assigned their corresponding tensor in the state_dict. 2316 2317 .. note:: 2318 :attr:`state_dict` is not the same object as the input 2319 :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So 2320 it can be modified. 2321 2322 Args: 2323 state_dict (dict): a dict containing parameters and 2324 persistent buffers. 2325 prefix (str): the prefix for parameters and buffers used in this 2326 module 2327 local_metadata (dict): a dict containing the metadata for this module. 2328 See 2329 strict (bool): whether to strictly enforce that the keys in 2330 :attr:`state_dict` with :attr:`prefix` match the names of 2331 parameters and buffers in this module 2332 missing_keys (list of str): if ``strict=True``, add missing keys to 2333 this list 2334 unexpected_keys (list of str): if ``strict=True``, add unexpected 2335 keys to this list 2336 error_msgs (list of str): error messages should be added to this 2337 list, and will be reported together in 2338 :meth:`~torch.nn.Module.load_state_dict` 2339 """ 2340 for hook in self._load_state_dict_pre_hooks.values(): 2341 hook( 2342 state_dict, 2343 prefix, 2344 local_metadata, 2345 strict, 2346 missing_keys, 2347 unexpected_keys, 2348 error_msgs, 2349 ) 2350 2351 persistent_buffers = { 2352 k: v 2353 for k, v in self._buffers.items() 2354 if k not in self._non_persistent_buffers_set 2355 } 2356 local_name_params = itertools.chain( 2357 self._parameters.items(), persistent_buffers.items() 2358 ) 2359 local_state = {k: v for k, v in local_name_params if v is not None} 2360 assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) 2361 use_swap_tensors = torch.__future__.get_swap_module_params_on_conversion() 2362 2363 for name, param in local_state.items(): 2364 key = prefix + name 2365 if key in state_dict: 2366 input_param = state_dict[key] 2367 if not torch.overrides.is_tensor_like(input_param): 2368 error_msgs.append( 2369 f'While copying the parameter named "{key}", ' 2370 "expected torch.Tensor or Tensor-like object from checkpoint but " 2371 f"received {type(input_param)}" 2372 ) 2373 continue 2374 2375 # This is used to avoid copying uninitialized parameters into 2376 # non-lazy modules, since they dont have the hook to do the checks 2377 # in such case, it will error when accessing the .shape attribute. 2378 is_param_lazy = torch.nn.parameter.is_lazy(param) 2379 # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ 2380 if ( 2381 not is_param_lazy 2382 and len(param.shape) == 0 2383 and len(input_param.shape) == 1 2384 ): 2385 input_param = input_param[0] 2386 2387 if not is_param_lazy and input_param.shape != param.shape: 2388 # local shape should match the one in checkpoint 2389 error_msgs.append( 2390 f"size mismatch for {key}: copying a param with shape {input_param.shape} from checkpoint, " 2391 f"the shape in current model is {param.shape}." 2392 ) 2393 continue 2394 2395 if ( 2396 param.is_meta 2397 and not input_param.is_meta 2398 and not assign_to_params_buffers 2399 ): 2400 warnings.warn( 2401 f"for {key}: copying from a non-meta parameter in the checkpoint to a meta " 2402 "parameter in the current model, which is a no-op. (Did you mean to " 2403 "pass `assign=True` to assign items in the state dictionary to their " 2404 "corresponding key in the module instead of copying them in place?)" 2405 ) 2406 2407 try: 2408 with torch.no_grad(): 2409 if use_swap_tensors: 2410 new_input_param = param.module_load( 2411 input_param, assign=assign_to_params_buffers 2412 ) 2413 if id(new_input_param) == id(input_param) or id( 2414 new_input_param 2415 ) == id(param): 2416 raise RuntimeError( 2417 "module_load returned one of self or other, please .detach() " 2418 "the result if returning one of the inputs in module_load" 2419 ) 2420 if isinstance(param, torch.nn.Parameter): 2421 if not isinstance(new_input_param, torch.nn.Parameter): 2422 new_input_param = torch.nn.Parameter( 2423 new_input_param, 2424 requires_grad=param.requires_grad, 2425 ) 2426 else: 2427 new_input_param.requires_grad_(param.requires_grad) 2428 torch.utils.swap_tensors(param, new_input_param) 2429 del new_input_param 2430 elif assign_to_params_buffers: 2431 # Shape checks are already done above 2432 if isinstance(param, torch.nn.Parameter): 2433 if not isinstance(input_param, torch.nn.Parameter): 2434 input_param = torch.nn.Parameter( 2435 input_param, requires_grad=param.requires_grad 2436 ) 2437 else: 2438 input_param.requires_grad_(param.requires_grad) 2439 setattr(self, name, input_param) 2440 else: 2441 param.copy_(input_param) 2442 except Exception as ex: 2443 action = "swapping" if use_swap_tensors else "copying" 2444 error_msgs.append( 2445 f'While {action} the parameter named "{key}", ' 2446 f"whose dimensions in the model are {param.size()} and " 2447 f"whose dimensions in the checkpoint are {input_param.size()}, " 2448 f"an exception occurred : {ex.args}." 2449 ) 2450 elif strict: 2451 missing_keys.append(key) 2452 2453 extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX 2454 if ( 2455 getattr(self.__class__, "set_extra_state", Module.set_extra_state) 2456 is not Module.set_extra_state 2457 ): 2458 if extra_state_key in state_dict: 2459 self.set_extra_state(state_dict[extra_state_key]) 2460 elif strict: 2461 missing_keys.append(extra_state_key) 2462 elif strict and (extra_state_key in state_dict): 2463 unexpected_keys.append(extra_state_key) 2464 2465 if strict: 2466 for key in state_dict.keys(): 2467 if key.startswith(prefix) and key != extra_state_key: 2468 input_name = key[len(prefix) :].split(".", 1) 2469 # Must be Module if it have attributes 2470 if len(input_name) > 1: 2471 if input_name[0] not in self._modules: 2472 unexpected_keys.append(key) 2473 elif input_name[0] not in local_state: 2474 unexpected_keys.append(key) 2475 2476 def load_state_dict( 2477 self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False 2478 ): 2479 r"""Copy parameters and buffers from :attr:`state_dict` into this module and its descendants. 2480 2481 If :attr:`strict` is ``True``, then 2482 the keys of :attr:`state_dict` must exactly match the keys returned 2483 by this module's :meth:`~torch.nn.Module.state_dict` function. 2484 2485 .. warning:: 2486 If :attr:`assign` is ``True`` the optimizer must be created after 2487 the call to :attr:`load_state_dict` unless 2488 :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``. 2489 2490 Args: 2491 state_dict (dict): a dict containing parameters and 2492 persistent buffers. 2493 strict (bool, optional): whether to strictly enforce that the keys 2494 in :attr:`state_dict` match the keys returned by this module's 2495 :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` 2496 assign (bool, optional): When ``False``, the properties of the tensors 2497 in the current module are preserved while when ``True``, the 2498 properties of the Tensors in the state dict are preserved. The only 2499 exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s 2500 for which the value from the module is preserved. 2501 Default: ``False`` 2502 2503 Returns: 2504 ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: 2505 * **missing_keys** is a list of str containing any keys that are expected 2506 by this module but missing from the provided ``state_dict``. 2507 * **unexpected_keys** is a list of str containing the keys that are not 2508 expected by this module but present in the provided ``state_dict``. 2509 2510 Note: 2511 If a parameter or buffer is registered as ``None`` and its corresponding key 2512 exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a 2513 ``RuntimeError``. 2514 """ 2515 if not isinstance(state_dict, Mapping): 2516 raise TypeError( 2517 f"Expected state_dict to be dict-like, got {type(state_dict)}." 2518 ) 2519 2520 missing_keys: List[str] = [] 2521 unexpected_keys: List[str] = [] 2522 error_msgs: List[str] = [] 2523 2524 # copy state_dict so _load_from_state_dict can modify it 2525 metadata = getattr(state_dict, "_metadata", None) 2526 state_dict = OrderedDict(state_dict) 2527 if metadata is not None: 2528 # mypy isn't aware that "_metadata" exists in state_dict 2529 state_dict._metadata = metadata # type: ignore[attr-defined] 2530 2531 def load(module, local_state_dict, prefix=""): 2532 local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 2533 if assign: 2534 local_metadata["assign_to_params_buffers"] = assign 2535 module._load_from_state_dict( 2536 local_state_dict, 2537 prefix, 2538 local_metadata, 2539 True, 2540 missing_keys, 2541 unexpected_keys, 2542 error_msgs, 2543 ) 2544 for name, child in module._modules.items(): 2545 if child is not None: 2546 child_prefix = prefix + name + "." 2547 child_state_dict = { 2548 k: v 2549 for k, v in local_state_dict.items() 2550 if k.startswith(child_prefix) 2551 } 2552 load(child, child_state_dict, child_prefix) # noqa: F821 2553 2554 # Note that the hook can modify missing_keys and unexpected_keys. 2555 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys) 2556 for hook in module._load_state_dict_post_hooks.values(): 2557 out = hook(module, incompatible_keys) 2558 assert out is None, ( 2559 "Hooks registered with ``register_load_state_dict_post_hook`` are not" 2560 "expected to return new values, if incompatible_keys need to be modified," 2561 "it should be done inplace." 2562 ) 2563 2564 load(self, state_dict) 2565 del load 2566 2567 if strict: 2568 if len(unexpected_keys) > 0: 2569 error_msgs.insert( 2570 0, 2571 "Unexpected key(s) in state_dict: {}. ".format( 2572 ", ".join(f'"{k}"' for k in unexpected_keys) 2573 ), 2574 ) 2575 if len(missing_keys) > 0: 2576 error_msgs.insert( 2577 0, 2578 "Missing key(s) in state_dict: {}. ".format( 2579 ", ".join(f'"{k}"' for k in missing_keys) 2580 ), 2581 ) 2582 2583 if len(error_msgs) > 0: 2584 raise RuntimeError( 2585 "Error(s) in loading state_dict for {}:\n\t{}".format( 2586 self.__class__.__name__, "\n\t".join(error_msgs) 2587 ) 2588 ) 2589 return _IncompatibleKeys(missing_keys, unexpected_keys) 2590 2591 def _named_members( 2592 self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True 2593 ): 2594 r"""Help yield various names + members of modules.""" 2595 memo = set() 2596 modules = ( 2597 self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) 2598 if recurse 2599 else [(prefix, self)] 2600 ) 2601 for module_prefix, module in modules: 2602 members = get_members_fn(module) 2603 for k, v in members: 2604 if v is None or v in memo: 2605 continue 2606 if remove_duplicate: 2607 memo.add(v) 2608 name = module_prefix + ("." if module_prefix else "") + k 2609 yield name, v 2610 2611 def parameters(self, recurse: bool = True) -> Iterator[Parameter]: 2612 r"""Return an iterator over module parameters. 2613 2614 This is typically passed to an optimizer. 2615 2616 Args: 2617 recurse (bool): if True, then yields parameters of this module 2618 and all submodules. Otherwise, yields only parameters that 2619 are direct members of this module. 2620 2621 Yields: 2622 Parameter: module parameter 2623 2624 Example:: 2625 2626 >>> # xdoctest: +SKIP("undefined vars") 2627 >>> for param in model.parameters(): 2628 >>> print(type(param), param.size()) 2629 <class 'torch.Tensor'> (20L,) 2630 <class 'torch.Tensor'> (20L, 1L, 5L, 5L) 2631 2632 """ 2633 for name, param in self.named_parameters(recurse=recurse): 2634 yield param 2635 2636 def named_parameters( 2637 self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True 2638 ) -> Iterator[Tuple[str, Parameter]]: 2639 r"""Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. 2640 2641 Args: 2642 prefix (str): prefix to prepend to all parameter names. 2643 recurse (bool): if True, then yields parameters of this module 2644 and all submodules. Otherwise, yields only parameters that 2645 are direct members of this module. 2646 remove_duplicate (bool, optional): whether to remove the duplicated 2647 parameters in the result. Defaults to True. 2648 2649 Yields: 2650 (str, Parameter): Tuple containing the name and parameter 2651 2652 Example:: 2653 2654 >>> # xdoctest: +SKIP("undefined vars") 2655 >>> for name, param in self.named_parameters(): 2656 >>> if name in ['bias']: 2657 >>> print(param.size()) 2658 2659 """ 2660 gen = self._named_members( 2661 lambda module: module._parameters.items(), 2662 prefix=prefix, 2663 recurse=recurse, 2664 remove_duplicate=remove_duplicate, 2665 ) 2666 yield from gen 2667 2668 def buffers(self, recurse: bool = True) -> Iterator[Tensor]: 2669 r"""Return an iterator over module buffers. 2670 2671 Args: 2672 recurse (bool): if True, then yields buffers of this module 2673 and all submodules. Otherwise, yields only buffers that 2674 are direct members of this module. 2675 2676 Yields: 2677 torch.Tensor: module buffer 2678 2679 Example:: 2680 2681 >>> # xdoctest: +SKIP("undefined vars") 2682 >>> for buf in model.buffers(): 2683 >>> print(type(buf), buf.size()) 2684 <class 'torch.Tensor'> (20L,) 2685 <class 'torch.Tensor'> (20L, 1L, 5L, 5L) 2686 2687 """ 2688 for _, buf in self.named_buffers(recurse=recurse): 2689 yield buf 2690 2691 def named_buffers( 2692 self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True 2693 ) -> Iterator[Tuple[str, Tensor]]: 2694 r"""Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself. 2695 2696 Args: 2697 prefix (str): prefix to prepend to all buffer names. 2698 recurse (bool, optional): if True, then yields buffers of this module 2699 and all submodules. Otherwise, yields only buffers that 2700 are direct members of this module. Defaults to True. 2701 remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True. 2702 2703 Yields: 2704 (str, torch.Tensor): Tuple containing the name and buffer 2705 2706 Example:: 2707 2708 >>> # xdoctest: +SKIP("undefined vars") 2709 >>> for name, buf in self.named_buffers(): 2710 >>> if name in ['running_var']: 2711 >>> print(buf.size()) 2712 2713 """ 2714 gen = self._named_members( 2715 lambda module: module._buffers.items(), 2716 prefix=prefix, 2717 recurse=recurse, 2718 remove_duplicate=remove_duplicate, 2719 ) 2720 yield from gen 2721 2722 def children(self) -> Iterator["Module"]: 2723 r"""Return an iterator over immediate children modules. 2724 2725 Yields: 2726 Module: a child module 2727 """ 2728 for name, module in self.named_children(): 2729 yield module 2730 2731 def named_children(self) -> Iterator[Tuple[str, "Module"]]: 2732 r"""Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself. 2733 2734 Yields: 2735 (str, Module): Tuple containing a name and child module 2736 2737 Example:: 2738 2739 >>> # xdoctest: +SKIP("undefined vars") 2740 >>> for name, module in model.named_children(): 2741 >>> if name in ['conv4', 'conv5']: 2742 >>> print(module) 2743 2744 """ 2745 memo = set() 2746 for name, module in self._modules.items(): 2747 if module is not None and module not in memo: 2748 memo.add(module) 2749 yield name, module 2750 2751 def modules(self) -> Iterator["Module"]: 2752 r"""Return an iterator over all modules in the network. 2753 2754 Yields: 2755 Module: a module in the network 2756 2757 Note: 2758 Duplicate modules are returned only once. In the following 2759 example, ``l`` will be returned only once. 2760 2761 Example:: 2762 2763 >>> l = nn.Linear(2, 2) 2764 >>> net = nn.Sequential(l, l) 2765 >>> for idx, m in enumerate(net.modules()): 2766 ... print(idx, '->', m) 2767 2768 0 -> Sequential( 2769 (0): Linear(in_features=2, out_features=2, bias=True) 2770 (1): Linear(in_features=2, out_features=2, bias=True) 2771 ) 2772 1 -> Linear(in_features=2, out_features=2, bias=True) 2773 2774 """ 2775 for _, module in self.named_modules(): 2776 yield module 2777 2778 def named_modules( 2779 self, 2780 memo: Optional[Set["Module"]] = None, 2781 prefix: str = "", 2782 remove_duplicate: bool = True, 2783 ): 2784 r"""Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself. 2785 2786 Args: 2787 memo: a memo to store the set of modules already added to the result 2788 prefix: a prefix that will be added to the name of the module 2789 remove_duplicate: whether to remove the duplicated module instances in the result 2790 or not 2791 2792 Yields: 2793 (str, Module): Tuple of name and module 2794 2795 Note: 2796 Duplicate modules are returned only once. In the following 2797 example, ``l`` will be returned only once. 2798 2799 Example:: 2800 2801 >>> l = nn.Linear(2, 2) 2802 >>> net = nn.Sequential(l, l) 2803 >>> for idx, m in enumerate(net.named_modules()): 2804 ... print(idx, '->', m) 2805 2806 0 -> ('', Sequential( 2807 (0): Linear(in_features=2, out_features=2, bias=True) 2808 (1): Linear(in_features=2, out_features=2, bias=True) 2809 )) 2810 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) 2811 2812 """ 2813 if memo is None: 2814 memo = set() 2815 if self not in memo: 2816 if remove_duplicate: 2817 memo.add(self) 2818 yield prefix, self 2819 for name, module in self._modules.items(): 2820 if module is None: 2821 continue 2822 submodule_prefix = prefix + ("." if prefix else "") + name 2823 yield from module.named_modules( 2824 memo, submodule_prefix, remove_duplicate 2825 ) 2826 2827 def train(self: T, mode: bool = True) -> T: 2828 r"""Set the module in training mode. 2829 2830 This has any effect only on certain modules. See documentations of 2831 particular modules for details of their behaviors in training/evaluation 2832 mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, 2833 etc. 2834 2835 Args: 2836 mode (bool): whether to set training mode (``True``) or evaluation 2837 mode (``False``). Default: ``True``. 2838 2839 Returns: 2840 Module: self 2841 """ 2842 if not isinstance(mode, bool): 2843 raise ValueError("training mode is expected to be boolean") 2844 self.training = mode 2845 for module in self.children(): 2846 module.train(mode) 2847 return self 2848 2849 def eval(self: T) -> T: 2850 r"""Set the module in evaluation mode. 2851 2852 This has any effect only on certain modules. See documentations of 2853 particular modules for details of their behaviors in training/evaluation 2854 mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, 2855 etc. 2856 2857 This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`. 2858 2859 See :ref:`locally-disable-grad-doc` for a comparison between 2860 `.eval()` and several similar mechanisms that may be confused with it. 2861 2862 Returns: 2863 Module: self 2864 """ 2865 return self.train(False) 2866 2867 def requires_grad_(self: T, requires_grad: bool = True) -> T: 2868 r"""Change if autograd should record operations on parameters in this module. 2869 2870 This method sets the parameters' :attr:`requires_grad` attributes 2871 in-place. 2872 2873 This method is helpful for freezing part of the module for finetuning 2874 or training parts of a model individually (e.g., GAN training). 2875 2876 See :ref:`locally-disable-grad-doc` for a comparison between 2877 `.requires_grad_()` and several similar mechanisms that may be confused with it. 2878 2879 Args: 2880 requires_grad (bool): whether autograd should record operations on 2881 parameters in this module. Default: ``True``. 2882 2883 Returns: 2884 Module: self 2885 """ 2886 for p in self.parameters(): 2887 p.requires_grad_(requires_grad) 2888 return self 2889 2890 def zero_grad(self, set_to_none: bool = True) -> None: 2891 r"""Reset gradients of all model parameters. 2892 2893 See similar function under :class:`torch.optim.Optimizer` for more context. 2894 2895 Args: 2896 set_to_none (bool): instead of setting to zero, set the grads to None. 2897 See :meth:`torch.optim.Optimizer.zero_grad` for details. 2898 """ 2899 if getattr(self, "_is_replica", False): 2900 warnings.warn( 2901 "Calling .zero_grad() from a module created with nn.DataParallel() has no effect. " 2902 "The parameters are copied (in a differentiable manner) from the original module. " 2903 "This means they are not leaf nodes in autograd and so don't accumulate gradients. " 2904 "If you need gradients in your forward method, consider using autograd.grad instead." 2905 ) 2906 2907 for p in self.parameters(): 2908 if p.grad is not None: 2909 if set_to_none: 2910 p.grad = None 2911 else: 2912 if p.grad.grad_fn is not None: 2913 p.grad.detach_() 2914 else: 2915 p.grad.requires_grad_(False) 2916 p.grad.zero_() 2917 2918 def share_memory(self: T) -> T: 2919 r"""See :meth:`torch.Tensor.share_memory_`.""" 2920 return self._apply(lambda t: t.share_memory_()) 2921 2922 def _get_name(self): 2923 return self.__class__.__name__ 2924 2925 def extra_repr(self) -> str: 2926 r"""Set the extra representation of the module. 2927 2928 To print customized extra information, you should re-implement 2929 this method in your own modules. Both single-line and multi-line 2930 strings are acceptable. 2931 """ 2932 return "" 2933 2934 def __repr__(self): 2935 # We treat the extra repr like the sub-module, one item per line 2936 extra_lines = [] 2937 extra_repr = self.extra_repr() 2938 # empty string will be split into list [''] 2939 if extra_repr: 2940 extra_lines = extra_repr.split("\n") 2941 child_lines = [] 2942 for key, module in self._modules.items(): 2943 mod_str = repr(module) 2944 mod_str = _addindent(mod_str, 2) 2945 child_lines.append("(" + key + "): " + mod_str) 2946 lines = extra_lines + child_lines 2947 2948 main_str = self._get_name() + "(" 2949 if lines: 2950 # simple one-liner info, which most builtin Modules will use 2951 if len(extra_lines) == 1 and not child_lines: 2952 main_str += extra_lines[0] 2953 else: 2954 main_str += "\n " + "\n ".join(lines) + "\n" 2955 2956 main_str += ")" 2957 return main_str 2958 2959 def __dir__(self): 2960 module_attrs = dir(self.__class__) 2961 attrs = list(self.__dict__.keys()) 2962 parameters = list(self._parameters.keys()) 2963 modules = list(self._modules.keys()) 2964 buffers = list(self._buffers.keys()) 2965 keys = module_attrs + attrs + parameters + modules + buffers 2966 2967 # Eliminate attrs that are not legal Python variable names 2968 keys = [key for key in keys if not key[0].isdigit()] 2969 2970 return sorted(keys) 2971 2972 def _replicate_for_data_parallel(self): 2973 replica = self.__new__(type(self)) 2974 replica.__dict__ = self.__dict__.copy() 2975 2976 # replicas do not have parameters themselves, the replicas reference the original 2977 # module. 2978 replica._parameters = {} 2979 replica._buffers = replica._buffers.copy() 2980 replica._modules = replica._modules.copy() 2981 replica._is_replica = True # type: ignore[assignment] 2982 2983 return replica 2984 2985 def compile(self, *args, **kwargs): 2986 """ 2987 Compile this Module's forward using :func:`torch.compile`. 2988 2989 This Module's `__call__` method is compiled and all arguments are passed as-is 2990 to :func:`torch.compile`. 2991 2992 See :func:`torch.compile` for details on the arguments for this function. 2993 """ 2994 self._compiled_call_impl = torch.compile(self._call_impl, *args, **kwargs) 2995