xref: /aosp_15_r20/external/pytorch/torch/nn/modules/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3from itertools import repeat
4from typing import Any, Dict, List
5
6
7__all__ = ["consume_prefix_in_state_dict_if_present"]
8
9
10def _ntuple(n, name="parse"):
11    def parse(x):
12        if isinstance(x, collections.abc.Iterable):
13            return tuple(x)
14        return tuple(repeat(x, n))
15
16    parse.__name__ = name
17    return parse
18
19
20_single = _ntuple(1, "_single")
21_pair = _ntuple(2, "_pair")
22_triple = _ntuple(3, "_triple")
23_quadruple = _ntuple(4, "_quadruple")
24
25
26def _reverse_repeat_tuple(t, n):
27    r"""Reverse the order of `t` and repeat each element for `n` times.
28
29    This can be used to translate padding arg used by Conv and Pooling modules
30    to the ones used by `F.pad`.
31    """
32    return tuple(x for x in reversed(t) for _ in range(n))
33
34
35def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]:
36    import torch
37
38    if isinstance(out_size, (int, torch.SymInt)):
39        return out_size
40    if len(defaults) <= len(out_size):
41        raise ValueError(f"Input dimension should be at least {len(out_size) + 1}")
42    return [
43        v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size) :])
44    ]
45
46
47def consume_prefix_in_state_dict_if_present(
48    state_dict: Dict[str, Any],
49    prefix: str,
50) -> None:
51    r"""Strip the prefix in state_dict in place, if any.
52
53    ..note::
54        Given a `state_dict` from a DP/DDP model, a local model can load it by applying
55        `consume_prefix_in_state_dict_if_present(state_dict, "module.")` before calling
56        :meth:`torch.nn.Module.load_state_dict`.
57
58    Args:
59        state_dict (OrderedDict): a state-dict to be loaded to the model.
60        prefix (str): prefix.
61    """
62    keys = list(state_dict.keys())
63    for key in keys:
64        if key.startswith(prefix):
65            newkey = key[len(prefix) :]
66            state_dict[newkey] = state_dict.pop(key)
67
68    # also strip the prefix in metadata if any.
69    if hasattr(state_dict, "_metadata"):
70        keys = list(state_dict._metadata.keys())
71        for key in keys:
72            # for the metadata dict, the key can be:
73            # '': for the DDP module, which we want to remove.
74            # 'module': for the actual model.
75            # 'module.xx.xx': for the rest.
76            if len(key) == 0:
77                continue
78            # handling both, 'module' case and  'module.' cases
79            if key == prefix.replace(".", "") or key.startswith(prefix):
80                newkey = key[len(prefix) :]
81                state_dict._metadata[newkey] = state_dict._metadata.pop(key)
82