xref: /aosp_15_r20/external/pytorch/torch/utils/benchmark/utils/cpp_jit.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""JIT C++ strings into executables."""
2import atexit
3import os
4import re
5import shutil
6import textwrap
7import threading
8from typing import Any, List, Optional
9
10import torch
11from torch.utils.benchmark.utils._stubs import CallgrindModuleType, TimeitModuleType
12from torch.utils.benchmark.utils.common import _make_temp_dir
13from torch.utils import cpp_extension
14
15
16LOCK = threading.Lock()
17SOURCE_ROOT = os.path.split(os.path.abspath(__file__))[0]
18
19# We calculate uuid once at import time so that separate processes will have
20# separate build roots, but threads will share the same build root.
21# `cpp_extension` uses build root as part of the cache key, so per-invocation
22# uuid's (e.g. different build root per _compile_template call) would lead to
23# a 0% cache hit rate and spurious recompilation. Consider the following:
24#   ```
25#   setup = "auto x = torch::ones({1024, 1024});"
26#   stmt = "torch::mm(x, x);"
27#   for num_threads in [1, 2, 4, 8]:
28#     print(Timer(stmt, setup, num_threads=num_threads, language="c++").blocked_autorange())
29#   ````
30# `setup` and `stmt` do not change, so we can reuse the executable from the
31# first pass through the loop.
32_BUILD_ROOT: Optional[str] = None
33
34def _get_build_root() -> str:
35    global _BUILD_ROOT
36    if _BUILD_ROOT is None:
37        _BUILD_ROOT = _make_temp_dir(prefix="benchmark_utils_jit_build")
38        atexit.register(shutil.rmtree, _BUILD_ROOT)
39    return _BUILD_ROOT
40
41
42# BACK_TESTING_NOTE:
43#   There are two workflows where this code could be used. One is the obvious
44#   case where someone simply builds or installs PyTorch and uses Timer.
45#   The other is that the entire `torch/utils/benchmark` folder from a CURRENT
46#   PyTorch checkout is copy-pasted into a much OLDER version of the PyTorch
47#   source code. This is what we refer to here as "back testing". The rationale
48#   is that we might want to use current tooling to study some aspect of an
49#   earlier version of PyTorch. (e.g. a regression.)
50#
51#   The problem is that Timer relies on several aspects of core PyTorch, namely
52#   some binding functions for Valgrind symbols in `torch._C` and the
53#   `torch.__config__._cxx_flags()` method. If we were to naively copy code
54#   around this wouldn't work as the symbols of interest aren't present in
55#   earlier versions of PyTorch. In order to work around this, we must add back
56#   testing shims. These shims will never activate during normal use, but will
57#   allow Timer to function outside of the "correct" version of PyTorch by
58#   emulating functionality that was added later.
59#
60#   These shims are temporary, and as Timer becomes more integrated with
61#   PyTorch the cost and complexity of such shims will increase. Once back
62#   testing is no longer required (which is to say we have done enough historic
63#   analysis and the shims no longer justify their maintenance and code
64#   complexity costs) back testing paths will be removed.
65
66CXX_FLAGS: Optional[List[str]]
67if hasattr(torch.__config__, "_cxx_flags"):
68    try:
69        CXX_FLAGS = torch.__config__._cxx_flags().strip().split()
70        if CXX_FLAGS is not None and "-g" not in CXX_FLAGS:
71            CXX_FLAGS.append("-g")
72        # remove "-W" flags to allow build benchmarks
73        # with a relaxed constraint of compiler versions
74        if CXX_FLAGS is not None:
75            CXX_FLAGS = list(filter(lambda x: not x.startswith("-W"), CXX_FLAGS))
76
77    except RuntimeError:
78        # We are in FBCode.
79        CXX_FLAGS = None
80else:
81    # FIXME: Remove when back testing is no longer required.
82    CXX_FLAGS = ["-O2", "-fPIC", "-g"]
83
84EXTRA_INCLUDE_PATHS: List[str] = [os.path.join(SOURCE_ROOT, "valgrind_wrapper")]
85CONDA_PREFIX = os.getenv("CONDA_PREFIX")
86if CONDA_PREFIX is not None:
87    # Load will automatically search /usr/include, but not conda include.
88    EXTRA_INCLUDE_PATHS.append(os.path.join(CONDA_PREFIX, "include"))
89
90
91COMPAT_CALLGRIND_BINDINGS: Optional[CallgrindModuleType] = None
92def get_compat_bindings() -> CallgrindModuleType:
93    with LOCK:
94        global COMPAT_CALLGRIND_BINDINGS
95        if COMPAT_CALLGRIND_BINDINGS is None:
96            COMPAT_CALLGRIND_BINDINGS = cpp_extension.load(
97                name="callgrind_bindings",
98                sources=[os.path.join(
99                    SOURCE_ROOT,
100                    "valgrind_wrapper",
101                    "compat_bindings.cpp"
102                )],
103                extra_cflags=CXX_FLAGS,
104                extra_include_paths=EXTRA_INCLUDE_PATHS,
105            )
106    return COMPAT_CALLGRIND_BINDINGS
107
108
109def _compile_template(
110    *,
111    stmt: str,
112    setup: str,
113    global_setup: str,
114    src: str,
115    is_standalone: bool
116) -> Any:
117    for before, after, indentation in (
118        ("// GLOBAL_SETUP_TEMPLATE_LOCATION", global_setup, 0),
119        ("// SETUP_TEMPLATE_LOCATION", setup, 4),
120        ("// STMT_TEMPLATE_LOCATION", stmt, 8)
121    ):
122        # C++ doesn't care about indentation so this code isn't load
123        # bearing the way it is with Python, but this makes the source
124        # look nicer if a human has to look at it.
125        src = re.sub(
126            before,
127            textwrap.indent(after, " " * indentation)[indentation:],
128            src
129        )
130
131    # We want to isolate different Timers. However `cpp_extension` will
132    # cache builds which will significantly reduce the cost of repeated
133    # invocations.
134    with LOCK:
135        name = f"timer_cpp_{abs(hash(src))}"
136        build_dir = os.path.join(_get_build_root(), name)
137        os.makedirs(build_dir, exist_ok=True)
138
139        src_path = os.path.join(build_dir, "timer_src.cpp")
140        with open(src_path, "w") as f:
141            f.write(src)
142
143    # `cpp_extension` has its own locking scheme, so we don't need our lock.
144    return cpp_extension.load(
145        name=name,
146        sources=[src_path],
147        build_directory=build_dir,
148        extra_cflags=CXX_FLAGS,
149        extra_include_paths=EXTRA_INCLUDE_PATHS,
150        is_python_module=not is_standalone,
151        is_standalone=is_standalone,
152    )
153
154
155def compile_timeit_template(*, stmt: str, setup: str, global_setup: str) -> TimeitModuleType:
156    template_path: str = os.path.join(SOURCE_ROOT, "timeit_template.cpp")
157    with open(template_path) as f:
158        src: str = f.read()
159
160    module = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=False)
161    assert isinstance(module, TimeitModuleType)
162    return module
163
164
165def compile_callgrind_template(*, stmt: str, setup: str, global_setup: str) -> str:
166    template_path: str = os.path.join(SOURCE_ROOT, "valgrind_wrapper", "timer_callgrind_template.cpp")
167    with open(template_path) as f:
168        src: str = f.read()
169
170    target = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=True)
171    assert isinstance(target, str)
172    return target
173