xref: /aosp_15_r20/external/pytorch/torch/nn/modules/container.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import operator
4from collections import abc as container_abcs, OrderedDict
5from itertools import chain, islice
6from typing import (
7    Any,
8    Dict,
9    Iterable,
10    Iterator,
11    Mapping,
12    Optional,
13    overload,
14    Tuple,
15    TypeVar,
16    Union,
17)
18from typing_extensions import deprecated, Self
19
20import torch
21from torch._jit_internal import _copy_to_script_wrapper
22from torch.nn.parameter import Parameter
23
24from .module import Module
25
26
27__all__ = [
28    "Container",
29    "Sequential",
30    "ModuleList",
31    "ModuleDict",
32    "ParameterList",
33    "ParameterDict",
34]
35
36T = TypeVar("T", bound=Module)
37
38
39# Copied from torch.nn.modules.module, required for a custom __repr__ for ModuleList
40def _addindent(s_, numSpaces):
41    s = s_.split("\n")
42    # don't do anything for single-line stuff
43    if len(s) == 1:
44        return s_
45    first = s.pop(0)
46    s = [(numSpaces * " ") + line for line in s]
47    s = "\n".join(s)
48    s = first + "\n" + s
49    return s
50
51
52@deprecated(
53    "`nn.Container` is deprecated. "
54    "All of it's functionality is now implemented in `nn.Module`. Subclass that instead.",
55    category=FutureWarning,
56)
57class Container(Module):
58    def __init__(self, **kwargs: Any) -> None:
59        super().__init__()
60        for key, value in kwargs.items():
61            self.add_module(key, value)
62
63
64class Sequential(Module):
65    r"""A sequential container.
66
67    Modules will be added to it in the order they are passed in the
68    constructor. Alternatively, an ``OrderedDict`` of modules can be
69    passed in. The ``forward()`` method of ``Sequential`` accepts any
70    input and forwards it to the first module it contains. It then
71    "chains" outputs to inputs sequentially for each subsequent module,
72    finally returning the output of the last module.
73
74    The value a ``Sequential`` provides over manually calling a sequence
75    of modules is that it allows treating the whole container as a
76    single module, such that performing a transformation on the
77    ``Sequential`` applies to each of the modules it stores (which are
78    each a registered submodule of the ``Sequential``).
79
80    What's the difference between a ``Sequential`` and a
81    :class:`torch.nn.ModuleList`? A ``ModuleList`` is exactly what it
82    sounds like--a list for storing ``Module`` s! On the other hand,
83    the layers in a ``Sequential`` are connected in a cascading way.
84
85    Example::
86
87        # Using Sequential to create a small model. When `model` is run,
88        # input will first be passed to `Conv2d(1,20,5)`. The output of
89        # `Conv2d(1,20,5)` will be used as the input to the first
90        # `ReLU`; the output of the first `ReLU` will become the input
91        # for `Conv2d(20,64,5)`. Finally, the output of
92        # `Conv2d(20,64,5)` will be used as input to the second `ReLU`
93        model = nn.Sequential(
94                  nn.Conv2d(1,20,5),
95                  nn.ReLU(),
96                  nn.Conv2d(20,64,5),
97                  nn.ReLU()
98                )
99
100        # Using Sequential with OrderedDict. This is functionally the
101        # same as the above code
102        model = nn.Sequential(OrderedDict([
103                  ('conv1', nn.Conv2d(1,20,5)),
104                  ('relu1', nn.ReLU()),
105                  ('conv2', nn.Conv2d(20,64,5)),
106                  ('relu2', nn.ReLU())
107                ]))
108    """
109
110    _modules: Dict[str, Module]  # type: ignore[assignment]
111
112    @overload
113    def __init__(self, *args: Module) -> None:
114        ...
115
116    @overload
117    def __init__(self, arg: "OrderedDict[str, Module]") -> None:
118        ...
119
120    def __init__(self, *args):
121        super().__init__()
122        if len(args) == 1 and isinstance(args[0], OrderedDict):
123            for key, module in args[0].items():
124                self.add_module(key, module)
125        else:
126            for idx, module in enumerate(args):
127                self.add_module(str(idx), module)
128
129    def _get_item_by_idx(self, iterator, idx) -> T:  # type: ignore[misc, type-var]
130        """Get the idx-th item of the iterator."""
131        size = len(self)
132        idx = operator.index(idx)
133        if not -size <= idx < size:
134            raise IndexError(f"index {idx} is out of range")
135        idx %= size
136        return next(islice(iterator, idx, None))
137
138    @_copy_to_script_wrapper
139    def __getitem__(self, idx: Union[slice, int]) -> Union["Sequential", T]:
140        if isinstance(idx, slice):
141            return self.__class__(OrderedDict(list(self._modules.items())[idx]))
142        else:
143            return self._get_item_by_idx(self._modules.values(), idx)
144
145    def __setitem__(self, idx: int, module: Module) -> None:
146        key: str = self._get_item_by_idx(self._modules.keys(), idx)
147        return setattr(self, key, module)
148
149    def __delitem__(self, idx: Union[slice, int]) -> None:
150        if isinstance(idx, slice):
151            for key in list(self._modules.keys())[idx]:
152                delattr(self, key)
153        else:
154            key = self._get_item_by_idx(self._modules.keys(), idx)
155            delattr(self, key)
156        # To preserve numbering
157        str_indices = [str(i) for i in range(len(self._modules))]
158        self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
159
160    @_copy_to_script_wrapper
161    def __len__(self) -> int:
162        return len(self._modules)
163
164    def __add__(self, other) -> "Sequential":
165        if isinstance(other, Sequential):
166            ret = Sequential()
167            for layer in self:
168                ret.append(layer)
169            for layer in other:
170                ret.append(layer)
171            return ret
172        else:
173            raise ValueError(
174                "add operator supports only objects "
175                f"of Sequential class, but {str(type(other))} is given."
176            )
177
178    def pop(self, key: Union[int, slice]) -> Module:
179        v = self[key]
180        del self[key]
181        return v
182
183    def __iadd__(self, other) -> Self:
184        if isinstance(other, Sequential):
185            offset = len(self)
186            for i, module in enumerate(other):
187                self.add_module(str(i + offset), module)
188            return self
189        else:
190            raise ValueError(
191                "add operator supports only objects "
192                f"of Sequential class, but {str(type(other))} is given."
193            )
194
195    def __mul__(self, other: int) -> "Sequential":
196        if not isinstance(other, int):
197            raise TypeError(
198                f"unsupported operand type(s) for *: {type(self)} and {type(other)}"
199            )
200        elif other <= 0:
201            raise ValueError(
202                f"Non-positive multiplication factor {other} for {type(self)}"
203            )
204        else:
205            combined = Sequential()
206            offset = 0
207            for _ in range(other):
208                for module in self:
209                    combined.add_module(str(offset), module)
210                    offset += 1
211            return combined
212
213    def __rmul__(self, other: int) -> "Sequential":
214        return self.__mul__(other)
215
216    def __imul__(self, other: int) -> Self:
217        if not isinstance(other, int):
218            raise TypeError(
219                f"unsupported operand type(s) for *: {type(self)} and {type(other)}"
220            )
221        elif other <= 0:
222            raise ValueError(
223                f"Non-positive multiplication factor {other} for {type(self)}"
224            )
225        else:
226            len_original = len(self)
227            offset = len(self)
228            for _ in range(other - 1):
229                for i in range(len_original):
230                    self.add_module(str(i + offset), self._modules[str(i)])
231                offset += len_original
232            return self
233
234    @_copy_to_script_wrapper
235    def __dir__(self):
236        keys = super().__dir__()
237        keys = [key for key in keys if not key.isdigit()]
238        return keys
239
240    @_copy_to_script_wrapper
241    def __iter__(self) -> Iterator[Module]:
242        return iter(self._modules.values())
243
244    # NB: We can't really type check this function as the type of input
245    # may change dynamically (as is tested in
246    # TestScript.test_sequential_intermediary_types).  Cannot annotate
247    # with Any as TorchScript expects a more precise type
248    def forward(self, input):
249        for module in self:
250            input = module(input)
251        return input
252
253    def append(self, module: Module) -> "Sequential":
254        r"""Append a given module to the end.
255
256        Args:
257            module (nn.Module): module to append
258        """
259        self.add_module(str(len(self)), module)
260        return self
261
262    def insert(self, index: int, module: Module) -> "Sequential":
263        if not isinstance(module, Module):
264            raise AssertionError(f"module should be of type: {Module}")
265        n = len(self._modules)
266        if not (-n <= index <= n):
267            raise IndexError(f"Index out of range: {index}")
268        if index < 0:
269            index += n
270        for i in range(n, index, -1):
271            self._modules[str(i)] = self._modules[str(i - 1)]
272        self._modules[str(index)] = module
273        return self
274
275    def extend(self, sequential) -> "Sequential":
276        for layer in sequential:
277            self.append(layer)
278        return self
279
280
281class ModuleList(Module):
282    r"""Holds submodules in a list.
283
284    :class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but
285    modules it contains are properly registered, and will be visible by all
286    :class:`~torch.nn.Module` methods.
287
288    Args:
289        modules (iterable, optional): an iterable of modules to add
290
291    Example::
292
293        class MyModule(nn.Module):
294            def __init__(self) -> None:
295                super().__init__()
296                self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
297
298            def forward(self, x):
299                # ModuleList can act as an iterable, or be indexed using ints
300                for i, l in enumerate(self.linears):
301                    x = self.linears[i // 2](x) + l(x)
302                return x
303    """
304
305    _modules: Dict[str, Module]  # type: ignore[assignment]
306
307    def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
308        super().__init__()
309        if modules is not None:
310            self += modules
311
312    def _get_abs_string_index(self, idx):
313        """Get the absolute index for the list of modules."""
314        idx = operator.index(idx)
315        if not (-len(self) <= idx < len(self)):
316            raise IndexError(f"index {idx} is out of range")
317        if idx < 0:
318            idx += len(self)
319        return str(idx)
320
321    @overload
322    def __getitem__(self, idx: slice) -> "ModuleList":
323        ...
324
325    @overload
326    def __getitem__(self, idx: int) -> Module:
327        ...
328
329    @_copy_to_script_wrapper
330    def __getitem__(self, idx: Union[int, slice]) -> Union[Module, "ModuleList"]:
331        if isinstance(idx, slice):
332            return self.__class__(list(self._modules.values())[idx])
333        else:
334            return self._modules[self._get_abs_string_index(idx)]
335
336    def __setitem__(self, idx: int, module: Module) -> None:
337        idx = self._get_abs_string_index(idx)
338        return setattr(self, str(idx), module)
339
340    def __delitem__(self, idx: Union[int, slice]) -> None:
341        if isinstance(idx, slice):
342            for k in range(len(self._modules))[idx]:
343                delattr(self, str(k))
344        else:
345            delattr(self, self._get_abs_string_index(idx))
346        # To preserve numbering, self._modules is being reconstructed with modules after deletion
347        str_indices = [str(i) for i in range(len(self._modules))]
348        self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
349
350    @_copy_to_script_wrapper
351    def __len__(self) -> int:
352        return len(self._modules)
353
354    @_copy_to_script_wrapper
355    def __iter__(self) -> Iterator[Module]:
356        return iter(self._modules.values())
357
358    def __iadd__(self, modules: Iterable[Module]) -> Self:
359        return self.extend(modules)
360
361    def __add__(self, other: Iterable[Module]) -> "ModuleList":
362        combined = ModuleList()
363        for i, module in enumerate(chain(self, other)):
364            combined.add_module(str(i), module)
365        return combined
366
367    def __repr__(self):
368        """Return a custom repr for ModuleList that compresses repeated module representations."""
369        list_of_reprs = [repr(item) for item in self]
370        if len(list_of_reprs) == 0:
371            return self._get_name() + "()"
372
373        start_end_indices = [[0, 0]]
374        repeated_blocks = [list_of_reprs[0]]
375        for i, r in enumerate(list_of_reprs[1:], 1):
376            if r == repeated_blocks[-1]:
377                start_end_indices[-1][1] += 1
378                continue
379
380            start_end_indices.append([i, i])
381            repeated_blocks.append(r)
382
383        lines = []
384        main_str = self._get_name() + "("
385        for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
386            local_repr = f"({start_id}): {b}"  # default repr
387
388            if start_id != end_id:
389                n = end_id - start_id + 1
390                local_repr = f"({start_id}-{end_id}): {n} x {b}"
391
392            local_repr = _addindent(local_repr, 2)
393            lines.append(local_repr)
394
395        main_str += "\n  " + "\n  ".join(lines) + "\n"
396        main_str += ")"
397        return main_str
398
399    @_copy_to_script_wrapper
400    def __dir__(self):
401        keys = super().__dir__()
402        keys = [key for key in keys if not key.isdigit()]
403        return keys
404
405    def insert(self, index: int, module: Module) -> None:
406        r"""Insert a given module before a given index in the list.
407
408        Args:
409            index (int): index to insert.
410            module (nn.Module): module to insert
411        """
412        for i in range(len(self._modules), index, -1):
413            self._modules[str(i)] = self._modules[str(i - 1)]
414        self._modules[str(index)] = module
415
416    def append(self, module: Module) -> "ModuleList":
417        r"""Append a given module to the end of the list.
418
419        Args:
420            module (nn.Module): module to append
421        """
422        self.add_module(str(len(self)), module)
423        return self
424
425    def pop(self, key: Union[int, slice]) -> Module:
426        v = self[key]
427        del self[key]
428        return v
429
430    def extend(self, modules: Iterable[Module]) -> Self:
431        r"""Append modules from a Python iterable to the end of the list.
432
433        Args:
434            modules (iterable): iterable of modules to append
435        """
436        if not isinstance(modules, container_abcs.Iterable):
437            raise TypeError(
438                "ModuleList.extend should be called with an "
439                "iterable, but got " + type(modules).__name__
440            )
441        offset = len(self)
442        for i, module in enumerate(modules):
443            self.add_module(str(offset + i), module)
444        return self
445
446    # remove forward alltogether to fallback on Module's _forward_unimplemented
447
448
449class ModuleDict(Module):
450    r"""Holds submodules in a dictionary.
451
452    :class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary,
453    but modules it contains are properly registered, and will be visible by all
454    :class:`~torch.nn.Module` methods.
455
456    :class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects
457
458    * the order of insertion, and
459
460    * in :meth:`~torch.nn.ModuleDict.update`, the order of the merged
461      ``OrderedDict``, ``dict`` (started from Python 3.6) or another
462      :class:`~torch.nn.ModuleDict` (the argument to
463      :meth:`~torch.nn.ModuleDict.update`).
464
465    Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping
466    types (e.g., Python's plain ``dict`` before Python version 3.6) does not
467    preserve the order of the merged mapping.
468
469    Args:
470        modules (iterable, optional): a mapping (dictionary) of (string: module)
471            or an iterable of key-value pairs of type (string, module)
472
473    Example::
474
475        class MyModule(nn.Module):
476            def __init__(self) -> None:
477                super().__init__()
478                self.choices = nn.ModuleDict({
479                        'conv': nn.Conv2d(10, 10, 3),
480                        'pool': nn.MaxPool2d(3)
481                })
482                self.activations = nn.ModuleDict([
483                        ['lrelu', nn.LeakyReLU()],
484                        ['prelu', nn.PReLU()]
485                ])
486
487            def forward(self, x, choice, act):
488                x = self.choices[choice](x)
489                x = self.activations[act](x)
490                return x
491    """
492
493    _modules: Dict[str, Module]  # type: ignore[assignment]
494
495    def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
496        super().__init__()
497        if modules is not None:
498            self.update(modules)
499
500    @_copy_to_script_wrapper
501    def __getitem__(self, key: str) -> Module:
502        return self._modules[key]
503
504    def __setitem__(self, key: str, module: Module) -> None:
505        self.add_module(key, module)
506
507    def __delitem__(self, key: str) -> None:
508        del self._modules[key]
509
510    @_copy_to_script_wrapper
511    def __len__(self) -> int:
512        return len(self._modules)
513
514    @_copy_to_script_wrapper
515    def __iter__(self) -> Iterator[str]:
516        return iter(self._modules)
517
518    @_copy_to_script_wrapper
519    def __contains__(self, key: str) -> bool:
520        return key in self._modules
521
522    def clear(self) -> None:
523        """Remove all items from the ModuleDict."""
524        self._modules.clear()
525
526    def pop(self, key: str) -> Module:
527        r"""Remove key from the ModuleDict and return its module.
528
529        Args:
530            key (str): key to pop from the ModuleDict
531        """
532        v = self[key]
533        del self[key]
534        return v
535
536    @_copy_to_script_wrapper
537    def keys(self) -> Iterable[str]:
538        r"""Return an iterable of the ModuleDict keys."""
539        return self._modules.keys()
540
541    @_copy_to_script_wrapper
542    def items(self) -> Iterable[Tuple[str, Module]]:
543        r"""Return an iterable of the ModuleDict key/value pairs."""
544        return self._modules.items()
545
546    @_copy_to_script_wrapper
547    def values(self) -> Iterable[Module]:
548        r"""Return an iterable of the ModuleDict values."""
549        return self._modules.values()
550
551    def update(self, modules: Mapping[str, Module]) -> None:
552        r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys.
553
554        .. note::
555            If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
556            an iterable of key-value pairs, the order of new elements in it is preserved.
557
558        Args:
559            modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
560                or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
561        """
562        if not isinstance(modules, container_abcs.Iterable):
563            raise TypeError(
564                "ModuleDict.update should be called with an "
565                "iterable of key/value pairs, but got " + type(modules).__name__
566            )
567
568        if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
569            for key, module in modules.items():
570                self[key] = module
571        else:
572            # modules here can be a list with two items
573            for j, m in enumerate(modules):
574                if not isinstance(m, container_abcs.Iterable):
575                    raise TypeError(
576                        "ModuleDict update sequence element "
577                        "#" + str(j) + " should be Iterable; is" + type(m).__name__
578                    )
579                if not len(m) == 2:
580                    raise ValueError(
581                        "ModuleDict update sequence element "
582                        "#" + str(j) + " has length " + str(len(m)) + "; 2 is required"
583                    )
584                # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
585                # that's too cumbersome to type correctly with overloads, so we add an ignore here
586                self[m[0]] = m[1]  # type: ignore[assignment]
587
588    # remove forward alltogether to fallback on Module's _forward_unimplemented
589
590
591class ParameterList(Module):
592    r"""Holds parameters in a list.
593
594    :class:`~torch.nn.ParameterList` can be used like a regular Python
595    list, but Tensors that are :class:`~torch.nn.Parameter` are properly registered,
596    and will be visible by all :class:`~torch.nn.Module` methods.
597
598    Note that the constructor, assigning an element of the list, the
599    :meth:`~torch.nn.ParameterList.append` method and the :meth:`~torch.nn.ParameterList.extend`
600    method will convert any :class:`~torch.Tensor` into :class:`~torch.nn.Parameter`.
601
602    Args:
603        parameters (iterable, optional): an iterable of elements to add to the list.
604
605    Example::
606
607        class MyModule(nn.Module):
608            def __init__(self) -> None:
609                super().__init__()
610                self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
611
612            def forward(self, x):
613                # ParameterList can act as an iterable, or be indexed using ints
614                for i, p in enumerate(self.params):
615                    x = self.params[i // 2].mm(x) + p.mm(x)
616                return x
617    """
618
619    def __init__(self, values: Optional[Iterable[Any]] = None) -> None:
620        super().__init__()
621        self._size = 0
622        if values is not None:
623            self += values
624
625    def _get_abs_string_index(self, idx):
626        """Get the absolute index for the list of modules."""
627        idx = operator.index(idx)
628        if not (-len(self) <= idx < len(self)):
629            raise IndexError(f"index {idx} is out of range")
630        if idx < 0:
631            idx += len(self)
632        return str(idx)
633
634    @overload
635    def __getitem__(self, idx: int) -> Any:
636        ...
637
638    @overload
639    def __getitem__(self: T, idx: slice) -> T:
640        ...
641
642    def __getitem__(self, idx):
643        if isinstance(idx, slice):
644            start, stop, step = idx.indices(len(self))
645            out = self.__class__()
646            for i in range(start, stop, step):
647                out.append(self[i])
648            return out
649        else:
650            idx = self._get_abs_string_index(idx)
651            return getattr(self, str(idx))
652
653    def __setitem__(self, idx: int, param: Any) -> None:
654        # Note that all other function that add an entry to the list part of
655        # the ParameterList end up here. So this is the only place where we need
656        # to wrap things into Parameter if needed.
657        # Objects added via setattr() are not in the list part and thus won't
658        # call into this function.
659        idx = self._get_abs_string_index(idx)
660        if isinstance(param, torch.Tensor) and not isinstance(param, Parameter):
661            param = Parameter(param)
662        return setattr(self, str(idx), param)
663
664    def __len__(self) -> int:
665        return self._size
666
667    def __iter__(self) -> Iterator[Any]:
668        return iter(self[i] for i in range(len(self)))
669
670    def __iadd__(self, parameters: Iterable[Any]) -> Self:
671        return self.extend(parameters)
672
673    def __dir__(self):
674        keys = super().__dir__()
675        keys = [key for key in keys if not key.isdigit()]
676        return keys
677
678    def append(self, value: Any) -> "ParameterList":
679        """Append a given value at the end of the list.
680
681        Args:
682            value (Any): value to append
683        """
684        new_idx = len(self)
685        self._size += 1
686        self[new_idx] = value
687        return self
688
689    def extend(self, values: Iterable[Any]) -> Self:
690        """Append values from a Python iterable to the end of the list.
691
692        Args:
693            values (iterable): iterable of values to append
694        """
695        # Tensor is an iterable but we never want to unpack it here
696        if not isinstance(values, container_abcs.Iterable) or isinstance(
697            values, torch.Tensor
698        ):
699            raise TypeError(
700                "ParameterList.extend should be called with an "
701                "iterable, but got " + type(values).__name__
702            )
703        for value in values:
704            self.append(value)
705        return self
706
707    def extra_repr(self) -> str:
708        child_lines = []
709        for k, p in enumerate(self):
710            if isinstance(p, torch.Tensor):
711                size_str = "x".join(str(size) for size in p.size())
712                if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
713                    device_str = f" ({p.device})"
714                else:
715                    device_str = ""
716                parastr = "{} containing: [{} of size {}{}]".format(
717                    "Parameter" if isinstance(p, Parameter) else "Tensor",
718                    p.dtype,
719                    size_str,
720                    device_str,
721                )
722                child_lines.append("  (" + str(k) + "): " + parastr)
723            else:
724                child_lines.append(
725                    "  (" + str(k) + "): Object of type: " + type(p).__name__
726                )
727
728        tmpstr = "\n".join(child_lines)
729        return tmpstr
730
731    def __call__(self, *args, **kwargs):
732        raise RuntimeError("ParameterList should not be called.")
733
734
735class ParameterDict(Module):
736    r"""Holds parameters in a dictionary.
737
738    ParameterDict can be indexed like a regular Python dictionary, but Parameters it
739    contains are properly registered, and will be visible by all Module methods.
740    Other objects are treated as would be done by a regular Python dictionary
741
742    :class:`~torch.nn.ParameterDict` is an **ordered** dictionary.
743    :meth:`~torch.nn.ParameterDict.update` with other unordered mapping
744    types (e.g., Python's plain ``dict``) does not preserve the order of the
745    merged mapping. On the other hand, ``OrderedDict`` or another :class:`~torch.nn.ParameterDict`
746    will preserve their ordering.
747
748    Note that the constructor, assigning an element of the dictionary and the
749    :meth:`~torch.nn.ParameterDict.update` method will convert any :class:`~torch.Tensor` into
750    :class:`~torch.nn.Parameter`.
751
752    Args:
753        values (iterable, optional): a mapping (dictionary) of
754            (string : Any) or an iterable of key-value pairs
755            of type (string, Any)
756
757    Example::
758
759        class MyModule(nn.Module):
760            def __init__(self) -> None:
761                super().__init__()
762                self.params = nn.ParameterDict({
763                        'left': nn.Parameter(torch.randn(5, 10)),
764                        'right': nn.Parameter(torch.randn(5, 10))
765                })
766
767            def forward(self, x, choice):
768                x = self.params[choice].mm(x)
769                return x
770    """
771
772    def __init__(self, parameters: Any = None) -> None:
773        super().__init__()
774        self._keys: Dict[str, None] = {}
775        if parameters is not None:
776            self.update(parameters)
777
778    def _key_to_attr(self, key: str) -> str:
779        if not isinstance(key, str):
780            raise TypeError(
781                "Index given to ParameterDict cannot be used as a key as it is "
782                f"not a string (type is '{type(key).__name__}'). Open an issue on "
783                "github if you need non-string keys."
784            )
785        else:
786            # Use the key as-is so that `.named_parameters()` returns the right thing
787            return key
788
789    def __getitem__(self, key: str) -> Any:
790        attr = self._key_to_attr(key)
791        return getattr(self, attr)
792
793    def __setitem__(self, key: str, value: Any) -> None:
794        # Note that all other function that add an entry to the dictionary part of
795        # the ParameterDict end up here. So this is the only place where we need
796        # to wrap things into Parameter if needed.
797        # Objects added via setattr() are not in the dictionary part and thus won't
798        # call into this function.
799        self._keys[key] = None
800        attr = self._key_to_attr(key)
801        if isinstance(value, torch.Tensor) and not isinstance(value, Parameter):
802            value = Parameter(value)
803        setattr(self, attr, value)
804
805    def __delitem__(self, key: str) -> None:
806        del self._keys[key]
807        attr = self._key_to_attr(key)
808        delattr(self, attr)
809
810    def __len__(self) -> int:
811        return len(self._keys)
812
813    def __iter__(self) -> Iterator[str]:
814        return iter(self._keys)
815
816    def __reversed__(self) -> Iterator[str]:
817        return reversed(list(self._keys))
818
819    def copy(self) -> "ParameterDict":
820        """Return a copy of this :class:`~torch.nn.ParameterDict` instance."""
821        # We have to use an OrderedDict because the ParameterDict constructor
822        # behaves differently on plain dict vs OrderedDict
823        return ParameterDict(OrderedDict((k, self[k]) for k in self._keys))
824
825    def __contains__(self, key: str) -> bool:
826        return key in self._keys
827
828    def setdefault(self, key: str, default: Optional[Any] = None) -> Any:
829        """Set the default for a key in the Parameterdict.
830
831        If key is in the ParameterDict, return its value.
832        If not, insert `key` with a parameter `default` and return `default`.
833        `default` defaults to `None`.
834
835        Args:
836            key (str): key to set default for
837            default (Any): the parameter set to the key
838        """
839        if key not in self:
840            self[key] = default
841        return self[key]
842
843    def clear(self) -> None:
844        """Remove all items from the ParameterDict."""
845        for k in self._keys.copy():
846            del self[k]
847
848    def pop(self, key: str) -> Any:
849        r"""Remove key from the ParameterDict and return its parameter.
850
851        Args:
852            key (str): key to pop from the ParameterDict
853        """
854        v = self[key]
855        del self[key]
856        return v
857
858    def popitem(self) -> Tuple[str, Any]:
859        """Remove and return the last inserted `(key, parameter)` pair from the ParameterDict."""
860        k, _ = self._keys.popitem()
861        # We need the key in the _keys to be able to access/del
862        self._keys[k] = None
863        val = self[k]
864        del self[k]
865        return k, val
866
867    def get(self, key: str, default: Optional[Any] = None) -> Any:
868        r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not.
869
870        Args:
871            key (str): key to get from the ParameterDict
872            default (Parameter, optional): value to return if key not present
873        """
874        return self[key] if key in self else default
875
876    def fromkeys(
877        self, keys: Iterable[str], default: Optional[Any] = None
878    ) -> "ParameterDict":
879        r"""Return a new ParameterDict with the keys provided.
880
881        Args:
882            keys (iterable, string): keys to make the new ParameterDict from
883            default (Parameter, optional): value to set for all keys
884        """
885        return ParameterDict((k, default) for k in keys)
886
887    def keys(self) -> Iterable[str]:
888        r"""Return an iterable of the ParameterDict keys."""
889        return self._keys.keys()
890
891    def items(self) -> Iterable[Tuple[str, Any]]:
892        r"""Return an iterable of the ParameterDict key/value pairs."""
893        return ((k, self[k]) for k in self._keys)
894
895    def values(self) -> Iterable[Any]:
896        r"""Return an iterable of the ParameterDict values."""
897        return (self[k] for k in self._keys)
898
899    def update(self, parameters: Union[Mapping[str, Any], "ParameterDict"]) -> None:
900        r"""Update the :class:`~torch.nn.ParameterDict` with key-value pairs from ``parameters``, overwriting existing keys.
901
902        .. note::
903            If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or
904            an iterable of key-value pairs, the order of new elements in it is preserved.
905
906        Args:
907            parameters (iterable): a mapping (dictionary) from string to
908                :class:`~torch.nn.Parameter`, or an iterable of
909                key-value pairs of type (string, :class:`~torch.nn.Parameter`)
910        """
911        if not isinstance(parameters, container_abcs.Iterable):
912            raise TypeError(
913                "ParametersDict.update should be called with an "
914                "iterable of key/value pairs, but got " + type(parameters).__name__
915            )
916
917        if isinstance(parameters, (OrderedDict, ParameterDict)):
918            for key, parameter in parameters.items():
919                self[key] = parameter
920        elif isinstance(parameters, container_abcs.Mapping):
921            for key, parameter in sorted(parameters.items()):
922                self[key] = parameter
923        else:
924            for j, p in enumerate(parameters):
925                if not isinstance(p, container_abcs.Iterable):
926                    raise TypeError(
927                        "ParameterDict update sequence element "
928                        "#" + str(j) + " should be Iterable; is" + type(p).__name__
929                    )
930                if not len(p) == 2:
931                    raise ValueError(
932                        "ParameterDict update sequence element "
933                        "#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
934                    )
935                # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
936                self[p[0]] = p[1]  # type: ignore[assignment]
937
938    def extra_repr(self) -> str:
939        child_lines = []
940        for k, p in self.items():
941            if isinstance(p, torch.Tensor):
942                size_str = "x".join(str(size) for size in p.size())
943                if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
944                    device_str = f" ({p.device})"
945                else:
946                    device_str = ""
947                parastr = "{} containing: [{} of size {}{}]".format(
948                    "Parameter" if isinstance(p, Parameter) else "Tensor",
949                    torch.typename(p),
950                    size_str,
951                    device_str,
952                )
953                child_lines.append("  (" + str(k) + "): " + parastr)
954            else:
955                child_lines.append(
956                    "  (" + str(k) + "): Object of type: " + type(p).__name__
957                )
958        tmpstr = "\n".join(child_lines)
959        return tmpstr
960
961    def __call__(self, input):
962        raise RuntimeError("ParameterDict should not be called.")
963
964    def __or__(self, other: "ParameterDict") -> "ParameterDict":
965        copy = self.copy()
966        copy.update(other)
967        return copy
968
969    def __ror__(self, other: "ParameterDict") -> "ParameterDict":
970        copy = other.copy()
971        copy.update(self)
972        return copy
973
974    def __ior__(self, other: "ParameterDict") -> Self:
975        self.update(other)
976        return self
977