xref: /aosp_15_r20/external/pytorch/torch/package/_mangling.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Import mangling.
3See mangling.md for details.
4"""
5import re
6
7
8_mangle_index = 0
9
10
11class PackageMangler:
12    """
13    Used on import, to ensure that all modules imported have a shared mangle parent.
14    """
15
16    def __init__(self) -> None:
17        global _mangle_index
18        self._mangle_index = _mangle_index
19        # Increment the global index
20        _mangle_index += 1
21        # Angle brackets are used so that there is almost no chance of
22        # confusing this module for a real module. Plus, it is Python's
23        # preferred way of denoting special modules.
24        self._mangle_parent = f"<torch_package_{self._mangle_index}>"
25
26    def mangle(self, name) -> str:
27        assert len(name) != 0
28        return self._mangle_parent + "." + name
29
30    def demangle(self, mangled: str) -> str:
31        """
32        Note: This only demangles names that were mangled by this specific
33        PackageMangler. It will pass through names created by a different
34        PackageMangler instance.
35        """
36        if mangled.startswith(self._mangle_parent + "."):
37            return mangled.partition(".")[2]
38
39        # wasn't a mangled name
40        return mangled
41
42    def parent_name(self):
43        return self._mangle_parent
44
45
46def is_mangled(name: str) -> bool:
47    return bool(re.match(r"<torch_package_\d+>", name))
48
49
50def demangle(name: str) -> str:
51    """
52    Note: Unlike PackageMangler.demangle, this version works on any
53    mangled name, irrespective of which PackageMangler created it.
54    """
55    if is_mangled(name):
56        first, sep, last = name.partition(".")
57        # If there is only a base mangle prefix, e.g. '<torch_package_0>',
58        # then return an empty string.
59        return last if len(sep) != 0 else ""
60    return name
61
62
63def get_mangle_prefix(name: str) -> str:
64    return name.partition(".")[0] if is_mangled(name) else name
65