xref: /aosp_15_r20/external/pytorch/benchmarks/instruction_counts/core/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2import atexit
3import re
4import shutil
5import textwrap
6from typing import List, Optional, Tuple
7
8from core.api import GroupedBenchmark, TimerArgs
9from core.types import Definition, FlatIntermediateDefinition, Label
10
11from torch.utils.benchmark.utils.common import _make_temp_dir
12
13
14_TEMPDIR: Optional[str] = None
15
16
17def get_temp_dir() -> str:
18    global _TEMPDIR
19    if _TEMPDIR is None:
20        _TEMPDIR = _make_temp_dir(
21            prefix="instruction_count_microbenchmarks", gc_dev_shm=True
22        )
23        atexit.register(shutil.rmtree, path=_TEMPDIR)
24    return _TEMPDIR
25
26
27def _flatten(
28    key_prefix: Label, sub_schema: Definition, result: FlatIntermediateDefinition
29) -> None:
30    for k, value in sub_schema.items():
31        if isinstance(k, tuple):
32            assert all(isinstance(ki, str) for ki in k)
33            key_suffix: Label = k
34        elif k is None:
35            key_suffix = ()
36        else:
37            assert isinstance(k, str)
38            key_suffix = (k,)
39
40        key: Label = key_prefix + key_suffix
41        if isinstance(value, (TimerArgs, GroupedBenchmark)):
42            assert key not in result, f"duplicate key: {key}"
43            result[key] = value
44        else:
45            assert isinstance(value, dict)
46            _flatten(key_prefix=key, sub_schema=value, result=result)
47
48
49def flatten(schema: Definition) -> FlatIntermediateDefinition:
50    """See types.py for an explanation of nested vs. flat definitions."""
51    result: FlatIntermediateDefinition = {}
52    _flatten(key_prefix=(), sub_schema=schema, result=result)
53
54    # Ensure that we produced a valid flat definition.
55    for k, v in result.items():
56        assert isinstance(k, tuple)
57        assert all(isinstance(ki, str) for ki in k)
58        assert isinstance(v, (TimerArgs, GroupedBenchmark))
59    return result
60
61
62def parse_stmts(stmts: str) -> Tuple[str, str]:
63    """Helper function for side-by-side Python and C++ stmts.
64
65    For more complex statements, it can be useful to see Python and C++ code
66    side by side. To this end, we provide an **extremely restricted** way
67    to define Python and C++ code side-by-side. The schema should be mostly
68    self explanatory, with the following non-obvious caveats:
69      - Width for the left (Python) column MUST be 40 characters.
70      - The column separator is " | ", not "|". Whitespace matters.
71    """
72    stmts = textwrap.dedent(stmts).strip()
73    lines: List[str] = stmts.splitlines(keepends=False)
74    assert len(lines) >= 3, f"Invalid string:\n{stmts}"
75
76    column_header_pattern = r"^Python\s{35}\| C\+\+(\s*)$"
77    signature_pattern = r"^: f\((.*)\)( -> (.+))?\s*$"
78    separation_pattern = r"^[-]{40} | [-]{40}$"
79    code_pattern = r"^(.{40}) \|($| (.*)$)"
80
81    column_match = re.search(column_header_pattern, lines[0])
82    if column_match is None:
83        raise ValueError(
84            f"Column header `{lines[0]}` "
85            f"does not match pattern `{column_header_pattern}`"
86        )
87
88    assert re.search(separation_pattern, lines[1])
89
90    py_lines: List[str] = []
91    cpp_lines: List[str] = []
92    for l in lines[2:]:
93        l_match = re.search(code_pattern, l)
94        if l_match is None:
95            raise ValueError(f"Invalid line `{l}`")
96        py_lines.append(l_match.groups()[0])
97        cpp_lines.append(l_match.groups()[2] or "")
98
99        # Make sure we can round trip for correctness.
100        l_from_stmts = f"{py_lines[-1]:<40} | {cpp_lines[-1]:<40}".rstrip()
101        assert l_from_stmts == l.rstrip(), f"Failed to round trip `{l}`"
102
103    return "\n".join(py_lines), "\n".join(cpp_lines)
104