1"""Timer class based on the timeit.Timer class, but torch aware.""" 2import enum 3import timeit 4import textwrap 5from typing import overload, Any, Callable, Dict, List, NoReturn, Optional, Tuple, Type, Union 6 7import torch 8from torch.utils.benchmark.utils import common, cpp_jit 9from torch.utils.benchmark.utils._stubs import TimerClass, TimeitModuleType 10from torch.utils.benchmark.utils.valgrind_wrapper import timer_interface as valgrind_timer_interface 11 12 13__all__ = ["Timer", "timer", "Language"] 14 15 16if torch.backends.cuda.is_built() and torch.cuda.is_available(): # type: ignore[no-untyped-call] 17 def timer() -> float: 18 torch.cuda.synchronize() 19 return timeit.default_timer() 20elif torch._C._get_privateuse1_backend_name() != "privateuseone": 21 privateuse1_device_handler = getattr(torch, torch._C._get_privateuse1_backend_name(), None) \ 22 if torch._C._get_privateuse1_backend_name() != "cpu" else None 23 24 def timer() -> float: 25 if privateuse1_device_handler: 26 privateuse1_device_handler.synchronize() 27 return timeit.default_timer() 28else: 29 timer = timeit.default_timer 30 31 32class Language(enum.Enum): 33 PYTHON = 0 34 CPP = 1 35 36 37class CPPTimer: 38 def __init__( 39 self, 40 stmt: str, 41 setup: str, 42 global_setup: str, 43 timer: Callable[[], float], 44 globals: Dict[str, Any], 45 ) -> None: 46 if timer is not timeit.default_timer: 47 raise NotImplementedError( 48 "PyTorch was built with CUDA and a GPU is present; however " 49 "Timer does not yet support GPU measurements. If your " 50 "code is CPU only, pass `timer=timeit.default_timer` to the " 51 "Timer's constructor to indicate this. (Note that this will " 52 "produce incorrect results if the GPU is in fact used, as " 53 "Timer will not synchronize CUDA.)" 54 ) 55 56 if globals: 57 raise ValueError("C++ timing does not support globals.") 58 59 self._stmt: str = textwrap.dedent(stmt) 60 self._setup: str = textwrap.dedent(setup) 61 self._global_setup: str = textwrap.dedent(global_setup) 62 self._timeit_module: Optional[TimeitModuleType] = None 63 64 def timeit(self, number: int) -> float: 65 if self._timeit_module is None: 66 self._timeit_module = cpp_jit.compile_timeit_template( 67 stmt=self._stmt, 68 setup=self._setup, 69 global_setup=self._global_setup, 70 ) 71 72 return self._timeit_module.timeit(number) 73 74 75class Timer: 76 """Helper class for measuring execution time of PyTorch statements. 77 78 For a full tutorial on how to use this class, see: 79 https://pytorch.org/tutorials/recipes/recipes/benchmark.html 80 81 The PyTorch Timer is based on `timeit.Timer` (and in fact uses 82 `timeit.Timer` internally), but with several key differences: 83 84 1) Runtime aware: 85 Timer will perform warmups (important as some elements of PyTorch are 86 lazily initialized), set threadpool size so that comparisons are 87 apples-to-apples, and synchronize asynchronous CUDA functions when 88 necessary. 89 90 2) Focus on replicates: 91 When measuring code, and particularly complex kernels / models, 92 run-to-run variation is a significant confounding factor. It is 93 expected that all measurements should include replicates to quantify 94 noise and allow median computation, which is more robust than mean. 95 To that effect, this class deviates from the `timeit` API by 96 conceptually merging `timeit.Timer.repeat` and `timeit.Timer.autorange`. 97 (Exact algorithms are discussed in method docstrings.) The `timeit` 98 method is replicated for cases where an adaptive strategy is not 99 desired. 100 101 3) Optional metadata: 102 When defining a Timer, one can optionally specify `label`, `sub_label`, 103 `description`, and `env`. (Defined later) These fields are included in 104 the representation of result object and by the `Compare` class to group 105 and display results for comparison. 106 107 4) Instruction counts 108 In addition to wall times, Timer can run a statement under Callgrind 109 and report instructions executed. 110 111 Directly analogous to `timeit.Timer` constructor arguments: 112 113 `stmt`, `setup`, `timer`, `globals` 114 115 PyTorch Timer specific constructor arguments: 116 117 `label`, `sub_label`, `description`, `env`, `num_threads` 118 119 Args: 120 stmt: Code snippet to be run in a loop and timed. 121 122 setup: Optional setup code. Used to define variables used in `stmt` 123 124 global_setup: (C++ only) 125 Code which is placed at the top level of the file for things like 126 `#include` statements. 127 128 timer: 129 Callable which returns the current time. If PyTorch was built 130 without CUDA or there is no GPU present, this defaults to 131 `timeit.default_timer`; otherwise it will synchronize CUDA before 132 measuring the time. 133 134 globals: 135 A dict which defines the global variables when `stmt` is being 136 executed. This is the other method for providing variables which 137 `stmt` needs. 138 139 label: 140 String which summarizes `stmt`. For instance, if `stmt` is 141 "torch.nn.functional.relu(torch.add(x, 1, out=out))" 142 one might set label to "ReLU(x + 1)" to improve readability. 143 144 sub_label: 145 Provide supplemental information to disambiguate measurements 146 with identical stmt or label. For instance, in our example 147 above sub_label might be "float" or "int", so that it is easy 148 to differentiate: 149 "ReLU(x + 1): (float)" 150 151 "ReLU(x + 1): (int)" 152 when printing Measurements or summarizing using `Compare`. 153 154 description: 155 String to distinguish measurements with identical label and 156 sub_label. The principal use of `description` is to signal to 157 `Compare` the columns of data. For instance one might set it 158 based on the input size to create a table of the form: :: 159 160 | n=1 | n=4 | ... 161 ------------- ... 162 ReLU(x + 1): (float) | ... | ... | ... 163 ReLU(x + 1): (int) | ... | ... | ... 164 165 166 using `Compare`. It is also included when printing a Measurement. 167 168 env: 169 This tag indicates that otherwise identical tasks were run in 170 different environments, and are therefore not equivalent, for 171 instance when A/B testing a change to a kernel. `Compare` will 172 treat Measurements with different `env` specification as distinct 173 when merging replicate runs. 174 175 num_threads: 176 The size of the PyTorch threadpool when executing `stmt`. Single 177 threaded performance is important as both a key inference workload 178 and a good indicator of intrinsic algorithmic efficiency, so the 179 default is set to one. This is in contrast to the default PyTorch 180 threadpool size which tries to utilize all cores. 181 """ 182 183 _timer_cls: Type[TimerClass] = timeit.Timer 184 185 def __init__( 186 self, 187 stmt: str = "pass", 188 setup: str = "pass", 189 global_setup: str = "", 190 timer: Callable[[], float] = timer, 191 globals: Optional[Dict[str, Any]] = None, 192 label: Optional[str] = None, 193 sub_label: Optional[str] = None, 194 description: Optional[str] = None, 195 env: Optional[str] = None, 196 num_threads: int = 1, 197 language: Union[Language, str] = Language.PYTHON, 198 ): 199 if not isinstance(stmt, str): 200 raise ValueError("Currently only a `str` stmt is supported.") 201 202 # We copy `globals` to prevent mutations from leaking. 203 # (For instance, `eval` adds the `__builtins__` key) 204 self._globals = dict(globals or {}) 205 206 timer_kwargs = {} 207 if language in (Language.PYTHON, "py", "python"): 208 # Include `torch` if not specified as a convenience feature. 209 self._globals.setdefault("torch", torch) 210 self._language: Language = Language.PYTHON 211 if global_setup: 212 raise ValueError( 213 f"global_setup is C++ only, got `{global_setup}`. Most " 214 "likely this code can simply be moved to `setup`." 215 ) 216 217 elif language in (Language.CPP, "cpp", "c++"): 218 assert self._timer_cls is timeit.Timer, "_timer_cls has already been swapped." 219 self._timer_cls = CPPTimer 220 setup = ("" if setup == "pass" else setup) 221 self._language = Language.CPP 222 timer_kwargs["global_setup"] = global_setup 223 224 else: 225 raise ValueError(f"Invalid language `{language}`.") 226 227 # Convenience adjustment so that multi-line code snippets defined in 228 # functions do not IndentationError (Python) or look odd (C++). The 229 # leading newline removal is for the initial newline that appears when 230 # defining block strings. For instance: 231 # textwrap.dedent(""" 232 # print("This is a stmt") 233 # """) 234 # produces '\nprint("This is a stmt")\n'. 235 # 236 # Stripping this down to 'print("This is a stmt")' doesn't change 237 # what gets executed, but it makes __repr__'s nicer. 238 stmt = textwrap.dedent(stmt) 239 stmt = (stmt[1:] if stmt and stmt[0] == "\n" else stmt).rstrip() 240 setup = textwrap.dedent(setup) 241 setup = (setup[1:] if setup and setup[0] == "\n" else setup).rstrip() 242 243 self._timer = self._timer_cls( 244 stmt=stmt, 245 setup=setup, 246 timer=timer, 247 globals=valgrind_timer_interface.CopyIfCallgrind.unwrap_all(self._globals), 248 **timer_kwargs, 249 ) 250 self._task_spec = common.TaskSpec( 251 stmt=stmt, 252 setup=setup, 253 global_setup=global_setup, 254 label=label, 255 sub_label=sub_label, 256 description=description, 257 env=env, 258 num_threads=num_threads, 259 ) 260 261 def _timeit(self, number: int) -> float: 262 # Even calling a timer in C++ takes ~50 ns, so no real operation should 263 # take less than 1 ns. (And this prevents divide by zero errors.) 264 return max(self._timer.timeit(number), 1e-9) 265 266 def timeit(self, number: int = 1000000) -> common.Measurement: 267 """Mirrors the semantics of timeit.Timer.timeit(). 268 269 Execute the main statement (`stmt`) `number` times. 270 https://docs.python.org/3/library/timeit.html#timeit.Timer.timeit 271 """ 272 with common.set_torch_threads(self._task_spec.num_threads): 273 # Warmup 274 self._timeit(number=max(int(number // 100), 2)) 275 276 return common.Measurement( 277 number_per_run=number, 278 raw_times=[self._timeit(number=number)], 279 task_spec=self._task_spec 280 ) 281 282 def repeat(self, repeat: int = -1, number: int = -1) -> None: 283 raise NotImplementedError("See `Timer.blocked_autorange.`") 284 285 def autorange(self, callback: Optional[Callable[[int, float], NoReturn]] = None) -> None: 286 raise NotImplementedError("See `Timer.blocked_autorange.`") 287 288 def _threaded_measurement_loop( 289 self, 290 number: int, 291 time_hook: Callable[[], float], 292 stop_hook: Callable[[List[float]], bool], 293 min_run_time: float, 294 max_run_time: Optional[float] = None, 295 callback: Optional[Callable[[int, float], NoReturn]] = None 296 ) -> List[float]: 297 total_time = 0.0 298 can_stop = False 299 times: List[float] = [] 300 with common.set_torch_threads(self._task_spec.num_threads): 301 while (total_time < min_run_time) or (not can_stop): 302 time_spent = time_hook() 303 times.append(time_spent) 304 total_time += time_spent 305 if callback: 306 callback(number, time_spent) 307 can_stop = stop_hook(times) 308 if max_run_time and total_time > max_run_time: 309 break 310 return times 311 312 def _estimate_block_size(self, min_run_time: float) -> int: 313 with common.set_torch_threads(self._task_spec.num_threads): 314 # Estimate the block size needed for measurement to be negligible 315 # compared to the inner loop. This also serves as a warmup. 316 overhead = torch.tensor([self._timeit(0) for _ in range(5)]).median().item() 317 number = 1 318 while True: 319 time_taken = self._timeit(number) 320 relative_overhead = overhead / time_taken 321 if relative_overhead <= 1e-4 and time_taken >= min_run_time / 1000: 322 break 323 if time_taken > min_run_time: 324 break 325 # Avoid overflow in C++ pybind11 interface 326 if number * 10 > 2147483647: 327 break 328 number *= 10 329 return number 330 331 def blocked_autorange( 332 self, 333 callback: Optional[Callable[[int, float], NoReturn]] = None, 334 min_run_time: float = 0.2, 335 ) -> common.Measurement: 336 """Measure many replicates while keeping timer overhead to a minimum. 337 338 At a high level, blocked_autorange executes the following pseudo-code:: 339 340 `setup` 341 342 total_time = 0 343 while total_time < min_run_time 344 start = timer() 345 for _ in range(block_size): 346 `stmt` 347 total_time += (timer() - start) 348 349 Note the variable `block_size` in the inner loop. The choice of block 350 size is important to measurement quality, and must balance two 351 competing objectives: 352 353 1) A small block size results in more replicates and generally 354 better statistics. 355 356 2) A large block size better amortizes the cost of `timer` 357 invocation, and results in a less biased measurement. This is 358 important because CUDA synchronization time is non-trivial 359 (order single to low double digit microseconds) and would 360 otherwise bias the measurement. 361 362 blocked_autorange sets block_size by running a warmup period, 363 increasing block size until timer overhead is less than 0.1% of 364 the overall computation. This value is then used for the main 365 measurement loop. 366 367 Returns: 368 A `Measurement` object that contains measured runtimes and 369 repetition counts, and can be used to compute statistics. 370 (mean, median, etc.) 371 """ 372 number = self._estimate_block_size(min_run_time) 373 374 def time_hook() -> float: 375 return self._timeit(number) 376 377 def stop_hook(times: List[float]) -> bool: 378 return True 379 380 times = self._threaded_measurement_loop( 381 number, time_hook, stop_hook, 382 min_run_time=min_run_time, 383 callback=callback) 384 385 return common.Measurement( 386 number_per_run=number, 387 raw_times=times, 388 task_spec=self._task_spec 389 ) 390 391 def adaptive_autorange( 392 self, 393 threshold: float = 0.1, 394 *, 395 min_run_time: float = 0.01, 396 max_run_time: float = 10.0, 397 callback: Optional[Callable[[int, float], NoReturn]] = None, 398 ) -> common.Measurement: 399 """Similar to `blocked_autorange` but also checks for variablility in measurements 400 and repeats until iqr/median is smaller than `threshold` or `max_run_time` is reached. 401 402 403 At a high level, adaptive_autorange executes the following pseudo-code:: 404 405 `setup` 406 407 times = [] 408 while times.sum < max_run_time 409 start = timer() 410 for _ in range(block_size): 411 `stmt` 412 times.append(timer() - start) 413 414 enough_data = len(times)>3 and times.sum > min_run_time 415 small_iqr=times.iqr/times.mean<threshold 416 417 if enough_data and small_iqr: 418 break 419 420 Args: 421 threshold: value of iqr/median threshold for stopping 422 423 min_run_time: total runtime needed before checking `threshold` 424 425 max_run_time: total runtime for all measurements regardless of `threshold` 426 427 Returns: 428 A `Measurement` object that contains measured runtimes and 429 repetition counts, and can be used to compute statistics. 430 (mean, median, etc.) 431 """ 432 number = self._estimate_block_size(min_run_time=0.05) 433 434 def time_hook() -> float: 435 return self._timeit(number) 436 437 def stop_hook(times: List[float]) -> bool: 438 if len(times) > 3: 439 return common.Measurement( 440 number_per_run=number, 441 raw_times=times, 442 task_spec=self._task_spec 443 ).meets_confidence(threshold=threshold) 444 return False 445 times = self._threaded_measurement_loop( 446 number, time_hook, stop_hook, min_run_time, max_run_time, callback=callback) 447 448 return common.Measurement( 449 number_per_run=number, 450 raw_times=times, 451 task_spec=self._task_spec 452 ) 453 454 @overload 455 def collect_callgrind( 456 self, 457 number: int, 458 *, 459 repeats: None, 460 collect_baseline: bool, 461 retain_out_file: bool, 462 ) -> valgrind_timer_interface.CallgrindStats: 463 ... 464 465 @overload 466 def collect_callgrind( 467 self, 468 number: int, 469 *, 470 repeats: int, 471 collect_baseline: bool, 472 retain_out_file: bool, 473 ) -> Tuple[valgrind_timer_interface.CallgrindStats, ...]: 474 ... 475 476 def collect_callgrind( 477 self, 478 number: int = 100, 479 *, 480 repeats: Optional[int] = None, 481 collect_baseline: bool = True, 482 retain_out_file: bool = False, 483 ) -> Any: 484 """Collect instruction counts using Callgrind. 485 486 Unlike wall times, instruction counts are deterministic 487 (modulo non-determinism in the program itself and small amounts of 488 jitter from the Python interpreter.) This makes them ideal for detailed 489 performance analysis. This method runs `stmt` in a separate process 490 so that Valgrind can instrument the program. Performance is severely 491 degraded due to the instrumentation, however this is ameliorated by 492 the fact that a small number of iterations is generally sufficient to 493 obtain good measurements. 494 495 In order to to use this method `valgrind`, `callgrind_control`, and 496 `callgrind_annotate` must be installed. 497 498 Because there is a process boundary between the caller (this process) 499 and the `stmt` execution, `globals` cannot contain arbitrary in-memory 500 data structures. (Unlike timing methods) Instead, globals are 501 restricted to builtins, `nn.Modules`'s, and TorchScripted functions/modules 502 to reduce the surprise factor from serialization and subsequent 503 deserialization. The `GlobalsBridge` class provides more detail on this 504 subject. Take particular care with nn.Modules: they rely on pickle and 505 you may need to add an import to `setup` for them to transfer properly. 506 507 By default, a profile for an empty statement will be collected and 508 cached to indicate how many instructions are from the Python loop which 509 drives `stmt`. 510 511 Returns: 512 A `CallgrindStats` object which provides instruction counts and 513 some basic facilities for analyzing and manipulating results. 514 """ 515 if not isinstance(self._task_spec.stmt, str): 516 raise ValueError("`collect_callgrind` currently only supports string `stmt`") 517 518 if repeats is not None and repeats < 1: 519 raise ValueError("If specified, `repeats` must be >= 1") 520 521 # Check that the statement is valid. It doesn't guarantee success, but it's much 522 # simpler and quicker to raise an exception for a faulty `stmt` or `setup` in 523 # the parent process rather than the valgrind subprocess. 524 self._timeit(1) 525 is_python = (self._language == Language.PYTHON) 526 assert is_python or not self._globals 527 result = valgrind_timer_interface.wrapper_singleton().collect_callgrind( 528 task_spec=self._task_spec, 529 globals=self._globals, 530 number=number, 531 repeats=repeats or 1, 532 collect_baseline=collect_baseline and is_python, 533 is_python=is_python, 534 retain_out_file=retain_out_file, 535 ) 536 537 return (result[0] if repeats is None else result) 538