1# mypy: ignore-errors 2import atexit 3import re 4import shutil 5import textwrap 6from typing import List, Optional, Tuple 7 8from core.api import GroupedBenchmark, TimerArgs 9from core.types import Definition, FlatIntermediateDefinition, Label 10 11from torch.utils.benchmark.utils.common import _make_temp_dir 12 13 14_TEMPDIR: Optional[str] = None 15 16 17def get_temp_dir() -> str: 18 global _TEMPDIR 19 if _TEMPDIR is None: 20 _TEMPDIR = _make_temp_dir( 21 prefix="instruction_count_microbenchmarks", gc_dev_shm=True 22 ) 23 atexit.register(shutil.rmtree, path=_TEMPDIR) 24 return _TEMPDIR 25 26 27def _flatten( 28 key_prefix: Label, sub_schema: Definition, result: FlatIntermediateDefinition 29) -> None: 30 for k, value in sub_schema.items(): 31 if isinstance(k, tuple): 32 assert all(isinstance(ki, str) for ki in k) 33 key_suffix: Label = k 34 elif k is None: 35 key_suffix = () 36 else: 37 assert isinstance(k, str) 38 key_suffix = (k,) 39 40 key: Label = key_prefix + key_suffix 41 if isinstance(value, (TimerArgs, GroupedBenchmark)): 42 assert key not in result, f"duplicate key: {key}" 43 result[key] = value 44 else: 45 assert isinstance(value, dict) 46 _flatten(key_prefix=key, sub_schema=value, result=result) 47 48 49def flatten(schema: Definition) -> FlatIntermediateDefinition: 50 """See types.py for an explanation of nested vs. flat definitions.""" 51 result: FlatIntermediateDefinition = {} 52 _flatten(key_prefix=(), sub_schema=schema, result=result) 53 54 # Ensure that we produced a valid flat definition. 55 for k, v in result.items(): 56 assert isinstance(k, tuple) 57 assert all(isinstance(ki, str) for ki in k) 58 assert isinstance(v, (TimerArgs, GroupedBenchmark)) 59 return result 60 61 62def parse_stmts(stmts: str) -> Tuple[str, str]: 63 """Helper function for side-by-side Python and C++ stmts. 64 65 For more complex statements, it can be useful to see Python and C++ code 66 side by side. To this end, we provide an **extremely restricted** way 67 to define Python and C++ code side-by-side. The schema should be mostly 68 self explanatory, with the following non-obvious caveats: 69 - Width for the left (Python) column MUST be 40 characters. 70 - The column separator is " | ", not "|". Whitespace matters. 71 """ 72 stmts = textwrap.dedent(stmts).strip() 73 lines: List[str] = stmts.splitlines(keepends=False) 74 assert len(lines) >= 3, f"Invalid string:\n{stmts}" 75 76 column_header_pattern = r"^Python\s{35}\| C\+\+(\s*)$" 77 signature_pattern = r"^: f\((.*)\)( -> (.+))?\s*$" 78 separation_pattern = r"^[-]{40} | [-]{40}$" 79 code_pattern = r"^(.{40}) \|($| (.*)$)" 80 81 column_match = re.search(column_header_pattern, lines[0]) 82 if column_match is None: 83 raise ValueError( 84 f"Column header `{lines[0]}` " 85 f"does not match pattern `{column_header_pattern}`" 86 ) 87 88 assert re.search(separation_pattern, lines[1]) 89 90 py_lines: List[str] = [] 91 cpp_lines: List[str] = [] 92 for l in lines[2:]: 93 l_match = re.search(code_pattern, l) 94 if l_match is None: 95 raise ValueError(f"Invalid line `{l}`") 96 py_lines.append(l_match.groups()[0]) 97 cpp_lines.append(l_match.groups()[2] or "") 98 99 # Make sure we can round trip for correctness. 100 l_from_stmts = f"{py_lines[-1]:<40} | {cpp_lines[-1]:<40}".rstrip() 101 assert l_from_stmts == l.rstrip(), f"Failed to round trip `{l}`" 102 103 return "\n".join(py_lines), "\n".join(cpp_lines) 104