1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport contextlib 4*da0073e9SAndroid Build Coastguard Workerimport functools 5*da0073e9SAndroid Build Coastguard Workerimport hashlib 6*da0073e9SAndroid Build Coastguard Workerimport os 7*da0073e9SAndroid Build Coastguard Workerimport re 8*da0073e9SAndroid Build Coastguard Workerimport sys 9*da0073e9SAndroid Build Coastguard Workerimport textwrap 10*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import fields, is_dataclass 11*da0073e9SAndroid Build Coastguard Workerfrom enum import auto, Enum 12*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path 13*da0073e9SAndroid Build Coastguard Workerfrom typing import ( 14*da0073e9SAndroid Build Coastguard Worker Any, 15*da0073e9SAndroid Build Coastguard Worker Callable, 16*da0073e9SAndroid Build Coastguard Worker Generic, 17*da0073e9SAndroid Build Coastguard Worker Iterable, 18*da0073e9SAndroid Build Coastguard Worker Iterator, 19*da0073e9SAndroid Build Coastguard Worker Literal, 20*da0073e9SAndroid Build Coastguard Worker NoReturn, 21*da0073e9SAndroid Build Coastguard Worker Sequence, 22*da0073e9SAndroid Build Coastguard Worker TYPE_CHECKING, 23*da0073e9SAndroid Build Coastguard Worker TypeVar, 24*da0073e9SAndroid Build Coastguard Worker) 25*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import Self 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerfrom torchgen.code_template import CodeTemplate 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING: 31*da0073e9SAndroid Build Coastguard Worker from argparse import Namespace 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard WorkerREPO_ROOT = Path(__file__).absolute().parent.parent 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker# Many of these functions share logic for defining both the definition 38*da0073e9SAndroid Build Coastguard Worker# and declaration (for example, the function signature is the same), so 39*da0073e9SAndroid Build Coastguard Worker# we organize them into one function that takes a Target to say which 40*da0073e9SAndroid Build Coastguard Worker# code we want. 41*da0073e9SAndroid Build Coastguard Worker# 42*da0073e9SAndroid Build Coastguard Worker# This is an OPEN enum (we may add more cases to it in the future), so be sure 43*da0073e9SAndroid Build Coastguard Worker# to explicitly specify with Literal[Target.XXX] or Literal[Target.XXX, Target.YYY] 44*da0073e9SAndroid Build Coastguard Worker# what targets are valid for your use. 45*da0073e9SAndroid Build Coastguard Workerclass Target(Enum): 46*da0073e9SAndroid Build Coastguard Worker # top level namespace (not including at) 47*da0073e9SAndroid Build Coastguard Worker DEFINITION = auto() 48*da0073e9SAndroid Build Coastguard Worker DECLARATION = auto() 49*da0073e9SAndroid Build Coastguard Worker # TORCH_LIBRARY(...) { ... } 50*da0073e9SAndroid Build Coastguard Worker REGISTRATION = auto() 51*da0073e9SAndroid Build Coastguard Worker # namespace { ... } 52*da0073e9SAndroid Build Coastguard Worker ANONYMOUS_DEFINITION = auto() 53*da0073e9SAndroid Build Coastguard Worker # namespace cpu { ... } 54*da0073e9SAndroid Build Coastguard Worker NAMESPACED_DEFINITION = auto() 55*da0073e9SAndroid Build Coastguard Worker NAMESPACED_DECLARATION = auto() 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker# Matches "foo" in "foo, bar" but not "foobar". Used to search for the 59*da0073e9SAndroid Build Coastguard Worker# occurrence of a parameter in the derivative formula 60*da0073e9SAndroid Build Coastguard WorkerIDENT_REGEX = r"(^|\W){}($|\W)" 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker# TODO: Use a real parser here; this will get bamboozled 64*da0073e9SAndroid Build Coastguard Workerdef split_name_params(schema: str) -> tuple[str, list[str]]: 65*da0073e9SAndroid Build Coastguard Worker m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema) 66*da0073e9SAndroid Build Coastguard Worker if m is None: 67*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Unsupported function schema: {schema}") 68*da0073e9SAndroid Build Coastguard Worker name, _, params = m.groups() 69*da0073e9SAndroid Build Coastguard Worker return name, params.split(", ") 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard WorkerT = TypeVar("T") 73*da0073e9SAndroid Build Coastguard WorkerS = TypeVar("S") 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker# These two functions purposely return generators in analogy to map() 76*da0073e9SAndroid Build Coastguard Worker# so that you don't mix up when you need to list() them 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker# Map over function that may return None; omit Nones from output sequence 80*da0073e9SAndroid Build Coastguard Workerdef mapMaybe(func: Callable[[T], S | None], xs: Iterable[T]) -> Iterator[S]: 81*da0073e9SAndroid Build Coastguard Worker for x in xs: 82*da0073e9SAndroid Build Coastguard Worker r = func(x) 83*da0073e9SAndroid Build Coastguard Worker if r is not None: 84*da0073e9SAndroid Build Coastguard Worker yield r 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker# Map over function that returns sequences and cat them all together 88*da0073e9SAndroid Build Coastguard Workerdef concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]: 89*da0073e9SAndroid Build Coastguard Worker for x in xs: 90*da0073e9SAndroid Build Coastguard Worker yield from func(x) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker# Conveniently add error context to exceptions raised. Lets us 94*da0073e9SAndroid Build Coastguard Worker# easily say that an error occurred while processing a specific 95*da0073e9SAndroid Build Coastguard Worker# context. 96*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 97*da0073e9SAndroid Build Coastguard Workerdef context(msg_fn: Callable[[], str]) -> Iterator[None]: 98*da0073e9SAndroid Build Coastguard Worker try: 99*da0073e9SAndroid Build Coastguard Worker yield 100*da0073e9SAndroid Build Coastguard Worker except Exception as e: 101*da0073e9SAndroid Build Coastguard Worker # TODO: this does the wrong thing with KeyError 102*da0073e9SAndroid Build Coastguard Worker msg = msg_fn() 103*da0073e9SAndroid Build Coastguard Worker msg = textwrap.indent(msg, " ") 104*da0073e9SAndroid Build Coastguard Worker msg = f"{e.args[0]}\n{msg}" if e.args else msg 105*da0073e9SAndroid Build Coastguard Worker e.args = (msg,) + e.args[1:] 106*da0073e9SAndroid Build Coastguard Worker raise 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker# A little trick from https://github.com/python/mypy/issues/6366 110*da0073e9SAndroid Build Coastguard Worker# for getting mypy to do exhaustiveness checking 111*da0073e9SAndroid Build Coastguard Worker# TODO: put this somewhere else, maybe 112*da0073e9SAndroid Build Coastguard Workerdef assert_never(x: NoReturn) -> NoReturn: 113*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"Unhandled type: {type(x).__name__}") 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker@functools.lru_cache(maxsize=None) 117*da0073e9SAndroid Build Coastguard Workerdef _read_template(template_fn: str) -> CodeTemplate: 118*da0073e9SAndroid Build Coastguard Worker return CodeTemplate.from_file(template_fn) 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker# String hash that's stable across different executions, unlike builtin hash 122*da0073e9SAndroid Build Coastguard Workerdef string_stable_hash(s: str) -> int: 123*da0073e9SAndroid Build Coastguard Worker sha1 = hashlib.sha1(s.encode("latin1")).digest() 124*da0073e9SAndroid Build Coastguard Worker return int.from_bytes(sha1, byteorder="little") 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker# A small abstraction for writing out generated files and keeping track 128*da0073e9SAndroid Build Coastguard Worker# of what files have been written (so you can write out a list of output 129*da0073e9SAndroid Build Coastguard Worker# files) 130*da0073e9SAndroid Build Coastguard Workerclass FileManager: 131*da0073e9SAndroid Build Coastguard Worker install_dir: str 132*da0073e9SAndroid Build Coastguard Worker template_dir: str 133*da0073e9SAndroid Build Coastguard Worker dry_run: bool 134*da0073e9SAndroid Build Coastguard Worker filenames: set[str] 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None: 137*da0073e9SAndroid Build Coastguard Worker self.install_dir = install_dir 138*da0073e9SAndroid Build Coastguard Worker self.template_dir = template_dir 139*da0073e9SAndroid Build Coastguard Worker self.filenames = set() 140*da0073e9SAndroid Build Coastguard Worker self.dry_run = dry_run 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker def _write_if_changed(self, filename: str, contents: str) -> None: 143*da0073e9SAndroid Build Coastguard Worker old_contents: str | None 144*da0073e9SAndroid Build Coastguard Worker try: 145*da0073e9SAndroid Build Coastguard Worker with open(filename) as f: 146*da0073e9SAndroid Build Coastguard Worker old_contents = f.read() 147*da0073e9SAndroid Build Coastguard Worker except OSError: 148*da0073e9SAndroid Build Coastguard Worker old_contents = None 149*da0073e9SAndroid Build Coastguard Worker if contents != old_contents: 150*da0073e9SAndroid Build Coastguard Worker # Create output directory if it doesn't exist 151*da0073e9SAndroid Build Coastguard Worker os.makedirs(os.path.dirname(filename), exist_ok=True) 152*da0073e9SAndroid Build Coastguard Worker with open(filename, "w") as f: 153*da0073e9SAndroid Build Coastguard Worker f.write(contents) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker # Read from template file and replace pattern with callable (type could be dict or str). 156*da0073e9SAndroid Build Coastguard Worker def substitute_with_template( 157*da0073e9SAndroid Build Coastguard Worker self, template_fn: str, env_callable: Callable[[], str | dict[str, Any]] 158*da0073e9SAndroid Build Coastguard Worker ) -> str: 159*da0073e9SAndroid Build Coastguard Worker template_path = os.path.join(self.template_dir, template_fn) 160*da0073e9SAndroid Build Coastguard Worker env = env_callable() 161*da0073e9SAndroid Build Coastguard Worker if isinstance(env, dict): 162*da0073e9SAndroid Build Coastguard Worker if "generated_comment" not in env: 163*da0073e9SAndroid Build Coastguard Worker generator_default = REPO_ROOT / "torchgen" / "gen.py" 164*da0073e9SAndroid Build Coastguard Worker try: 165*da0073e9SAndroid Build Coastguard Worker generator = Path( 166*da0073e9SAndroid Build Coastguard Worker sys.modules["__main__"].__file__ or generator_default 167*da0073e9SAndroid Build Coastguard Worker ).absolute() 168*da0073e9SAndroid Build Coastguard Worker except (KeyError, AttributeError): 169*da0073e9SAndroid Build Coastguard Worker generator = generator_default.absolute() 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker try: 172*da0073e9SAndroid Build Coastguard Worker generator_path = generator.relative_to(REPO_ROOT).as_posix() 173*da0073e9SAndroid Build Coastguard Worker except ValueError: 174*da0073e9SAndroid Build Coastguard Worker generator_path = generator.name 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker env = { 177*da0073e9SAndroid Build Coastguard Worker **env, # copy the original dict instead of mutating it 178*da0073e9SAndroid Build Coastguard Worker "generated_comment": ( 179*da0073e9SAndroid Build Coastguard Worker "@" + f"generated by {generator_path} from {template_fn}" 180*da0073e9SAndroid Build Coastguard Worker ), 181*da0073e9SAndroid Build Coastguard Worker } 182*da0073e9SAndroid Build Coastguard Worker template = _read_template(template_path) 183*da0073e9SAndroid Build Coastguard Worker return template.substitute(env) 184*da0073e9SAndroid Build Coastguard Worker elif isinstance(env, str): 185*da0073e9SAndroid Build Coastguard Worker return env 186*da0073e9SAndroid Build Coastguard Worker else: 187*da0073e9SAndroid Build Coastguard Worker assert_never(env) 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker def write_with_template( 190*da0073e9SAndroid Build Coastguard Worker self, 191*da0073e9SAndroid Build Coastguard Worker filename: str, 192*da0073e9SAndroid Build Coastguard Worker template_fn: str, 193*da0073e9SAndroid Build Coastguard Worker env_callable: Callable[[], str | dict[str, Any]], 194*da0073e9SAndroid Build Coastguard Worker ) -> None: 195*da0073e9SAndroid Build Coastguard Worker filename = f"{self.install_dir}/{filename}" 196*da0073e9SAndroid Build Coastguard Worker assert filename not in self.filenames, "duplicate file write {filename}" 197*da0073e9SAndroid Build Coastguard Worker self.filenames.add(filename) 198*da0073e9SAndroid Build Coastguard Worker if not self.dry_run: 199*da0073e9SAndroid Build Coastguard Worker substitute_out = self.substitute_with_template( 200*da0073e9SAndroid Build Coastguard Worker template_fn=template_fn, 201*da0073e9SAndroid Build Coastguard Worker env_callable=env_callable, 202*da0073e9SAndroid Build Coastguard Worker ) 203*da0073e9SAndroid Build Coastguard Worker self._write_if_changed(filename=filename, contents=substitute_out) 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker def write( 206*da0073e9SAndroid Build Coastguard Worker self, 207*da0073e9SAndroid Build Coastguard Worker filename: str, 208*da0073e9SAndroid Build Coastguard Worker env_callable: Callable[[], str | dict[str, Any]], 209*da0073e9SAndroid Build Coastguard Worker ) -> None: 210*da0073e9SAndroid Build Coastguard Worker self.write_with_template(filename, filename, env_callable) 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker def write_sharded( 213*da0073e9SAndroid Build Coastguard Worker self, 214*da0073e9SAndroid Build Coastguard Worker filename: str, 215*da0073e9SAndroid Build Coastguard Worker items: Iterable[T], 216*da0073e9SAndroid Build Coastguard Worker *, 217*da0073e9SAndroid Build Coastguard Worker key_fn: Callable[[T], str], 218*da0073e9SAndroid Build Coastguard Worker env_callable: Callable[[T], dict[str, list[str]]], 219*da0073e9SAndroid Build Coastguard Worker num_shards: int, 220*da0073e9SAndroid Build Coastguard Worker base_env: dict[str, Any] | None = None, 221*da0073e9SAndroid Build Coastguard Worker sharded_keys: set[str], 222*da0073e9SAndroid Build Coastguard Worker ) -> None: 223*da0073e9SAndroid Build Coastguard Worker everything: dict[str, Any] = {"shard_id": "Everything"} 224*da0073e9SAndroid Build Coastguard Worker shards: list[dict[str, Any]] = [ 225*da0073e9SAndroid Build Coastguard Worker {"shard_id": f"_{i}"} for i in range(num_shards) 226*da0073e9SAndroid Build Coastguard Worker ] 227*da0073e9SAndroid Build Coastguard Worker all_shards = [everything] + shards 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker if base_env is not None: 230*da0073e9SAndroid Build Coastguard Worker for shard in all_shards: 231*da0073e9SAndroid Build Coastguard Worker shard.update(base_env) 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker for key in sharded_keys: 234*da0073e9SAndroid Build Coastguard Worker for shard in all_shards: 235*da0073e9SAndroid Build Coastguard Worker if key in shard: 236*da0073e9SAndroid Build Coastguard Worker assert isinstance( 237*da0073e9SAndroid Build Coastguard Worker shard[key], list 238*da0073e9SAndroid Build Coastguard Worker ), "sharded keys in base_env must be a list" 239*da0073e9SAndroid Build Coastguard Worker shard[key] = shard[key].copy() 240*da0073e9SAndroid Build Coastguard Worker else: 241*da0073e9SAndroid Build Coastguard Worker shard[key] = [] 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker def merge_env(into: dict[str, list[str]], from_: dict[str, list[str]]) -> None: 244*da0073e9SAndroid Build Coastguard Worker for k, v in from_.items(): 245*da0073e9SAndroid Build Coastguard Worker assert k in sharded_keys, f"undeclared sharded key {k}" 246*da0073e9SAndroid Build Coastguard Worker into[k] += v 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker if self.dry_run: 249*da0073e9SAndroid Build Coastguard Worker # Dry runs don't write any templates, so incomplete environments are fine 250*da0073e9SAndroid Build Coastguard Worker items = () 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker for item in items: 253*da0073e9SAndroid Build Coastguard Worker key = key_fn(item) 254*da0073e9SAndroid Build Coastguard Worker sid = string_stable_hash(key) % num_shards 255*da0073e9SAndroid Build Coastguard Worker env = env_callable(item) 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker merge_env(shards[sid], env) 258*da0073e9SAndroid Build Coastguard Worker merge_env(everything, env) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker dot_pos = filename.rfind(".") 261*da0073e9SAndroid Build Coastguard Worker if dot_pos == -1: 262*da0073e9SAndroid Build Coastguard Worker dot_pos = len(filename) 263*da0073e9SAndroid Build Coastguard Worker base_filename = filename[:dot_pos] 264*da0073e9SAndroid Build Coastguard Worker extension = filename[dot_pos:] 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker for shard in all_shards: 267*da0073e9SAndroid Build Coastguard Worker shard_id = shard["shard_id"] 268*da0073e9SAndroid Build Coastguard Worker self.write_with_template( 269*da0073e9SAndroid Build Coastguard Worker f"{base_filename}{shard_id}{extension}", filename, lambda: shard 270*da0073e9SAndroid Build Coastguard Worker ) 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled 273*da0073e9SAndroid Build Coastguard Worker self.filenames.discard( 274*da0073e9SAndroid Build Coastguard Worker f"{self.install_dir}/{base_filename}Everything{extension}" 275*da0073e9SAndroid Build Coastguard Worker ) 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker def write_outputs(self, variable_name: str, filename: str) -> None: 278*da0073e9SAndroid Build Coastguard Worker """Write a file containing the list of all outputs which are 279*da0073e9SAndroid Build Coastguard Worker generated by this script.""" 280*da0073e9SAndroid Build Coastguard Worker content = "set({}\n {})".format( 281*da0073e9SAndroid Build Coastguard Worker variable_name, 282*da0073e9SAndroid Build Coastguard Worker "\n ".join('"' + name + '"' for name in sorted(self.filenames)), 283*da0073e9SAndroid Build Coastguard Worker ) 284*da0073e9SAndroid Build Coastguard Worker self._write_if_changed(filename, content) 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker def template_dir_for_comments(self) -> str: 287*da0073e9SAndroid Build Coastguard Worker """ 288*da0073e9SAndroid Build Coastguard Worker This needs to be deterministic. The template dir is an absolute path 289*da0073e9SAndroid Build Coastguard Worker that varies across builds. So, just use the path relative to this file, 290*da0073e9SAndroid Build Coastguard Worker which will point to the codegen source but will be stable. 291*da0073e9SAndroid Build Coastguard Worker """ 292*da0073e9SAndroid Build Coastguard Worker return os.path.relpath(self.template_dir, os.path.dirname(__file__)) 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker# Helper function to generate file manager 296*da0073e9SAndroid Build Coastguard Workerdef make_file_manager( 297*da0073e9SAndroid Build Coastguard Worker options: Namespace, install_dir: str | None = None 298*da0073e9SAndroid Build Coastguard Worker) -> FileManager: 299*da0073e9SAndroid Build Coastguard Worker template_dir = os.path.join(options.source_path, "templates") 300*da0073e9SAndroid Build Coastguard Worker install_dir = install_dir if install_dir else options.install_dir 301*da0073e9SAndroid Build Coastguard Worker return FileManager( 302*da0073e9SAndroid Build Coastguard Worker install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run 303*da0073e9SAndroid Build Coastguard Worker ) 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker# Helper function to create a pretty representation for dataclasses 307*da0073e9SAndroid Build Coastguard Workerdef dataclass_repr( 308*da0073e9SAndroid Build Coastguard Worker obj: Any, 309*da0073e9SAndroid Build Coastguard Worker indent: int = 0, 310*da0073e9SAndroid Build Coastguard Worker width: int = 80, 311*da0073e9SAndroid Build Coastguard Worker) -> str: 312*da0073e9SAndroid Build Coastguard Worker # built-in pprint module support dataclasses from python 3.10 313*da0073e9SAndroid Build Coastguard Worker if sys.version_info >= (3, 10): 314*da0073e9SAndroid Build Coastguard Worker from pprint import pformat 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker return pformat(obj, indent, width) 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker return _pformat(obj, indent=indent, width=width) 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Workerdef _pformat( 322*da0073e9SAndroid Build Coastguard Worker obj: Any, 323*da0073e9SAndroid Build Coastguard Worker indent: int, 324*da0073e9SAndroid Build Coastguard Worker width: int, 325*da0073e9SAndroid Build Coastguard Worker curr_indent: int = 0, 326*da0073e9SAndroid Build Coastguard Worker) -> str: 327*da0073e9SAndroid Build Coastguard Worker assert is_dataclass(obj), f"obj should be a dataclass, received: {type(obj)}" 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker class_name = obj.__class__.__name__ 330*da0073e9SAndroid Build Coastguard Worker # update current indentation level with class name 331*da0073e9SAndroid Build Coastguard Worker curr_indent += len(class_name) + 1 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker fields_list = [(f.name, getattr(obj, f.name)) for f in fields(obj) if f.repr] 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker fields_str = [] 336*da0073e9SAndroid Build Coastguard Worker for name, attr in fields_list: 337*da0073e9SAndroid Build Coastguard Worker # update the current indent level with the field name 338*da0073e9SAndroid Build Coastguard Worker # dict, list, set and tuple also add indent as done in pprint 339*da0073e9SAndroid Build Coastguard Worker _curr_indent = curr_indent + len(name) + 1 340*da0073e9SAndroid Build Coastguard Worker if is_dataclass(attr): 341*da0073e9SAndroid Build Coastguard Worker str_repr = _pformat(attr, indent, width, _curr_indent) 342*da0073e9SAndroid Build Coastguard Worker elif isinstance(attr, dict): 343*da0073e9SAndroid Build Coastguard Worker str_repr = _format_dict(attr, indent, width, _curr_indent) 344*da0073e9SAndroid Build Coastguard Worker elif isinstance(attr, (list, set, tuple)): 345*da0073e9SAndroid Build Coastguard Worker str_repr = _format_list(attr, indent, width, _curr_indent) 346*da0073e9SAndroid Build Coastguard Worker else: 347*da0073e9SAndroid Build Coastguard Worker str_repr = repr(attr) 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker fields_str.append(f"{name}={str_repr}") 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker indent_str = curr_indent * " " 352*da0073e9SAndroid Build Coastguard Worker body = f",\n{indent_str}".join(fields_str) 353*da0073e9SAndroid Build Coastguard Worker return f"{class_name}({body})" 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Workerdef _format_dict( 357*da0073e9SAndroid Build Coastguard Worker attr: dict[Any, Any], 358*da0073e9SAndroid Build Coastguard Worker indent: int, 359*da0073e9SAndroid Build Coastguard Worker width: int, 360*da0073e9SAndroid Build Coastguard Worker curr_indent: int, 361*da0073e9SAndroid Build Coastguard Worker) -> str: 362*da0073e9SAndroid Build Coastguard Worker curr_indent += indent + 3 363*da0073e9SAndroid Build Coastguard Worker dict_repr = [] 364*da0073e9SAndroid Build Coastguard Worker for k, v in attr.items(): 365*da0073e9SAndroid Build Coastguard Worker k_repr = repr(k) 366*da0073e9SAndroid Build Coastguard Worker v_str = ( 367*da0073e9SAndroid Build Coastguard Worker _pformat(v, indent, width, curr_indent + len(k_repr)) 368*da0073e9SAndroid Build Coastguard Worker if is_dataclass(v) 369*da0073e9SAndroid Build Coastguard Worker else repr(v) 370*da0073e9SAndroid Build Coastguard Worker ) 371*da0073e9SAndroid Build Coastguard Worker dict_repr.append(f"{k_repr}: {v_str}") 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker return _format(dict_repr, indent, width, curr_indent, "{", "}") 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Workerdef _format_list( 377*da0073e9SAndroid Build Coastguard Worker attr: list[Any] | set[Any] | tuple[Any, ...], 378*da0073e9SAndroid Build Coastguard Worker indent: int, 379*da0073e9SAndroid Build Coastguard Worker width: int, 380*da0073e9SAndroid Build Coastguard Worker curr_indent: int, 381*da0073e9SAndroid Build Coastguard Worker) -> str: 382*da0073e9SAndroid Build Coastguard Worker curr_indent += indent + 1 383*da0073e9SAndroid Build Coastguard Worker list_repr = [ 384*da0073e9SAndroid Build Coastguard Worker _pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l) 385*da0073e9SAndroid Build Coastguard Worker for l in attr 386*da0073e9SAndroid Build Coastguard Worker ] 387*da0073e9SAndroid Build Coastguard Worker start, end = ("[", "]") if isinstance(attr, list) else ("(", ")") 388*da0073e9SAndroid Build Coastguard Worker return _format(list_repr, indent, width, curr_indent, start, end) 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Workerdef _format( 392*da0073e9SAndroid Build Coastguard Worker fields_str: list[str], 393*da0073e9SAndroid Build Coastguard Worker indent: int, 394*da0073e9SAndroid Build Coastguard Worker width: int, 395*da0073e9SAndroid Build Coastguard Worker curr_indent: int, 396*da0073e9SAndroid Build Coastguard Worker start: str, 397*da0073e9SAndroid Build Coastguard Worker end: str, 398*da0073e9SAndroid Build Coastguard Worker) -> str: 399*da0073e9SAndroid Build Coastguard Worker delimiter, curr_indent_str = "", "" 400*da0073e9SAndroid Build Coastguard Worker # if it exceed the max width then we place one element per line 401*da0073e9SAndroid Build Coastguard Worker if len(repr(fields_str)) >= width: 402*da0073e9SAndroid Build Coastguard Worker delimiter = "\n" 403*da0073e9SAndroid Build Coastguard Worker curr_indent_str = " " * curr_indent 404*da0073e9SAndroid Build Coastguard Worker 405*da0073e9SAndroid Build Coastguard Worker indent_str = " " * indent 406*da0073e9SAndroid Build Coastguard Worker body = f", {delimiter}{curr_indent_str}".join(fields_str) 407*da0073e9SAndroid Build Coastguard Worker return f"{start}{indent_str}{body}{end}" 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Workerclass NamespaceHelper: 411*da0073e9SAndroid Build Coastguard Worker """A helper for constructing the namespace open and close strings for a nested set of namespaces. 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker e.g. for namespace_str torch::lazy, 414*da0073e9SAndroid Build Coastguard Worker 415*da0073e9SAndroid Build Coastguard Worker prologue: 416*da0073e9SAndroid Build Coastguard Worker namespace torch { 417*da0073e9SAndroid Build Coastguard Worker namespace lazy { 418*da0073e9SAndroid Build Coastguard Worker 419*da0073e9SAndroid Build Coastguard Worker epilogue: 420*da0073e9SAndroid Build Coastguard Worker } // namespace lazy 421*da0073e9SAndroid Build Coastguard Worker } // namespace torch 422*da0073e9SAndroid Build Coastguard Worker """ 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker def __init__( 425*da0073e9SAndroid Build Coastguard Worker self, namespace_str: str, entity_name: str = "", max_level: int = 2 426*da0073e9SAndroid Build Coastguard Worker ) -> None: 427*da0073e9SAndroid Build Coastguard Worker # cpp_namespace can be a colon joined string such as torch::lazy 428*da0073e9SAndroid Build Coastguard Worker cpp_namespaces = namespace_str.split("::") 429*da0073e9SAndroid Build Coastguard Worker assert ( 430*da0073e9SAndroid Build Coastguard Worker len(cpp_namespaces) <= max_level 431*da0073e9SAndroid Build Coastguard Worker ), f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}." 432*da0073e9SAndroid Build Coastguard Worker self.cpp_namespace_ = namespace_str 433*da0073e9SAndroid Build Coastguard Worker self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces]) 434*da0073e9SAndroid Build Coastguard Worker self.epilogue_ = "\n".join( 435*da0073e9SAndroid Build Coastguard Worker [f"}} // namespace {n}" for n in reversed(cpp_namespaces)] 436*da0073e9SAndroid Build Coastguard Worker ) 437*da0073e9SAndroid Build Coastguard Worker self.namespaces_ = cpp_namespaces 438*da0073e9SAndroid Build Coastguard Worker self.entity_name_ = entity_name 439*da0073e9SAndroid Build Coastguard Worker 440*da0073e9SAndroid Build Coastguard Worker @staticmethod 441*da0073e9SAndroid Build Coastguard Worker def from_namespaced_entity( 442*da0073e9SAndroid Build Coastguard Worker namespaced_entity: str, max_level: int = 2 443*da0073e9SAndroid Build Coastguard Worker ) -> NamespaceHelper: 444*da0073e9SAndroid Build Coastguard Worker """ 445*da0073e9SAndroid Build Coastguard Worker Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add" 446*da0073e9SAndroid Build Coastguard Worker """ 447*da0073e9SAndroid Build Coastguard Worker names = namespaced_entity.split("::") 448*da0073e9SAndroid Build Coastguard Worker entity_name = names[-1] 449*da0073e9SAndroid Build Coastguard Worker namespace_str = "::".join(names[:-1]) 450*da0073e9SAndroid Build Coastguard Worker return NamespaceHelper( 451*da0073e9SAndroid Build Coastguard Worker namespace_str=namespace_str, entity_name=entity_name, max_level=max_level 452*da0073e9SAndroid Build Coastguard Worker ) 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker @property 455*da0073e9SAndroid Build Coastguard Worker def prologue(self) -> str: 456*da0073e9SAndroid Build Coastguard Worker return self.prologue_ 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker @property 459*da0073e9SAndroid Build Coastguard Worker def epilogue(self) -> str: 460*da0073e9SAndroid Build Coastguard Worker return self.epilogue_ 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker @property 463*da0073e9SAndroid Build Coastguard Worker def entity_name(self) -> str: 464*da0073e9SAndroid Build Coastguard Worker return self.entity_name_ 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker # Only allow certain level of namespaces 467*da0073e9SAndroid Build Coastguard Worker def get_cpp_namespace(self, default: str = "") -> str: 468*da0073e9SAndroid Build Coastguard Worker """ 469*da0073e9SAndroid Build Coastguard Worker Return the namespace string from joining all the namespaces by "::" (hence no leading "::"). 470*da0073e9SAndroid Build Coastguard Worker Return default if namespace string is empty. 471*da0073e9SAndroid Build Coastguard Worker """ 472*da0073e9SAndroid Build Coastguard Worker return self.cpp_namespace_ if self.cpp_namespace_ else default 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Workerclass OrderedSet(Generic[T]): 476*da0073e9SAndroid Build Coastguard Worker storage: dict[T, Literal[None]] 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Worker def __init__(self, iterable: Iterable[T] | None = None) -> None: 479*da0073e9SAndroid Build Coastguard Worker if iterable is None: 480*da0073e9SAndroid Build Coastguard Worker self.storage = {} 481*da0073e9SAndroid Build Coastguard Worker else: 482*da0073e9SAndroid Build Coastguard Worker self.storage = dict.fromkeys(iterable) 483*da0073e9SAndroid Build Coastguard Worker 484*da0073e9SAndroid Build Coastguard Worker def __contains__(self, item: T) -> bool: 485*da0073e9SAndroid Build Coastguard Worker return item in self.storage 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker def __iter__(self) -> Iterator[T]: 488*da0073e9SAndroid Build Coastguard Worker return iter(self.storage.keys()) 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker def update(self, items: OrderedSet[T]) -> None: 491*da0073e9SAndroid Build Coastguard Worker self.storage.update(items.storage) 492*da0073e9SAndroid Build Coastguard Worker 493*da0073e9SAndroid Build Coastguard Worker def add(self, item: T) -> None: 494*da0073e9SAndroid Build Coastguard Worker self.storage[item] = None 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker def copy(self) -> OrderedSet[T]: 497*da0073e9SAndroid Build Coastguard Worker ret: OrderedSet[T] = OrderedSet() 498*da0073e9SAndroid Build Coastguard Worker ret.storage = self.storage.copy() 499*da0073e9SAndroid Build Coastguard Worker return ret 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker @staticmethod 502*da0073e9SAndroid Build Coastguard Worker def union(*args: OrderedSet[T]) -> OrderedSet[T]: 503*da0073e9SAndroid Build Coastguard Worker ret = args[0].copy() 504*da0073e9SAndroid Build Coastguard Worker for s in args[1:]: 505*da0073e9SAndroid Build Coastguard Worker ret.update(s) 506*da0073e9SAndroid Build Coastguard Worker return ret 507*da0073e9SAndroid Build Coastguard Worker 508*da0073e9SAndroid Build Coastguard Worker def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]: 509*da0073e9SAndroid Build Coastguard Worker return OrderedSet.union(self, other) 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker def __ior__(self, other: OrderedSet[T]) -> Self: 512*da0073e9SAndroid Build Coastguard Worker self.update(other) 513*da0073e9SAndroid Build Coastguard Worker return self 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker def __eq__(self, other: object) -> bool: 516*da0073e9SAndroid Build Coastguard Worker if isinstance(other, OrderedSet): 517*da0073e9SAndroid Build Coastguard Worker return self.storage == other.storage 518*da0073e9SAndroid Build Coastguard Worker else: 519*da0073e9SAndroid Build Coastguard Worker return set(self.storage.keys()) == other 520