xref: /aosp_15_r20/external/pytorch/torchgen/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport contextlib
4*da0073e9SAndroid Build Coastguard Workerimport functools
5*da0073e9SAndroid Build Coastguard Workerimport hashlib
6*da0073e9SAndroid Build Coastguard Workerimport os
7*da0073e9SAndroid Build Coastguard Workerimport re
8*da0073e9SAndroid Build Coastguard Workerimport sys
9*da0073e9SAndroid Build Coastguard Workerimport textwrap
10*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import fields, is_dataclass
11*da0073e9SAndroid Build Coastguard Workerfrom enum import auto, Enum
12*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path
13*da0073e9SAndroid Build Coastguard Workerfrom typing import (
14*da0073e9SAndroid Build Coastguard Worker    Any,
15*da0073e9SAndroid Build Coastguard Worker    Callable,
16*da0073e9SAndroid Build Coastguard Worker    Generic,
17*da0073e9SAndroid Build Coastguard Worker    Iterable,
18*da0073e9SAndroid Build Coastguard Worker    Iterator,
19*da0073e9SAndroid Build Coastguard Worker    Literal,
20*da0073e9SAndroid Build Coastguard Worker    NoReturn,
21*da0073e9SAndroid Build Coastguard Worker    Sequence,
22*da0073e9SAndroid Build Coastguard Worker    TYPE_CHECKING,
23*da0073e9SAndroid Build Coastguard Worker    TypeVar,
24*da0073e9SAndroid Build Coastguard Worker)
25*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import Self
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerfrom torchgen.code_template import CodeTemplate
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING:
31*da0073e9SAndroid Build Coastguard Worker    from argparse import Namespace
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard WorkerREPO_ROOT = Path(__file__).absolute().parent.parent
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker# Many of these functions share logic for defining both the definition
38*da0073e9SAndroid Build Coastguard Worker# and declaration (for example, the function signature is the same), so
39*da0073e9SAndroid Build Coastguard Worker# we organize them into one function that takes a Target to say which
40*da0073e9SAndroid Build Coastguard Worker# code we want.
41*da0073e9SAndroid Build Coastguard Worker#
42*da0073e9SAndroid Build Coastguard Worker# This is an OPEN enum (we may add more cases to it in the future), so be sure
43*da0073e9SAndroid Build Coastguard Worker# to explicitly specify with Literal[Target.XXX] or Literal[Target.XXX, Target.YYY]
44*da0073e9SAndroid Build Coastguard Worker# what targets are valid for your use.
45*da0073e9SAndroid Build Coastguard Workerclass Target(Enum):
46*da0073e9SAndroid Build Coastguard Worker    # top level namespace (not including at)
47*da0073e9SAndroid Build Coastguard Worker    DEFINITION = auto()
48*da0073e9SAndroid Build Coastguard Worker    DECLARATION = auto()
49*da0073e9SAndroid Build Coastguard Worker    # TORCH_LIBRARY(...) { ... }
50*da0073e9SAndroid Build Coastguard Worker    REGISTRATION = auto()
51*da0073e9SAndroid Build Coastguard Worker    # namespace { ... }
52*da0073e9SAndroid Build Coastguard Worker    ANONYMOUS_DEFINITION = auto()
53*da0073e9SAndroid Build Coastguard Worker    # namespace cpu { ... }
54*da0073e9SAndroid Build Coastguard Worker    NAMESPACED_DEFINITION = auto()
55*da0073e9SAndroid Build Coastguard Worker    NAMESPACED_DECLARATION = auto()
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
59*da0073e9SAndroid Build Coastguard Worker# occurrence of a parameter in the derivative formula
60*da0073e9SAndroid Build Coastguard WorkerIDENT_REGEX = r"(^|\W){}($|\W)"
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker# TODO: Use a real parser here; this will get bamboozled
64*da0073e9SAndroid Build Coastguard Workerdef split_name_params(schema: str) -> tuple[str, list[str]]:
65*da0073e9SAndroid Build Coastguard Worker    m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
66*da0073e9SAndroid Build Coastguard Worker    if m is None:
67*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f"Unsupported function schema: {schema}")
68*da0073e9SAndroid Build Coastguard Worker    name, _, params = m.groups()
69*da0073e9SAndroid Build Coastguard Worker    return name, params.split(", ")
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard WorkerT = TypeVar("T")
73*da0073e9SAndroid Build Coastguard WorkerS = TypeVar("S")
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker# These two functions purposely return generators in analogy to map()
76*da0073e9SAndroid Build Coastguard Worker# so that you don't mix up when you need to list() them
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker# Map over function that may return None; omit Nones from output sequence
80*da0073e9SAndroid Build Coastguard Workerdef mapMaybe(func: Callable[[T], S | None], xs: Iterable[T]) -> Iterator[S]:
81*da0073e9SAndroid Build Coastguard Worker    for x in xs:
82*da0073e9SAndroid Build Coastguard Worker        r = func(x)
83*da0073e9SAndroid Build Coastguard Worker        if r is not None:
84*da0073e9SAndroid Build Coastguard Worker            yield r
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker# Map over function that returns sequences and cat them all together
88*da0073e9SAndroid Build Coastguard Workerdef concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
89*da0073e9SAndroid Build Coastguard Worker    for x in xs:
90*da0073e9SAndroid Build Coastguard Worker        yield from func(x)
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker# Conveniently add error context to exceptions raised.  Lets us
94*da0073e9SAndroid Build Coastguard Worker# easily say that an error occurred while processing a specific
95*da0073e9SAndroid Build Coastguard Worker# context.
96*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
97*da0073e9SAndroid Build Coastguard Workerdef context(msg_fn: Callable[[], str]) -> Iterator[None]:
98*da0073e9SAndroid Build Coastguard Worker    try:
99*da0073e9SAndroid Build Coastguard Worker        yield
100*da0073e9SAndroid Build Coastguard Worker    except Exception as e:
101*da0073e9SAndroid Build Coastguard Worker        # TODO: this does the wrong thing with KeyError
102*da0073e9SAndroid Build Coastguard Worker        msg = msg_fn()
103*da0073e9SAndroid Build Coastguard Worker        msg = textwrap.indent(msg, "  ")
104*da0073e9SAndroid Build Coastguard Worker        msg = f"{e.args[0]}\n{msg}" if e.args else msg
105*da0073e9SAndroid Build Coastguard Worker        e.args = (msg,) + e.args[1:]
106*da0073e9SAndroid Build Coastguard Worker        raise
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker# A little trick from https://github.com/python/mypy/issues/6366
110*da0073e9SAndroid Build Coastguard Worker# for getting mypy to do exhaustiveness checking
111*da0073e9SAndroid Build Coastguard Worker# TODO: put this somewhere else, maybe
112*da0073e9SAndroid Build Coastguard Workerdef assert_never(x: NoReturn) -> NoReturn:
113*da0073e9SAndroid Build Coastguard Worker    raise AssertionError(f"Unhandled type: {type(x).__name__}")
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache(maxsize=None)
117*da0073e9SAndroid Build Coastguard Workerdef _read_template(template_fn: str) -> CodeTemplate:
118*da0073e9SAndroid Build Coastguard Worker    return CodeTemplate.from_file(template_fn)
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker# String hash that's stable across different executions, unlike builtin hash
122*da0073e9SAndroid Build Coastguard Workerdef string_stable_hash(s: str) -> int:
123*da0073e9SAndroid Build Coastguard Worker    sha1 = hashlib.sha1(s.encode("latin1")).digest()
124*da0073e9SAndroid Build Coastguard Worker    return int.from_bytes(sha1, byteorder="little")
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker# A small abstraction for writing out generated files and keeping track
128*da0073e9SAndroid Build Coastguard Worker# of what files have been written (so you can write out a list of output
129*da0073e9SAndroid Build Coastguard Worker# files)
130*da0073e9SAndroid Build Coastguard Workerclass FileManager:
131*da0073e9SAndroid Build Coastguard Worker    install_dir: str
132*da0073e9SAndroid Build Coastguard Worker    template_dir: str
133*da0073e9SAndroid Build Coastguard Worker    dry_run: bool
134*da0073e9SAndroid Build Coastguard Worker    filenames: set[str]
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker    def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
137*da0073e9SAndroid Build Coastguard Worker        self.install_dir = install_dir
138*da0073e9SAndroid Build Coastguard Worker        self.template_dir = template_dir
139*da0073e9SAndroid Build Coastguard Worker        self.filenames = set()
140*da0073e9SAndroid Build Coastguard Worker        self.dry_run = dry_run
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker    def _write_if_changed(self, filename: str, contents: str) -> None:
143*da0073e9SAndroid Build Coastguard Worker        old_contents: str | None
144*da0073e9SAndroid Build Coastguard Worker        try:
145*da0073e9SAndroid Build Coastguard Worker            with open(filename) as f:
146*da0073e9SAndroid Build Coastguard Worker                old_contents = f.read()
147*da0073e9SAndroid Build Coastguard Worker        except OSError:
148*da0073e9SAndroid Build Coastguard Worker            old_contents = None
149*da0073e9SAndroid Build Coastguard Worker        if contents != old_contents:
150*da0073e9SAndroid Build Coastguard Worker            # Create output directory if it doesn't exist
151*da0073e9SAndroid Build Coastguard Worker            os.makedirs(os.path.dirname(filename), exist_ok=True)
152*da0073e9SAndroid Build Coastguard Worker            with open(filename, "w") as f:
153*da0073e9SAndroid Build Coastguard Worker                f.write(contents)
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker    # Read from template file and replace pattern with callable (type could be dict or str).
156*da0073e9SAndroid Build Coastguard Worker    def substitute_with_template(
157*da0073e9SAndroid Build Coastguard Worker        self, template_fn: str, env_callable: Callable[[], str | dict[str, Any]]
158*da0073e9SAndroid Build Coastguard Worker    ) -> str:
159*da0073e9SAndroid Build Coastguard Worker        template_path = os.path.join(self.template_dir, template_fn)
160*da0073e9SAndroid Build Coastguard Worker        env = env_callable()
161*da0073e9SAndroid Build Coastguard Worker        if isinstance(env, dict):
162*da0073e9SAndroid Build Coastguard Worker            if "generated_comment" not in env:
163*da0073e9SAndroid Build Coastguard Worker                generator_default = REPO_ROOT / "torchgen" / "gen.py"
164*da0073e9SAndroid Build Coastguard Worker                try:
165*da0073e9SAndroid Build Coastguard Worker                    generator = Path(
166*da0073e9SAndroid Build Coastguard Worker                        sys.modules["__main__"].__file__ or generator_default
167*da0073e9SAndroid Build Coastguard Worker                    ).absolute()
168*da0073e9SAndroid Build Coastguard Worker                except (KeyError, AttributeError):
169*da0073e9SAndroid Build Coastguard Worker                    generator = generator_default.absolute()
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker                try:
172*da0073e9SAndroid Build Coastguard Worker                    generator_path = generator.relative_to(REPO_ROOT).as_posix()
173*da0073e9SAndroid Build Coastguard Worker                except ValueError:
174*da0073e9SAndroid Build Coastguard Worker                    generator_path = generator.name
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker                env = {
177*da0073e9SAndroid Build Coastguard Worker                    **env,  # copy the original dict instead of mutating it
178*da0073e9SAndroid Build Coastguard Worker                    "generated_comment": (
179*da0073e9SAndroid Build Coastguard Worker                        "@" + f"generated by {generator_path} from {template_fn}"
180*da0073e9SAndroid Build Coastguard Worker                    ),
181*da0073e9SAndroid Build Coastguard Worker                }
182*da0073e9SAndroid Build Coastguard Worker            template = _read_template(template_path)
183*da0073e9SAndroid Build Coastguard Worker            return template.substitute(env)
184*da0073e9SAndroid Build Coastguard Worker        elif isinstance(env, str):
185*da0073e9SAndroid Build Coastguard Worker            return env
186*da0073e9SAndroid Build Coastguard Worker        else:
187*da0073e9SAndroid Build Coastguard Worker            assert_never(env)
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker    def write_with_template(
190*da0073e9SAndroid Build Coastguard Worker        self,
191*da0073e9SAndroid Build Coastguard Worker        filename: str,
192*da0073e9SAndroid Build Coastguard Worker        template_fn: str,
193*da0073e9SAndroid Build Coastguard Worker        env_callable: Callable[[], str | dict[str, Any]],
194*da0073e9SAndroid Build Coastguard Worker    ) -> None:
195*da0073e9SAndroid Build Coastguard Worker        filename = f"{self.install_dir}/{filename}"
196*da0073e9SAndroid Build Coastguard Worker        assert filename not in self.filenames, "duplicate file write {filename}"
197*da0073e9SAndroid Build Coastguard Worker        self.filenames.add(filename)
198*da0073e9SAndroid Build Coastguard Worker        if not self.dry_run:
199*da0073e9SAndroid Build Coastguard Worker            substitute_out = self.substitute_with_template(
200*da0073e9SAndroid Build Coastguard Worker                template_fn=template_fn,
201*da0073e9SAndroid Build Coastguard Worker                env_callable=env_callable,
202*da0073e9SAndroid Build Coastguard Worker            )
203*da0073e9SAndroid Build Coastguard Worker            self._write_if_changed(filename=filename, contents=substitute_out)
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker    def write(
206*da0073e9SAndroid Build Coastguard Worker        self,
207*da0073e9SAndroid Build Coastguard Worker        filename: str,
208*da0073e9SAndroid Build Coastguard Worker        env_callable: Callable[[], str | dict[str, Any]],
209*da0073e9SAndroid Build Coastguard Worker    ) -> None:
210*da0073e9SAndroid Build Coastguard Worker        self.write_with_template(filename, filename, env_callable)
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker    def write_sharded(
213*da0073e9SAndroid Build Coastguard Worker        self,
214*da0073e9SAndroid Build Coastguard Worker        filename: str,
215*da0073e9SAndroid Build Coastguard Worker        items: Iterable[T],
216*da0073e9SAndroid Build Coastguard Worker        *,
217*da0073e9SAndroid Build Coastguard Worker        key_fn: Callable[[T], str],
218*da0073e9SAndroid Build Coastguard Worker        env_callable: Callable[[T], dict[str, list[str]]],
219*da0073e9SAndroid Build Coastguard Worker        num_shards: int,
220*da0073e9SAndroid Build Coastguard Worker        base_env: dict[str, Any] | None = None,
221*da0073e9SAndroid Build Coastguard Worker        sharded_keys: set[str],
222*da0073e9SAndroid Build Coastguard Worker    ) -> None:
223*da0073e9SAndroid Build Coastguard Worker        everything: dict[str, Any] = {"shard_id": "Everything"}
224*da0073e9SAndroid Build Coastguard Worker        shards: list[dict[str, Any]] = [
225*da0073e9SAndroid Build Coastguard Worker            {"shard_id": f"_{i}"} for i in range(num_shards)
226*da0073e9SAndroid Build Coastguard Worker        ]
227*da0073e9SAndroid Build Coastguard Worker        all_shards = [everything] + shards
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker        if base_env is not None:
230*da0073e9SAndroid Build Coastguard Worker            for shard in all_shards:
231*da0073e9SAndroid Build Coastguard Worker                shard.update(base_env)
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker        for key in sharded_keys:
234*da0073e9SAndroid Build Coastguard Worker            for shard in all_shards:
235*da0073e9SAndroid Build Coastguard Worker                if key in shard:
236*da0073e9SAndroid Build Coastguard Worker                    assert isinstance(
237*da0073e9SAndroid Build Coastguard Worker                        shard[key], list
238*da0073e9SAndroid Build Coastguard Worker                    ), "sharded keys in base_env must be a list"
239*da0073e9SAndroid Build Coastguard Worker                    shard[key] = shard[key].copy()
240*da0073e9SAndroid Build Coastguard Worker                else:
241*da0073e9SAndroid Build Coastguard Worker                    shard[key] = []
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker        def merge_env(into: dict[str, list[str]], from_: dict[str, list[str]]) -> None:
244*da0073e9SAndroid Build Coastguard Worker            for k, v in from_.items():
245*da0073e9SAndroid Build Coastguard Worker                assert k in sharded_keys, f"undeclared sharded key {k}"
246*da0073e9SAndroid Build Coastguard Worker                into[k] += v
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker        if self.dry_run:
249*da0073e9SAndroid Build Coastguard Worker            # Dry runs don't write any templates, so incomplete environments are fine
250*da0073e9SAndroid Build Coastguard Worker            items = ()
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker        for item in items:
253*da0073e9SAndroid Build Coastguard Worker            key = key_fn(item)
254*da0073e9SAndroid Build Coastguard Worker            sid = string_stable_hash(key) % num_shards
255*da0073e9SAndroid Build Coastguard Worker            env = env_callable(item)
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker            merge_env(shards[sid], env)
258*da0073e9SAndroid Build Coastguard Worker            merge_env(everything, env)
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        dot_pos = filename.rfind(".")
261*da0073e9SAndroid Build Coastguard Worker        if dot_pos == -1:
262*da0073e9SAndroid Build Coastguard Worker            dot_pos = len(filename)
263*da0073e9SAndroid Build Coastguard Worker        base_filename = filename[:dot_pos]
264*da0073e9SAndroid Build Coastguard Worker        extension = filename[dot_pos:]
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker        for shard in all_shards:
267*da0073e9SAndroid Build Coastguard Worker            shard_id = shard["shard_id"]
268*da0073e9SAndroid Build Coastguard Worker            self.write_with_template(
269*da0073e9SAndroid Build Coastguard Worker                f"{base_filename}{shard_id}{extension}", filename, lambda: shard
270*da0073e9SAndroid Build Coastguard Worker            )
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker        # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
273*da0073e9SAndroid Build Coastguard Worker        self.filenames.discard(
274*da0073e9SAndroid Build Coastguard Worker            f"{self.install_dir}/{base_filename}Everything{extension}"
275*da0073e9SAndroid Build Coastguard Worker        )
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker    def write_outputs(self, variable_name: str, filename: str) -> None:
278*da0073e9SAndroid Build Coastguard Worker        """Write a file containing the list of all outputs which are
279*da0073e9SAndroid Build Coastguard Worker        generated by this script."""
280*da0073e9SAndroid Build Coastguard Worker        content = "set({}\n    {})".format(
281*da0073e9SAndroid Build Coastguard Worker            variable_name,
282*da0073e9SAndroid Build Coastguard Worker            "\n    ".join('"' + name + '"' for name in sorted(self.filenames)),
283*da0073e9SAndroid Build Coastguard Worker        )
284*da0073e9SAndroid Build Coastguard Worker        self._write_if_changed(filename, content)
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker    def template_dir_for_comments(self) -> str:
287*da0073e9SAndroid Build Coastguard Worker        """
288*da0073e9SAndroid Build Coastguard Worker        This needs to be deterministic. The template dir is an absolute path
289*da0073e9SAndroid Build Coastguard Worker        that varies across builds. So, just use the path relative to this file,
290*da0073e9SAndroid Build Coastguard Worker        which will point to the codegen source but will be stable.
291*da0073e9SAndroid Build Coastguard Worker        """
292*da0073e9SAndroid Build Coastguard Worker        return os.path.relpath(self.template_dir, os.path.dirname(__file__))
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker# Helper function to generate file manager
296*da0073e9SAndroid Build Coastguard Workerdef make_file_manager(
297*da0073e9SAndroid Build Coastguard Worker    options: Namespace, install_dir: str | None = None
298*da0073e9SAndroid Build Coastguard Worker) -> FileManager:
299*da0073e9SAndroid Build Coastguard Worker    template_dir = os.path.join(options.source_path, "templates")
300*da0073e9SAndroid Build Coastguard Worker    install_dir = install_dir if install_dir else options.install_dir
301*da0073e9SAndroid Build Coastguard Worker    return FileManager(
302*da0073e9SAndroid Build Coastguard Worker        install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run
303*da0073e9SAndroid Build Coastguard Worker    )
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker# Helper function to create a pretty representation for dataclasses
307*da0073e9SAndroid Build Coastguard Workerdef dataclass_repr(
308*da0073e9SAndroid Build Coastguard Worker    obj: Any,
309*da0073e9SAndroid Build Coastguard Worker    indent: int = 0,
310*da0073e9SAndroid Build Coastguard Worker    width: int = 80,
311*da0073e9SAndroid Build Coastguard Worker) -> str:
312*da0073e9SAndroid Build Coastguard Worker    # built-in pprint module support dataclasses from python 3.10
313*da0073e9SAndroid Build Coastguard Worker    if sys.version_info >= (3, 10):
314*da0073e9SAndroid Build Coastguard Worker        from pprint import pformat
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker        return pformat(obj, indent, width)
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker    return _pformat(obj, indent=indent, width=width)
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Workerdef _pformat(
322*da0073e9SAndroid Build Coastguard Worker    obj: Any,
323*da0073e9SAndroid Build Coastguard Worker    indent: int,
324*da0073e9SAndroid Build Coastguard Worker    width: int,
325*da0073e9SAndroid Build Coastguard Worker    curr_indent: int = 0,
326*da0073e9SAndroid Build Coastguard Worker) -> str:
327*da0073e9SAndroid Build Coastguard Worker    assert is_dataclass(obj), f"obj should be a dataclass, received: {type(obj)}"
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker    class_name = obj.__class__.__name__
330*da0073e9SAndroid Build Coastguard Worker    # update current indentation level with class name
331*da0073e9SAndroid Build Coastguard Worker    curr_indent += len(class_name) + 1
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker    fields_list = [(f.name, getattr(obj, f.name)) for f in fields(obj) if f.repr]
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker    fields_str = []
336*da0073e9SAndroid Build Coastguard Worker    for name, attr in fields_list:
337*da0073e9SAndroid Build Coastguard Worker        # update the current indent level with the field name
338*da0073e9SAndroid Build Coastguard Worker        # dict, list, set and tuple also add indent as done in pprint
339*da0073e9SAndroid Build Coastguard Worker        _curr_indent = curr_indent + len(name) + 1
340*da0073e9SAndroid Build Coastguard Worker        if is_dataclass(attr):
341*da0073e9SAndroid Build Coastguard Worker            str_repr = _pformat(attr, indent, width, _curr_indent)
342*da0073e9SAndroid Build Coastguard Worker        elif isinstance(attr, dict):
343*da0073e9SAndroid Build Coastguard Worker            str_repr = _format_dict(attr, indent, width, _curr_indent)
344*da0073e9SAndroid Build Coastguard Worker        elif isinstance(attr, (list, set, tuple)):
345*da0073e9SAndroid Build Coastguard Worker            str_repr = _format_list(attr, indent, width, _curr_indent)
346*da0073e9SAndroid Build Coastguard Worker        else:
347*da0073e9SAndroid Build Coastguard Worker            str_repr = repr(attr)
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker        fields_str.append(f"{name}={str_repr}")
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker    indent_str = curr_indent * " "
352*da0073e9SAndroid Build Coastguard Worker    body = f",\n{indent_str}".join(fields_str)
353*da0073e9SAndroid Build Coastguard Worker    return f"{class_name}({body})"
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Workerdef _format_dict(
357*da0073e9SAndroid Build Coastguard Worker    attr: dict[Any, Any],
358*da0073e9SAndroid Build Coastguard Worker    indent: int,
359*da0073e9SAndroid Build Coastguard Worker    width: int,
360*da0073e9SAndroid Build Coastguard Worker    curr_indent: int,
361*da0073e9SAndroid Build Coastguard Worker) -> str:
362*da0073e9SAndroid Build Coastguard Worker    curr_indent += indent + 3
363*da0073e9SAndroid Build Coastguard Worker    dict_repr = []
364*da0073e9SAndroid Build Coastguard Worker    for k, v in attr.items():
365*da0073e9SAndroid Build Coastguard Worker        k_repr = repr(k)
366*da0073e9SAndroid Build Coastguard Worker        v_str = (
367*da0073e9SAndroid Build Coastguard Worker            _pformat(v, indent, width, curr_indent + len(k_repr))
368*da0073e9SAndroid Build Coastguard Worker            if is_dataclass(v)
369*da0073e9SAndroid Build Coastguard Worker            else repr(v)
370*da0073e9SAndroid Build Coastguard Worker        )
371*da0073e9SAndroid Build Coastguard Worker        dict_repr.append(f"{k_repr}: {v_str}")
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker    return _format(dict_repr, indent, width, curr_indent, "{", "}")
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Workerdef _format_list(
377*da0073e9SAndroid Build Coastguard Worker    attr: list[Any] | set[Any] | tuple[Any, ...],
378*da0073e9SAndroid Build Coastguard Worker    indent: int,
379*da0073e9SAndroid Build Coastguard Worker    width: int,
380*da0073e9SAndroid Build Coastguard Worker    curr_indent: int,
381*da0073e9SAndroid Build Coastguard Worker) -> str:
382*da0073e9SAndroid Build Coastguard Worker    curr_indent += indent + 1
383*da0073e9SAndroid Build Coastguard Worker    list_repr = [
384*da0073e9SAndroid Build Coastguard Worker        _pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l)
385*da0073e9SAndroid Build Coastguard Worker        for l in attr
386*da0073e9SAndroid Build Coastguard Worker    ]
387*da0073e9SAndroid Build Coastguard Worker    start, end = ("[", "]") if isinstance(attr, list) else ("(", ")")
388*da0073e9SAndroid Build Coastguard Worker    return _format(list_repr, indent, width, curr_indent, start, end)
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Workerdef _format(
392*da0073e9SAndroid Build Coastguard Worker    fields_str: list[str],
393*da0073e9SAndroid Build Coastguard Worker    indent: int,
394*da0073e9SAndroid Build Coastguard Worker    width: int,
395*da0073e9SAndroid Build Coastguard Worker    curr_indent: int,
396*da0073e9SAndroid Build Coastguard Worker    start: str,
397*da0073e9SAndroid Build Coastguard Worker    end: str,
398*da0073e9SAndroid Build Coastguard Worker) -> str:
399*da0073e9SAndroid Build Coastguard Worker    delimiter, curr_indent_str = "", ""
400*da0073e9SAndroid Build Coastguard Worker    # if it exceed the max width then we place one element per line
401*da0073e9SAndroid Build Coastguard Worker    if len(repr(fields_str)) >= width:
402*da0073e9SAndroid Build Coastguard Worker        delimiter = "\n"
403*da0073e9SAndroid Build Coastguard Worker        curr_indent_str = " " * curr_indent
404*da0073e9SAndroid Build Coastguard Worker
405*da0073e9SAndroid Build Coastguard Worker    indent_str = " " * indent
406*da0073e9SAndroid Build Coastguard Worker    body = f", {delimiter}{curr_indent_str}".join(fields_str)
407*da0073e9SAndroid Build Coastguard Worker    return f"{start}{indent_str}{body}{end}"
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Workerclass NamespaceHelper:
411*da0073e9SAndroid Build Coastguard Worker    """A helper for constructing the namespace open and close strings for a nested set of namespaces.
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker    e.g. for namespace_str torch::lazy,
414*da0073e9SAndroid Build Coastguard Worker
415*da0073e9SAndroid Build Coastguard Worker    prologue:
416*da0073e9SAndroid Build Coastguard Worker    namespace torch {
417*da0073e9SAndroid Build Coastguard Worker    namespace lazy {
418*da0073e9SAndroid Build Coastguard Worker
419*da0073e9SAndroid Build Coastguard Worker    epilogue:
420*da0073e9SAndroid Build Coastguard Worker    } // namespace lazy
421*da0073e9SAndroid Build Coastguard Worker    } // namespace torch
422*da0073e9SAndroid Build Coastguard Worker    """
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker    def __init__(
425*da0073e9SAndroid Build Coastguard Worker        self, namespace_str: str, entity_name: str = "", max_level: int = 2
426*da0073e9SAndroid Build Coastguard Worker    ) -> None:
427*da0073e9SAndroid Build Coastguard Worker        # cpp_namespace can be a colon joined string such as torch::lazy
428*da0073e9SAndroid Build Coastguard Worker        cpp_namespaces = namespace_str.split("::")
429*da0073e9SAndroid Build Coastguard Worker        assert (
430*da0073e9SAndroid Build Coastguard Worker            len(cpp_namespaces) <= max_level
431*da0073e9SAndroid Build Coastguard Worker        ), f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}."
432*da0073e9SAndroid Build Coastguard Worker        self.cpp_namespace_ = namespace_str
433*da0073e9SAndroid Build Coastguard Worker        self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
434*da0073e9SAndroid Build Coastguard Worker        self.epilogue_ = "\n".join(
435*da0073e9SAndroid Build Coastguard Worker            [f"}} // namespace {n}" for n in reversed(cpp_namespaces)]
436*da0073e9SAndroid Build Coastguard Worker        )
437*da0073e9SAndroid Build Coastguard Worker        self.namespaces_ = cpp_namespaces
438*da0073e9SAndroid Build Coastguard Worker        self.entity_name_ = entity_name
439*da0073e9SAndroid Build Coastguard Worker
440*da0073e9SAndroid Build Coastguard Worker    @staticmethod
441*da0073e9SAndroid Build Coastguard Worker    def from_namespaced_entity(
442*da0073e9SAndroid Build Coastguard Worker        namespaced_entity: str, max_level: int = 2
443*da0073e9SAndroid Build Coastguard Worker    ) -> NamespaceHelper:
444*da0073e9SAndroid Build Coastguard Worker        """
445*da0073e9SAndroid Build Coastguard Worker        Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
446*da0073e9SAndroid Build Coastguard Worker        """
447*da0073e9SAndroid Build Coastguard Worker        names = namespaced_entity.split("::")
448*da0073e9SAndroid Build Coastguard Worker        entity_name = names[-1]
449*da0073e9SAndroid Build Coastguard Worker        namespace_str = "::".join(names[:-1])
450*da0073e9SAndroid Build Coastguard Worker        return NamespaceHelper(
451*da0073e9SAndroid Build Coastguard Worker            namespace_str=namespace_str, entity_name=entity_name, max_level=max_level
452*da0073e9SAndroid Build Coastguard Worker        )
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker    @property
455*da0073e9SAndroid Build Coastguard Worker    def prologue(self) -> str:
456*da0073e9SAndroid Build Coastguard Worker        return self.prologue_
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker    @property
459*da0073e9SAndroid Build Coastguard Worker    def epilogue(self) -> str:
460*da0073e9SAndroid Build Coastguard Worker        return self.epilogue_
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker    @property
463*da0073e9SAndroid Build Coastguard Worker    def entity_name(self) -> str:
464*da0073e9SAndroid Build Coastguard Worker        return self.entity_name_
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker    # Only allow certain level of namespaces
467*da0073e9SAndroid Build Coastguard Worker    def get_cpp_namespace(self, default: str = "") -> str:
468*da0073e9SAndroid Build Coastguard Worker        """
469*da0073e9SAndroid Build Coastguard Worker        Return the namespace string from joining all the namespaces by "::" (hence no leading "::").
470*da0073e9SAndroid Build Coastguard Worker        Return default if namespace string is empty.
471*da0073e9SAndroid Build Coastguard Worker        """
472*da0073e9SAndroid Build Coastguard Worker        return self.cpp_namespace_ if self.cpp_namespace_ else default
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Workerclass OrderedSet(Generic[T]):
476*da0073e9SAndroid Build Coastguard Worker    storage: dict[T, Literal[None]]
477*da0073e9SAndroid Build Coastguard Worker
478*da0073e9SAndroid Build Coastguard Worker    def __init__(self, iterable: Iterable[T] | None = None) -> None:
479*da0073e9SAndroid Build Coastguard Worker        if iterable is None:
480*da0073e9SAndroid Build Coastguard Worker            self.storage = {}
481*da0073e9SAndroid Build Coastguard Worker        else:
482*da0073e9SAndroid Build Coastguard Worker            self.storage = dict.fromkeys(iterable)
483*da0073e9SAndroid Build Coastguard Worker
484*da0073e9SAndroid Build Coastguard Worker    def __contains__(self, item: T) -> bool:
485*da0073e9SAndroid Build Coastguard Worker        return item in self.storage
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker    def __iter__(self) -> Iterator[T]:
488*da0073e9SAndroid Build Coastguard Worker        return iter(self.storage.keys())
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker    def update(self, items: OrderedSet[T]) -> None:
491*da0073e9SAndroid Build Coastguard Worker        self.storage.update(items.storage)
492*da0073e9SAndroid Build Coastguard Worker
493*da0073e9SAndroid Build Coastguard Worker    def add(self, item: T) -> None:
494*da0073e9SAndroid Build Coastguard Worker        self.storage[item] = None
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker    def copy(self) -> OrderedSet[T]:
497*da0073e9SAndroid Build Coastguard Worker        ret: OrderedSet[T] = OrderedSet()
498*da0073e9SAndroid Build Coastguard Worker        ret.storage = self.storage.copy()
499*da0073e9SAndroid Build Coastguard Worker        return ret
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker    @staticmethod
502*da0073e9SAndroid Build Coastguard Worker    def union(*args: OrderedSet[T]) -> OrderedSet[T]:
503*da0073e9SAndroid Build Coastguard Worker        ret = args[0].copy()
504*da0073e9SAndroid Build Coastguard Worker        for s in args[1:]:
505*da0073e9SAndroid Build Coastguard Worker            ret.update(s)
506*da0073e9SAndroid Build Coastguard Worker        return ret
507*da0073e9SAndroid Build Coastguard Worker
508*da0073e9SAndroid Build Coastguard Worker    def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]:
509*da0073e9SAndroid Build Coastguard Worker        return OrderedSet.union(self, other)
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker    def __ior__(self, other: OrderedSet[T]) -> Self:
512*da0073e9SAndroid Build Coastguard Worker        self.update(other)
513*da0073e9SAndroid Build Coastguard Worker        return self
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Worker    def __eq__(self, other: object) -> bool:
516*da0073e9SAndroid Build Coastguard Worker        if isinstance(other, OrderedSet):
517*da0073e9SAndroid Build Coastguard Worker            return self.storage == other.storage
518*da0073e9SAndroid Build Coastguard Worker        else:
519*da0073e9SAndroid Build Coastguard Worker            return set(self.storage.keys()) == other
520