xref: /aosp_15_r20/external/pytorch/torch/package/file_structure_representation.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Dict, List
3
4from .glob_group import GlobGroup, GlobPattern
5
6
7__all__ = ["Directory"]
8
9
10class Directory:
11    """A file structure representation. Organized as Directory nodes that have lists of
12    their Directory children. Directories for a package are created by calling
13    :meth:`PackageImporter.file_structure`."""
14
15    def __init__(self, name: str, is_dir: bool):
16        self.name = name
17        self.is_dir = is_dir
18        self.children: Dict[str, Directory] = {}
19
20    def _get_dir(self, dirs: List[str]) -> "Directory":
21        """Builds path of Directories if not yet built and returns last directory
22        in list.
23
24        Args:
25            dirs (List[str]): List of directory names that are treated like a path.
26
27        Returns:
28            :class:`Directory`: The last Directory specified in the dirs list.
29        """
30        if len(dirs) == 0:
31            return self
32        dir_name = dirs[0]
33        if dir_name not in self.children:
34            self.children[dir_name] = Directory(dir_name, True)
35        return self.children[dir_name]._get_dir(dirs[1:])
36
37    def _add_file(self, file_path: str):
38        """Adds a file to a Directory.
39
40        Args:
41            file_path (str): Path of file to add. Last element is added as a file while
42                other paths items are added as directories.
43        """
44        *dirs, file = file_path.split("/")
45        dir = self._get_dir(dirs)
46        dir.children[file] = Directory(file, False)
47
48    def has_file(self, filename: str) -> bool:
49        """Checks if a file is present in a :class:`Directory`.
50
51        Args:
52            filename (str): Path of file to search for.
53        Returns:
54            bool: If a :class:`Directory` contains the specified file.
55        """
56        lineage = filename.split("/", maxsplit=1)
57        child = lineage[0]
58        grandchildren = lineage[1] if len(lineage) > 1 else None
59        if child in self.children.keys():
60            if grandchildren is None:
61                return True
62            else:
63                return self.children[child].has_file(grandchildren)
64        return False
65
66    def __str__(self):
67        str_list: List[str] = []
68        self._stringify_tree(str_list)
69        return "".join(str_list)
70
71    def _stringify_tree(
72        self,
73        str_list: List[str],
74        preamble: str = "",
75        dir_ptr: str = "\u2500\u2500\u2500 ",
76    ):
77        """Recursive method to generate print-friendly version of a Directory."""
78        space = "    "
79        branch = "\u2502   "
80        tee = "\u251c\u2500\u2500 "
81        last = "\u2514\u2500\u2500 "
82
83        # add this directory's representation
84        str_list.append(f"{preamble}{dir_ptr}{self.name}\n")
85
86        # add directory's children representations
87        if dir_ptr == tee:
88            preamble = preamble + branch
89        else:
90            preamble = preamble + space
91
92        file_keys: List[str] = []
93        dir_keys: List[str] = []
94        for key, val in self.children.items():
95            if val.is_dir:
96                dir_keys.append(key)
97            else:
98                file_keys.append(key)
99
100        for index, key in enumerate(sorted(dir_keys)):
101            if (index == len(dir_keys) - 1) and len(file_keys) == 0:
102                self.children[key]._stringify_tree(str_list, preamble, last)
103            else:
104                self.children[key]._stringify_tree(str_list, preamble, tee)
105        for index, file in enumerate(sorted(file_keys)):
106            pointer = last if (index == len(file_keys) - 1) else tee
107            str_list.append(f"{preamble}{pointer}{file}\n")
108
109
110def _create_directory_from_file_list(
111    filename: str,
112    file_list: List[str],
113    include: "GlobPattern" = "**",
114    exclude: "GlobPattern" = (),
115) -> Directory:
116    """Return a :class:`Directory` file structure representation created from a list of files.
117
118    Args:
119        filename (str): The name given to the top-level directory that will be the
120            relative root for all file paths found in the file_list.
121
122        file_list (List[str]): List of files to add to the top-level directory.
123
124        include (Union[List[str], str]): An optional pattern that limits what is included from the file_list to
125            files whose name matches the pattern.
126
127        exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern.
128
129    Returns:
130            :class:`Directory`: a :class:`Directory` file structure representation created from a list of files.
131    """
132    glob_pattern = GlobGroup(include, exclude=exclude, separator="/")
133
134    top_dir = Directory(filename, True)
135    for file in file_list:
136        if glob_pattern.matches(file):
137            top_dir._add_file(file)
138    return top_dir
139