1"""JIT C++ strings into executables.""" 2import atexit 3import os 4import re 5import shutil 6import textwrap 7import threading 8from typing import Any, List, Optional 9 10import torch 11from torch.utils.benchmark.utils._stubs import CallgrindModuleType, TimeitModuleType 12from torch.utils.benchmark.utils.common import _make_temp_dir 13from torch.utils import cpp_extension 14 15 16LOCK = threading.Lock() 17SOURCE_ROOT = os.path.split(os.path.abspath(__file__))[0] 18 19# We calculate uuid once at import time so that separate processes will have 20# separate build roots, but threads will share the same build root. 21# `cpp_extension` uses build root as part of the cache key, so per-invocation 22# uuid's (e.g. different build root per _compile_template call) would lead to 23# a 0% cache hit rate and spurious recompilation. Consider the following: 24# ``` 25# setup = "auto x = torch::ones({1024, 1024});" 26# stmt = "torch::mm(x, x);" 27# for num_threads in [1, 2, 4, 8]: 28# print(Timer(stmt, setup, num_threads=num_threads, language="c++").blocked_autorange()) 29# ```` 30# `setup` and `stmt` do not change, so we can reuse the executable from the 31# first pass through the loop. 32_BUILD_ROOT: Optional[str] = None 33 34def _get_build_root() -> str: 35 global _BUILD_ROOT 36 if _BUILD_ROOT is None: 37 _BUILD_ROOT = _make_temp_dir(prefix="benchmark_utils_jit_build") 38 atexit.register(shutil.rmtree, _BUILD_ROOT) 39 return _BUILD_ROOT 40 41 42# BACK_TESTING_NOTE: 43# There are two workflows where this code could be used. One is the obvious 44# case where someone simply builds or installs PyTorch and uses Timer. 45# The other is that the entire `torch/utils/benchmark` folder from a CURRENT 46# PyTorch checkout is copy-pasted into a much OLDER version of the PyTorch 47# source code. This is what we refer to here as "back testing". The rationale 48# is that we might want to use current tooling to study some aspect of an 49# earlier version of PyTorch. (e.g. a regression.) 50# 51# The problem is that Timer relies on several aspects of core PyTorch, namely 52# some binding functions for Valgrind symbols in `torch._C` and the 53# `torch.__config__._cxx_flags()` method. If we were to naively copy code 54# around this wouldn't work as the symbols of interest aren't present in 55# earlier versions of PyTorch. In order to work around this, we must add back 56# testing shims. These shims will never activate during normal use, but will 57# allow Timer to function outside of the "correct" version of PyTorch by 58# emulating functionality that was added later. 59# 60# These shims are temporary, and as Timer becomes more integrated with 61# PyTorch the cost and complexity of such shims will increase. Once back 62# testing is no longer required (which is to say we have done enough historic 63# analysis and the shims no longer justify their maintenance and code 64# complexity costs) back testing paths will be removed. 65 66CXX_FLAGS: Optional[List[str]] 67if hasattr(torch.__config__, "_cxx_flags"): 68 try: 69 CXX_FLAGS = torch.__config__._cxx_flags().strip().split() 70 if CXX_FLAGS is not None and "-g" not in CXX_FLAGS: 71 CXX_FLAGS.append("-g") 72 # remove "-W" flags to allow build benchmarks 73 # with a relaxed constraint of compiler versions 74 if CXX_FLAGS is not None: 75 CXX_FLAGS = list(filter(lambda x: not x.startswith("-W"), CXX_FLAGS)) 76 77 except RuntimeError: 78 # We are in FBCode. 79 CXX_FLAGS = None 80else: 81 # FIXME: Remove when back testing is no longer required. 82 CXX_FLAGS = ["-O2", "-fPIC", "-g"] 83 84EXTRA_INCLUDE_PATHS: List[str] = [os.path.join(SOURCE_ROOT, "valgrind_wrapper")] 85CONDA_PREFIX = os.getenv("CONDA_PREFIX") 86if CONDA_PREFIX is not None: 87 # Load will automatically search /usr/include, but not conda include. 88 EXTRA_INCLUDE_PATHS.append(os.path.join(CONDA_PREFIX, "include")) 89 90 91COMPAT_CALLGRIND_BINDINGS: Optional[CallgrindModuleType] = None 92def get_compat_bindings() -> CallgrindModuleType: 93 with LOCK: 94 global COMPAT_CALLGRIND_BINDINGS 95 if COMPAT_CALLGRIND_BINDINGS is None: 96 COMPAT_CALLGRIND_BINDINGS = cpp_extension.load( 97 name="callgrind_bindings", 98 sources=[os.path.join( 99 SOURCE_ROOT, 100 "valgrind_wrapper", 101 "compat_bindings.cpp" 102 )], 103 extra_cflags=CXX_FLAGS, 104 extra_include_paths=EXTRA_INCLUDE_PATHS, 105 ) 106 return COMPAT_CALLGRIND_BINDINGS 107 108 109def _compile_template( 110 *, 111 stmt: str, 112 setup: str, 113 global_setup: str, 114 src: str, 115 is_standalone: bool 116) -> Any: 117 for before, after, indentation in ( 118 ("// GLOBAL_SETUP_TEMPLATE_LOCATION", global_setup, 0), 119 ("// SETUP_TEMPLATE_LOCATION", setup, 4), 120 ("// STMT_TEMPLATE_LOCATION", stmt, 8) 121 ): 122 # C++ doesn't care about indentation so this code isn't load 123 # bearing the way it is with Python, but this makes the source 124 # look nicer if a human has to look at it. 125 src = re.sub( 126 before, 127 textwrap.indent(after, " " * indentation)[indentation:], 128 src 129 ) 130 131 # We want to isolate different Timers. However `cpp_extension` will 132 # cache builds which will significantly reduce the cost of repeated 133 # invocations. 134 with LOCK: 135 name = f"timer_cpp_{abs(hash(src))}" 136 build_dir = os.path.join(_get_build_root(), name) 137 os.makedirs(build_dir, exist_ok=True) 138 139 src_path = os.path.join(build_dir, "timer_src.cpp") 140 with open(src_path, "w") as f: 141 f.write(src) 142 143 # `cpp_extension` has its own locking scheme, so we don't need our lock. 144 return cpp_extension.load( 145 name=name, 146 sources=[src_path], 147 build_directory=build_dir, 148 extra_cflags=CXX_FLAGS, 149 extra_include_paths=EXTRA_INCLUDE_PATHS, 150 is_python_module=not is_standalone, 151 is_standalone=is_standalone, 152 ) 153 154 155def compile_timeit_template(*, stmt: str, setup: str, global_setup: str) -> TimeitModuleType: 156 template_path: str = os.path.join(SOURCE_ROOT, "timeit_template.cpp") 157 with open(template_path) as f: 158 src: str = f.read() 159 160 module = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=False) 161 assert isinstance(module, TimeitModuleType) 162 return module 163 164 165def compile_callgrind_template(*, stmt: str, setup: str, global_setup: str) -> str: 166 template_path: str = os.path.join(SOURCE_ROOT, "valgrind_wrapper", "timer_callgrind_template.cpp") 167 with open(template_path) as f: 168 src: str = f.read() 169 170 target = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=True) 171 assert isinstance(target, str) 172 return target 173