1# mypy: allow-untyped-defs 2import collections 3from itertools import repeat 4from typing import Any, Dict, List 5 6 7__all__ = ["consume_prefix_in_state_dict_if_present"] 8 9 10def _ntuple(n, name="parse"): 11 def parse(x): 12 if isinstance(x, collections.abc.Iterable): 13 return tuple(x) 14 return tuple(repeat(x, n)) 15 16 parse.__name__ = name 17 return parse 18 19 20_single = _ntuple(1, "_single") 21_pair = _ntuple(2, "_pair") 22_triple = _ntuple(3, "_triple") 23_quadruple = _ntuple(4, "_quadruple") 24 25 26def _reverse_repeat_tuple(t, n): 27 r"""Reverse the order of `t` and repeat each element for `n` times. 28 29 This can be used to translate padding arg used by Conv and Pooling modules 30 to the ones used by `F.pad`. 31 """ 32 return tuple(x for x in reversed(t) for _ in range(n)) 33 34 35def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]: 36 import torch 37 38 if isinstance(out_size, (int, torch.SymInt)): 39 return out_size 40 if len(defaults) <= len(out_size): 41 raise ValueError(f"Input dimension should be at least {len(out_size) + 1}") 42 return [ 43 v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size) :]) 44 ] 45 46 47def consume_prefix_in_state_dict_if_present( 48 state_dict: Dict[str, Any], 49 prefix: str, 50) -> None: 51 r"""Strip the prefix in state_dict in place, if any. 52 53 ..note:: 54 Given a `state_dict` from a DP/DDP model, a local model can load it by applying 55 `consume_prefix_in_state_dict_if_present(state_dict, "module.")` before calling 56 :meth:`torch.nn.Module.load_state_dict`. 57 58 Args: 59 state_dict (OrderedDict): a state-dict to be loaded to the model. 60 prefix (str): prefix. 61 """ 62 keys = list(state_dict.keys()) 63 for key in keys: 64 if key.startswith(prefix): 65 newkey = key[len(prefix) :] 66 state_dict[newkey] = state_dict.pop(key) 67 68 # also strip the prefix in metadata if any. 69 if hasattr(state_dict, "_metadata"): 70 keys = list(state_dict._metadata.keys()) 71 for key in keys: 72 # for the metadata dict, the key can be: 73 # '': for the DDP module, which we want to remove. 74 # 'module': for the actual model. 75 # 'module.xx.xx': for the rest. 76 if len(key) == 0: 77 continue 78 # handling both, 'module' case and 'module.' cases 79 if key == prefix.replace(".", "") or key.startswith(prefix): 80 newkey = key[len(prefix) :] 81 state_dict._metadata[newkey] = state_dict._metadata.pop(key) 82