1# mypy: allow-untyped-defs 2import re 3from typing import Iterable, Union 4 5 6GlobPattern = Union[str, Iterable[str]] 7 8 9class GlobGroup: 10 """A set of patterns that candidate strings will be matched against. 11 12 A candidate is composed of a list of segments separated by ``separator``, e.g. "foo.bar.baz". 13 14 A pattern contains one or more segments. Segments can be: 15 - A literal string (e.g. "foo"), which matches exactly. 16 - A string containing a wildcard (e.g. "torch*", or "foo*baz*"). The wildcard matches 17 any string, including the empty string. 18 - A double wildcard ("**"). This matches against zero or more complete segments. 19 20 Examples: 21 ``torch.**``: matches ``torch`` and all its submodules, e.g. ``torch.nn`` and ``torch.nn.functional``. 22 ``torch.*``: matches ``torch.nn`` or ``torch.functional``, but not ``torch.nn.functional``. 23 ``torch*.**``: matches ``torch``, ``torchvision``, and all their submodules. 24 25 A candidates will match the ``GlobGroup`` if it matches any of the ``include`` patterns and 26 none of the ``exclude`` patterns. 27 28 Args: 29 include (Union[str, Iterable[str]]): A string or list of strings, 30 each representing a pattern to be matched against. A candidate 31 will match if it matches *any* include pattern 32 exclude (Union[str, Iterable[str]]): A string or list of strings, 33 each representing a pattern to be matched against. A candidate 34 will be excluded from matching if it matches *any* exclude pattern. 35 separator (str): A string that delimits segments in candidates and 36 patterns. By default this is "." which corresponds to how modules are 37 named in Python. Another common value for this is "/", which is 38 the Unix path separator. 39 """ 40 41 def __init__( 42 self, include: GlobPattern, *, exclude: GlobPattern = (), separator: str = "." 43 ): 44 self._dbg = f"GlobGroup(include={include}, exclude={exclude})" 45 self.include = GlobGroup._glob_list(include, separator) 46 self.exclude = GlobGroup._glob_list(exclude, separator) 47 self.separator = separator 48 49 def __str__(self): 50 return self._dbg 51 52 def __repr__(self): 53 return self._dbg 54 55 def matches(self, candidate: str) -> bool: 56 candidate = self.separator + candidate 57 return any(p.fullmatch(candidate) for p in self.include) and all( 58 not p.fullmatch(candidate) for p in self.exclude 59 ) 60 61 @staticmethod 62 def _glob_list(elems: GlobPattern, separator: str = "."): 63 if isinstance(elems, str): 64 return [GlobGroup._glob_to_re(elems, separator)] 65 else: 66 return [GlobGroup._glob_to_re(e, separator) for e in elems] 67 68 @staticmethod 69 def _glob_to_re(pattern: str, separator: str = "."): 70 # to avoid corner cases for the first component, we prefix the candidate string 71 # with '.' so `import torch` will regex against `.torch`, assuming '.' is the separator 72 def component_to_re(component): 73 if "**" in component: 74 if component == "**": 75 return "(" + re.escape(separator) + "[^" + separator + "]+)*" 76 else: 77 raise ValueError("** can only appear as an entire path segment") 78 else: 79 return re.escape(separator) + ("[^" + separator + "]*").join( 80 re.escape(x) for x in component.split("*") 81 ) 82 83 result = "".join(component_to_re(c) for c in pattern.split(separator)) 84 return re.compile(result) 85