1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import operator 4from collections import abc as container_abcs, OrderedDict 5from itertools import chain, islice 6from typing import ( 7 Any, 8 Dict, 9 Iterable, 10 Iterator, 11 Mapping, 12 Optional, 13 overload, 14 Tuple, 15 TypeVar, 16 Union, 17) 18from typing_extensions import deprecated, Self 19 20import torch 21from torch._jit_internal import _copy_to_script_wrapper 22from torch.nn.parameter import Parameter 23 24from .module import Module 25 26 27__all__ = [ 28 "Container", 29 "Sequential", 30 "ModuleList", 31 "ModuleDict", 32 "ParameterList", 33 "ParameterDict", 34] 35 36T = TypeVar("T", bound=Module) 37 38 39# Copied from torch.nn.modules.module, required for a custom __repr__ for ModuleList 40def _addindent(s_, numSpaces): 41 s = s_.split("\n") 42 # don't do anything for single-line stuff 43 if len(s) == 1: 44 return s_ 45 first = s.pop(0) 46 s = [(numSpaces * " ") + line for line in s] 47 s = "\n".join(s) 48 s = first + "\n" + s 49 return s 50 51 52@deprecated( 53 "`nn.Container` is deprecated. " 54 "All of it's functionality is now implemented in `nn.Module`. Subclass that instead.", 55 category=FutureWarning, 56) 57class Container(Module): 58 def __init__(self, **kwargs: Any) -> None: 59 super().__init__() 60 for key, value in kwargs.items(): 61 self.add_module(key, value) 62 63 64class Sequential(Module): 65 r"""A sequential container. 66 67 Modules will be added to it in the order they are passed in the 68 constructor. Alternatively, an ``OrderedDict`` of modules can be 69 passed in. The ``forward()`` method of ``Sequential`` accepts any 70 input and forwards it to the first module it contains. It then 71 "chains" outputs to inputs sequentially for each subsequent module, 72 finally returning the output of the last module. 73 74 The value a ``Sequential`` provides over manually calling a sequence 75 of modules is that it allows treating the whole container as a 76 single module, such that performing a transformation on the 77 ``Sequential`` applies to each of the modules it stores (which are 78 each a registered submodule of the ``Sequential``). 79 80 What's the difference between a ``Sequential`` and a 81 :class:`torch.nn.ModuleList`? A ``ModuleList`` is exactly what it 82 sounds like--a list for storing ``Module`` s! On the other hand, 83 the layers in a ``Sequential`` are connected in a cascading way. 84 85 Example:: 86 87 # Using Sequential to create a small model. When `model` is run, 88 # input will first be passed to `Conv2d(1,20,5)`. The output of 89 # `Conv2d(1,20,5)` will be used as the input to the first 90 # `ReLU`; the output of the first `ReLU` will become the input 91 # for `Conv2d(20,64,5)`. Finally, the output of 92 # `Conv2d(20,64,5)` will be used as input to the second `ReLU` 93 model = nn.Sequential( 94 nn.Conv2d(1,20,5), 95 nn.ReLU(), 96 nn.Conv2d(20,64,5), 97 nn.ReLU() 98 ) 99 100 # Using Sequential with OrderedDict. This is functionally the 101 # same as the above code 102 model = nn.Sequential(OrderedDict([ 103 ('conv1', nn.Conv2d(1,20,5)), 104 ('relu1', nn.ReLU()), 105 ('conv2', nn.Conv2d(20,64,5)), 106 ('relu2', nn.ReLU()) 107 ])) 108 """ 109 110 _modules: Dict[str, Module] # type: ignore[assignment] 111 112 @overload 113 def __init__(self, *args: Module) -> None: 114 ... 115 116 @overload 117 def __init__(self, arg: "OrderedDict[str, Module]") -> None: 118 ... 119 120 def __init__(self, *args): 121 super().__init__() 122 if len(args) == 1 and isinstance(args[0], OrderedDict): 123 for key, module in args[0].items(): 124 self.add_module(key, module) 125 else: 126 for idx, module in enumerate(args): 127 self.add_module(str(idx), module) 128 129 def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var] 130 """Get the idx-th item of the iterator.""" 131 size = len(self) 132 idx = operator.index(idx) 133 if not -size <= idx < size: 134 raise IndexError(f"index {idx} is out of range") 135 idx %= size 136 return next(islice(iterator, idx, None)) 137 138 @_copy_to_script_wrapper 139 def __getitem__(self, idx: Union[slice, int]) -> Union["Sequential", T]: 140 if isinstance(idx, slice): 141 return self.__class__(OrderedDict(list(self._modules.items())[idx])) 142 else: 143 return self._get_item_by_idx(self._modules.values(), idx) 144 145 def __setitem__(self, idx: int, module: Module) -> None: 146 key: str = self._get_item_by_idx(self._modules.keys(), idx) 147 return setattr(self, key, module) 148 149 def __delitem__(self, idx: Union[slice, int]) -> None: 150 if isinstance(idx, slice): 151 for key in list(self._modules.keys())[idx]: 152 delattr(self, key) 153 else: 154 key = self._get_item_by_idx(self._modules.keys(), idx) 155 delattr(self, key) 156 # To preserve numbering 157 str_indices = [str(i) for i in range(len(self._modules))] 158 self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) 159 160 @_copy_to_script_wrapper 161 def __len__(self) -> int: 162 return len(self._modules) 163 164 def __add__(self, other) -> "Sequential": 165 if isinstance(other, Sequential): 166 ret = Sequential() 167 for layer in self: 168 ret.append(layer) 169 for layer in other: 170 ret.append(layer) 171 return ret 172 else: 173 raise ValueError( 174 "add operator supports only objects " 175 f"of Sequential class, but {str(type(other))} is given." 176 ) 177 178 def pop(self, key: Union[int, slice]) -> Module: 179 v = self[key] 180 del self[key] 181 return v 182 183 def __iadd__(self, other) -> Self: 184 if isinstance(other, Sequential): 185 offset = len(self) 186 for i, module in enumerate(other): 187 self.add_module(str(i + offset), module) 188 return self 189 else: 190 raise ValueError( 191 "add operator supports only objects " 192 f"of Sequential class, but {str(type(other))} is given." 193 ) 194 195 def __mul__(self, other: int) -> "Sequential": 196 if not isinstance(other, int): 197 raise TypeError( 198 f"unsupported operand type(s) for *: {type(self)} and {type(other)}" 199 ) 200 elif other <= 0: 201 raise ValueError( 202 f"Non-positive multiplication factor {other} for {type(self)}" 203 ) 204 else: 205 combined = Sequential() 206 offset = 0 207 for _ in range(other): 208 for module in self: 209 combined.add_module(str(offset), module) 210 offset += 1 211 return combined 212 213 def __rmul__(self, other: int) -> "Sequential": 214 return self.__mul__(other) 215 216 def __imul__(self, other: int) -> Self: 217 if not isinstance(other, int): 218 raise TypeError( 219 f"unsupported operand type(s) for *: {type(self)} and {type(other)}" 220 ) 221 elif other <= 0: 222 raise ValueError( 223 f"Non-positive multiplication factor {other} for {type(self)}" 224 ) 225 else: 226 len_original = len(self) 227 offset = len(self) 228 for _ in range(other - 1): 229 for i in range(len_original): 230 self.add_module(str(i + offset), self._modules[str(i)]) 231 offset += len_original 232 return self 233 234 @_copy_to_script_wrapper 235 def __dir__(self): 236 keys = super().__dir__() 237 keys = [key for key in keys if not key.isdigit()] 238 return keys 239 240 @_copy_to_script_wrapper 241 def __iter__(self) -> Iterator[Module]: 242 return iter(self._modules.values()) 243 244 # NB: We can't really type check this function as the type of input 245 # may change dynamically (as is tested in 246 # TestScript.test_sequential_intermediary_types). Cannot annotate 247 # with Any as TorchScript expects a more precise type 248 def forward(self, input): 249 for module in self: 250 input = module(input) 251 return input 252 253 def append(self, module: Module) -> "Sequential": 254 r"""Append a given module to the end. 255 256 Args: 257 module (nn.Module): module to append 258 """ 259 self.add_module(str(len(self)), module) 260 return self 261 262 def insert(self, index: int, module: Module) -> "Sequential": 263 if not isinstance(module, Module): 264 raise AssertionError(f"module should be of type: {Module}") 265 n = len(self._modules) 266 if not (-n <= index <= n): 267 raise IndexError(f"Index out of range: {index}") 268 if index < 0: 269 index += n 270 for i in range(n, index, -1): 271 self._modules[str(i)] = self._modules[str(i - 1)] 272 self._modules[str(index)] = module 273 return self 274 275 def extend(self, sequential) -> "Sequential": 276 for layer in sequential: 277 self.append(layer) 278 return self 279 280 281class ModuleList(Module): 282 r"""Holds submodules in a list. 283 284 :class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but 285 modules it contains are properly registered, and will be visible by all 286 :class:`~torch.nn.Module` methods. 287 288 Args: 289 modules (iterable, optional): an iterable of modules to add 290 291 Example:: 292 293 class MyModule(nn.Module): 294 def __init__(self) -> None: 295 super().__init__() 296 self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)]) 297 298 def forward(self, x): 299 # ModuleList can act as an iterable, or be indexed using ints 300 for i, l in enumerate(self.linears): 301 x = self.linears[i // 2](x) + l(x) 302 return x 303 """ 304 305 _modules: Dict[str, Module] # type: ignore[assignment] 306 307 def __init__(self, modules: Optional[Iterable[Module]] = None) -> None: 308 super().__init__() 309 if modules is not None: 310 self += modules 311 312 def _get_abs_string_index(self, idx): 313 """Get the absolute index for the list of modules.""" 314 idx = operator.index(idx) 315 if not (-len(self) <= idx < len(self)): 316 raise IndexError(f"index {idx} is out of range") 317 if idx < 0: 318 idx += len(self) 319 return str(idx) 320 321 @overload 322 def __getitem__(self, idx: slice) -> "ModuleList": 323 ... 324 325 @overload 326 def __getitem__(self, idx: int) -> Module: 327 ... 328 329 @_copy_to_script_wrapper 330 def __getitem__(self, idx: Union[int, slice]) -> Union[Module, "ModuleList"]: 331 if isinstance(idx, slice): 332 return self.__class__(list(self._modules.values())[idx]) 333 else: 334 return self._modules[self._get_abs_string_index(idx)] 335 336 def __setitem__(self, idx: int, module: Module) -> None: 337 idx = self._get_abs_string_index(idx) 338 return setattr(self, str(idx), module) 339 340 def __delitem__(self, idx: Union[int, slice]) -> None: 341 if isinstance(idx, slice): 342 for k in range(len(self._modules))[idx]: 343 delattr(self, str(k)) 344 else: 345 delattr(self, self._get_abs_string_index(idx)) 346 # To preserve numbering, self._modules is being reconstructed with modules after deletion 347 str_indices = [str(i) for i in range(len(self._modules))] 348 self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) 349 350 @_copy_to_script_wrapper 351 def __len__(self) -> int: 352 return len(self._modules) 353 354 @_copy_to_script_wrapper 355 def __iter__(self) -> Iterator[Module]: 356 return iter(self._modules.values()) 357 358 def __iadd__(self, modules: Iterable[Module]) -> Self: 359 return self.extend(modules) 360 361 def __add__(self, other: Iterable[Module]) -> "ModuleList": 362 combined = ModuleList() 363 for i, module in enumerate(chain(self, other)): 364 combined.add_module(str(i), module) 365 return combined 366 367 def __repr__(self): 368 """Return a custom repr for ModuleList that compresses repeated module representations.""" 369 list_of_reprs = [repr(item) for item in self] 370 if len(list_of_reprs) == 0: 371 return self._get_name() + "()" 372 373 start_end_indices = [[0, 0]] 374 repeated_blocks = [list_of_reprs[0]] 375 for i, r in enumerate(list_of_reprs[1:], 1): 376 if r == repeated_blocks[-1]: 377 start_end_indices[-1][1] += 1 378 continue 379 380 start_end_indices.append([i, i]) 381 repeated_blocks.append(r) 382 383 lines = [] 384 main_str = self._get_name() + "(" 385 for (start_id, end_id), b in zip(start_end_indices, repeated_blocks): 386 local_repr = f"({start_id}): {b}" # default repr 387 388 if start_id != end_id: 389 n = end_id - start_id + 1 390 local_repr = f"({start_id}-{end_id}): {n} x {b}" 391 392 local_repr = _addindent(local_repr, 2) 393 lines.append(local_repr) 394 395 main_str += "\n " + "\n ".join(lines) + "\n" 396 main_str += ")" 397 return main_str 398 399 @_copy_to_script_wrapper 400 def __dir__(self): 401 keys = super().__dir__() 402 keys = [key for key in keys if not key.isdigit()] 403 return keys 404 405 def insert(self, index: int, module: Module) -> None: 406 r"""Insert a given module before a given index in the list. 407 408 Args: 409 index (int): index to insert. 410 module (nn.Module): module to insert 411 """ 412 for i in range(len(self._modules), index, -1): 413 self._modules[str(i)] = self._modules[str(i - 1)] 414 self._modules[str(index)] = module 415 416 def append(self, module: Module) -> "ModuleList": 417 r"""Append a given module to the end of the list. 418 419 Args: 420 module (nn.Module): module to append 421 """ 422 self.add_module(str(len(self)), module) 423 return self 424 425 def pop(self, key: Union[int, slice]) -> Module: 426 v = self[key] 427 del self[key] 428 return v 429 430 def extend(self, modules: Iterable[Module]) -> Self: 431 r"""Append modules from a Python iterable to the end of the list. 432 433 Args: 434 modules (iterable): iterable of modules to append 435 """ 436 if not isinstance(modules, container_abcs.Iterable): 437 raise TypeError( 438 "ModuleList.extend should be called with an " 439 "iterable, but got " + type(modules).__name__ 440 ) 441 offset = len(self) 442 for i, module in enumerate(modules): 443 self.add_module(str(offset + i), module) 444 return self 445 446 # remove forward alltogether to fallback on Module's _forward_unimplemented 447 448 449class ModuleDict(Module): 450 r"""Holds submodules in a dictionary. 451 452 :class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary, 453 but modules it contains are properly registered, and will be visible by all 454 :class:`~torch.nn.Module` methods. 455 456 :class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects 457 458 * the order of insertion, and 459 460 * in :meth:`~torch.nn.ModuleDict.update`, the order of the merged 461 ``OrderedDict``, ``dict`` (started from Python 3.6) or another 462 :class:`~torch.nn.ModuleDict` (the argument to 463 :meth:`~torch.nn.ModuleDict.update`). 464 465 Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping 466 types (e.g., Python's plain ``dict`` before Python version 3.6) does not 467 preserve the order of the merged mapping. 468 469 Args: 470 modules (iterable, optional): a mapping (dictionary) of (string: module) 471 or an iterable of key-value pairs of type (string, module) 472 473 Example:: 474 475 class MyModule(nn.Module): 476 def __init__(self) -> None: 477 super().__init__() 478 self.choices = nn.ModuleDict({ 479 'conv': nn.Conv2d(10, 10, 3), 480 'pool': nn.MaxPool2d(3) 481 }) 482 self.activations = nn.ModuleDict([ 483 ['lrelu', nn.LeakyReLU()], 484 ['prelu', nn.PReLU()] 485 ]) 486 487 def forward(self, x, choice, act): 488 x = self.choices[choice](x) 489 x = self.activations[act](x) 490 return x 491 """ 492 493 _modules: Dict[str, Module] # type: ignore[assignment] 494 495 def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None: 496 super().__init__() 497 if modules is not None: 498 self.update(modules) 499 500 @_copy_to_script_wrapper 501 def __getitem__(self, key: str) -> Module: 502 return self._modules[key] 503 504 def __setitem__(self, key: str, module: Module) -> None: 505 self.add_module(key, module) 506 507 def __delitem__(self, key: str) -> None: 508 del self._modules[key] 509 510 @_copy_to_script_wrapper 511 def __len__(self) -> int: 512 return len(self._modules) 513 514 @_copy_to_script_wrapper 515 def __iter__(self) -> Iterator[str]: 516 return iter(self._modules) 517 518 @_copy_to_script_wrapper 519 def __contains__(self, key: str) -> bool: 520 return key in self._modules 521 522 def clear(self) -> None: 523 """Remove all items from the ModuleDict.""" 524 self._modules.clear() 525 526 def pop(self, key: str) -> Module: 527 r"""Remove key from the ModuleDict and return its module. 528 529 Args: 530 key (str): key to pop from the ModuleDict 531 """ 532 v = self[key] 533 del self[key] 534 return v 535 536 @_copy_to_script_wrapper 537 def keys(self) -> Iterable[str]: 538 r"""Return an iterable of the ModuleDict keys.""" 539 return self._modules.keys() 540 541 @_copy_to_script_wrapper 542 def items(self) -> Iterable[Tuple[str, Module]]: 543 r"""Return an iterable of the ModuleDict key/value pairs.""" 544 return self._modules.items() 545 546 @_copy_to_script_wrapper 547 def values(self) -> Iterable[Module]: 548 r"""Return an iterable of the ModuleDict values.""" 549 return self._modules.values() 550 551 def update(self, modules: Mapping[str, Module]) -> None: 552 r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys. 553 554 .. note:: 555 If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or 556 an iterable of key-value pairs, the order of new elements in it is preserved. 557 558 Args: 559 modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`, 560 or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`) 561 """ 562 if not isinstance(modules, container_abcs.Iterable): 563 raise TypeError( 564 "ModuleDict.update should be called with an " 565 "iterable of key/value pairs, but got " + type(modules).__name__ 566 ) 567 568 if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)): 569 for key, module in modules.items(): 570 self[key] = module 571 else: 572 # modules here can be a list with two items 573 for j, m in enumerate(modules): 574 if not isinstance(m, container_abcs.Iterable): 575 raise TypeError( 576 "ModuleDict update sequence element " 577 "#" + str(j) + " should be Iterable; is" + type(m).__name__ 578 ) 579 if not len(m) == 2: 580 raise ValueError( 581 "ModuleDict update sequence element " 582 "#" + str(j) + " has length " + str(len(m)) + "; 2 is required" 583 ) 584 # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] 585 # that's too cumbersome to type correctly with overloads, so we add an ignore here 586 self[m[0]] = m[1] # type: ignore[assignment] 587 588 # remove forward alltogether to fallback on Module's _forward_unimplemented 589 590 591class ParameterList(Module): 592 r"""Holds parameters in a list. 593 594 :class:`~torch.nn.ParameterList` can be used like a regular Python 595 list, but Tensors that are :class:`~torch.nn.Parameter` are properly registered, 596 and will be visible by all :class:`~torch.nn.Module` methods. 597 598 Note that the constructor, assigning an element of the list, the 599 :meth:`~torch.nn.ParameterList.append` method and the :meth:`~torch.nn.ParameterList.extend` 600 method will convert any :class:`~torch.Tensor` into :class:`~torch.nn.Parameter`. 601 602 Args: 603 parameters (iterable, optional): an iterable of elements to add to the list. 604 605 Example:: 606 607 class MyModule(nn.Module): 608 def __init__(self) -> None: 609 super().__init__() 610 self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)]) 611 612 def forward(self, x): 613 # ParameterList can act as an iterable, or be indexed using ints 614 for i, p in enumerate(self.params): 615 x = self.params[i // 2].mm(x) + p.mm(x) 616 return x 617 """ 618 619 def __init__(self, values: Optional[Iterable[Any]] = None) -> None: 620 super().__init__() 621 self._size = 0 622 if values is not None: 623 self += values 624 625 def _get_abs_string_index(self, idx): 626 """Get the absolute index for the list of modules.""" 627 idx = operator.index(idx) 628 if not (-len(self) <= idx < len(self)): 629 raise IndexError(f"index {idx} is out of range") 630 if idx < 0: 631 idx += len(self) 632 return str(idx) 633 634 @overload 635 def __getitem__(self, idx: int) -> Any: 636 ... 637 638 @overload 639 def __getitem__(self: T, idx: slice) -> T: 640 ... 641 642 def __getitem__(self, idx): 643 if isinstance(idx, slice): 644 start, stop, step = idx.indices(len(self)) 645 out = self.__class__() 646 for i in range(start, stop, step): 647 out.append(self[i]) 648 return out 649 else: 650 idx = self._get_abs_string_index(idx) 651 return getattr(self, str(idx)) 652 653 def __setitem__(self, idx: int, param: Any) -> None: 654 # Note that all other function that add an entry to the list part of 655 # the ParameterList end up here. So this is the only place where we need 656 # to wrap things into Parameter if needed. 657 # Objects added via setattr() are not in the list part and thus won't 658 # call into this function. 659 idx = self._get_abs_string_index(idx) 660 if isinstance(param, torch.Tensor) and not isinstance(param, Parameter): 661 param = Parameter(param) 662 return setattr(self, str(idx), param) 663 664 def __len__(self) -> int: 665 return self._size 666 667 def __iter__(self) -> Iterator[Any]: 668 return iter(self[i] for i in range(len(self))) 669 670 def __iadd__(self, parameters: Iterable[Any]) -> Self: 671 return self.extend(parameters) 672 673 def __dir__(self): 674 keys = super().__dir__() 675 keys = [key for key in keys if not key.isdigit()] 676 return keys 677 678 def append(self, value: Any) -> "ParameterList": 679 """Append a given value at the end of the list. 680 681 Args: 682 value (Any): value to append 683 """ 684 new_idx = len(self) 685 self._size += 1 686 self[new_idx] = value 687 return self 688 689 def extend(self, values: Iterable[Any]) -> Self: 690 """Append values from a Python iterable to the end of the list. 691 692 Args: 693 values (iterable): iterable of values to append 694 """ 695 # Tensor is an iterable but we never want to unpack it here 696 if not isinstance(values, container_abcs.Iterable) or isinstance( 697 values, torch.Tensor 698 ): 699 raise TypeError( 700 "ParameterList.extend should be called with an " 701 "iterable, but got " + type(values).__name__ 702 ) 703 for value in values: 704 self.append(value) 705 return self 706 707 def extra_repr(self) -> str: 708 child_lines = [] 709 for k, p in enumerate(self): 710 if isinstance(p, torch.Tensor): 711 size_str = "x".join(str(size) for size in p.size()) 712 if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]: 713 device_str = f" ({p.device})" 714 else: 715 device_str = "" 716 parastr = "{} containing: [{} of size {}{}]".format( 717 "Parameter" if isinstance(p, Parameter) else "Tensor", 718 p.dtype, 719 size_str, 720 device_str, 721 ) 722 child_lines.append(" (" + str(k) + "): " + parastr) 723 else: 724 child_lines.append( 725 " (" + str(k) + "): Object of type: " + type(p).__name__ 726 ) 727 728 tmpstr = "\n".join(child_lines) 729 return tmpstr 730 731 def __call__(self, *args, **kwargs): 732 raise RuntimeError("ParameterList should not be called.") 733 734 735class ParameterDict(Module): 736 r"""Holds parameters in a dictionary. 737 738 ParameterDict can be indexed like a regular Python dictionary, but Parameters it 739 contains are properly registered, and will be visible by all Module methods. 740 Other objects are treated as would be done by a regular Python dictionary 741 742 :class:`~torch.nn.ParameterDict` is an **ordered** dictionary. 743 :meth:`~torch.nn.ParameterDict.update` with other unordered mapping 744 types (e.g., Python's plain ``dict``) does not preserve the order of the 745 merged mapping. On the other hand, ``OrderedDict`` or another :class:`~torch.nn.ParameterDict` 746 will preserve their ordering. 747 748 Note that the constructor, assigning an element of the dictionary and the 749 :meth:`~torch.nn.ParameterDict.update` method will convert any :class:`~torch.Tensor` into 750 :class:`~torch.nn.Parameter`. 751 752 Args: 753 values (iterable, optional): a mapping (dictionary) of 754 (string : Any) or an iterable of key-value pairs 755 of type (string, Any) 756 757 Example:: 758 759 class MyModule(nn.Module): 760 def __init__(self) -> None: 761 super().__init__() 762 self.params = nn.ParameterDict({ 763 'left': nn.Parameter(torch.randn(5, 10)), 764 'right': nn.Parameter(torch.randn(5, 10)) 765 }) 766 767 def forward(self, x, choice): 768 x = self.params[choice].mm(x) 769 return x 770 """ 771 772 def __init__(self, parameters: Any = None) -> None: 773 super().__init__() 774 self._keys: Dict[str, None] = {} 775 if parameters is not None: 776 self.update(parameters) 777 778 def _key_to_attr(self, key: str) -> str: 779 if not isinstance(key, str): 780 raise TypeError( 781 "Index given to ParameterDict cannot be used as a key as it is " 782 f"not a string (type is '{type(key).__name__}'). Open an issue on " 783 "github if you need non-string keys." 784 ) 785 else: 786 # Use the key as-is so that `.named_parameters()` returns the right thing 787 return key 788 789 def __getitem__(self, key: str) -> Any: 790 attr = self._key_to_attr(key) 791 return getattr(self, attr) 792 793 def __setitem__(self, key: str, value: Any) -> None: 794 # Note that all other function that add an entry to the dictionary part of 795 # the ParameterDict end up here. So this is the only place where we need 796 # to wrap things into Parameter if needed. 797 # Objects added via setattr() are not in the dictionary part and thus won't 798 # call into this function. 799 self._keys[key] = None 800 attr = self._key_to_attr(key) 801 if isinstance(value, torch.Tensor) and not isinstance(value, Parameter): 802 value = Parameter(value) 803 setattr(self, attr, value) 804 805 def __delitem__(self, key: str) -> None: 806 del self._keys[key] 807 attr = self._key_to_attr(key) 808 delattr(self, attr) 809 810 def __len__(self) -> int: 811 return len(self._keys) 812 813 def __iter__(self) -> Iterator[str]: 814 return iter(self._keys) 815 816 def __reversed__(self) -> Iterator[str]: 817 return reversed(list(self._keys)) 818 819 def copy(self) -> "ParameterDict": 820 """Return a copy of this :class:`~torch.nn.ParameterDict` instance.""" 821 # We have to use an OrderedDict because the ParameterDict constructor 822 # behaves differently on plain dict vs OrderedDict 823 return ParameterDict(OrderedDict((k, self[k]) for k in self._keys)) 824 825 def __contains__(self, key: str) -> bool: 826 return key in self._keys 827 828 def setdefault(self, key: str, default: Optional[Any] = None) -> Any: 829 """Set the default for a key in the Parameterdict. 830 831 If key is in the ParameterDict, return its value. 832 If not, insert `key` with a parameter `default` and return `default`. 833 `default` defaults to `None`. 834 835 Args: 836 key (str): key to set default for 837 default (Any): the parameter set to the key 838 """ 839 if key not in self: 840 self[key] = default 841 return self[key] 842 843 def clear(self) -> None: 844 """Remove all items from the ParameterDict.""" 845 for k in self._keys.copy(): 846 del self[k] 847 848 def pop(self, key: str) -> Any: 849 r"""Remove key from the ParameterDict and return its parameter. 850 851 Args: 852 key (str): key to pop from the ParameterDict 853 """ 854 v = self[key] 855 del self[key] 856 return v 857 858 def popitem(self) -> Tuple[str, Any]: 859 """Remove and return the last inserted `(key, parameter)` pair from the ParameterDict.""" 860 k, _ = self._keys.popitem() 861 # We need the key in the _keys to be able to access/del 862 self._keys[k] = None 863 val = self[k] 864 del self[k] 865 return k, val 866 867 def get(self, key: str, default: Optional[Any] = None) -> Any: 868 r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not. 869 870 Args: 871 key (str): key to get from the ParameterDict 872 default (Parameter, optional): value to return if key not present 873 """ 874 return self[key] if key in self else default 875 876 def fromkeys( 877 self, keys: Iterable[str], default: Optional[Any] = None 878 ) -> "ParameterDict": 879 r"""Return a new ParameterDict with the keys provided. 880 881 Args: 882 keys (iterable, string): keys to make the new ParameterDict from 883 default (Parameter, optional): value to set for all keys 884 """ 885 return ParameterDict((k, default) for k in keys) 886 887 def keys(self) -> Iterable[str]: 888 r"""Return an iterable of the ParameterDict keys.""" 889 return self._keys.keys() 890 891 def items(self) -> Iterable[Tuple[str, Any]]: 892 r"""Return an iterable of the ParameterDict key/value pairs.""" 893 return ((k, self[k]) for k in self._keys) 894 895 def values(self) -> Iterable[Any]: 896 r"""Return an iterable of the ParameterDict values.""" 897 return (self[k] for k in self._keys) 898 899 def update(self, parameters: Union[Mapping[str, Any], "ParameterDict"]) -> None: 900 r"""Update the :class:`~torch.nn.ParameterDict` with key-value pairs from ``parameters``, overwriting existing keys. 901 902 .. note:: 903 If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or 904 an iterable of key-value pairs, the order of new elements in it is preserved. 905 906 Args: 907 parameters (iterable): a mapping (dictionary) from string to 908 :class:`~torch.nn.Parameter`, or an iterable of 909 key-value pairs of type (string, :class:`~torch.nn.Parameter`) 910 """ 911 if not isinstance(parameters, container_abcs.Iterable): 912 raise TypeError( 913 "ParametersDict.update should be called with an " 914 "iterable of key/value pairs, but got " + type(parameters).__name__ 915 ) 916 917 if isinstance(parameters, (OrderedDict, ParameterDict)): 918 for key, parameter in parameters.items(): 919 self[key] = parameter 920 elif isinstance(parameters, container_abcs.Mapping): 921 for key, parameter in sorted(parameters.items()): 922 self[key] = parameter 923 else: 924 for j, p in enumerate(parameters): 925 if not isinstance(p, container_abcs.Iterable): 926 raise TypeError( 927 "ParameterDict update sequence element " 928 "#" + str(j) + " should be Iterable; is" + type(p).__name__ 929 ) 930 if not len(p) == 2: 931 raise ValueError( 932 "ParameterDict update sequence element " 933 "#" + str(j) + " has length " + str(len(p)) + "; 2 is required" 934 ) 935 # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment 936 self[p[0]] = p[1] # type: ignore[assignment] 937 938 def extra_repr(self) -> str: 939 child_lines = [] 940 for k, p in self.items(): 941 if isinstance(p, torch.Tensor): 942 size_str = "x".join(str(size) for size in p.size()) 943 if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]: 944 device_str = f" ({p.device})" 945 else: 946 device_str = "" 947 parastr = "{} containing: [{} of size {}{}]".format( 948 "Parameter" if isinstance(p, Parameter) else "Tensor", 949 torch.typename(p), 950 size_str, 951 device_str, 952 ) 953 child_lines.append(" (" + str(k) + "): " + parastr) 954 else: 955 child_lines.append( 956 " (" + str(k) + "): Object of type: " + type(p).__name__ 957 ) 958 tmpstr = "\n".join(child_lines) 959 return tmpstr 960 961 def __call__(self, input): 962 raise RuntimeError("ParameterDict should not be called.") 963 964 def __or__(self, other: "ParameterDict") -> "ParameterDict": 965 copy = self.copy() 966 copy.update(other) 967 return copy 968 969 def __ror__(self, other: "ParameterDict") -> "ParameterDict": 970 copy = other.copy() 971 copy.update(self) 972 return copy 973 974 def __ior__(self, other: "ParameterDict") -> Self: 975 self.update(other) 976 return self 977