xref: /aosp_15_r20/external/pytorch/torch/utils/benchmark/utils/timer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Timer class based on the timeit.Timer class, but torch aware."""
2import enum
3import timeit
4import textwrap
5from typing import overload, Any, Callable, Dict, List, NoReturn, Optional, Tuple, Type, Union
6
7import torch
8from torch.utils.benchmark.utils import common, cpp_jit
9from torch.utils.benchmark.utils._stubs import TimerClass, TimeitModuleType
10from torch.utils.benchmark.utils.valgrind_wrapper import timer_interface as valgrind_timer_interface
11
12
13__all__ = ["Timer", "timer", "Language"]
14
15
16if torch.backends.cuda.is_built() and torch.cuda.is_available():  # type: ignore[no-untyped-call]
17    def timer() -> float:
18        torch.cuda.synchronize()
19        return timeit.default_timer()
20elif torch._C._get_privateuse1_backend_name() != "privateuseone":
21    privateuse1_device_handler = getattr(torch, torch._C._get_privateuse1_backend_name(), None) \
22        if torch._C._get_privateuse1_backend_name() != "cpu" else None
23
24    def timer() -> float:
25        if privateuse1_device_handler:
26            privateuse1_device_handler.synchronize()
27        return timeit.default_timer()
28else:
29    timer = timeit.default_timer
30
31
32class Language(enum.Enum):
33    PYTHON = 0
34    CPP = 1
35
36
37class CPPTimer:
38    def __init__(
39        self,
40        stmt: str,
41        setup: str,
42        global_setup: str,
43        timer: Callable[[], float],
44        globals: Dict[str, Any],
45    ) -> None:
46        if timer is not timeit.default_timer:
47            raise NotImplementedError(
48                "PyTorch was built with CUDA and a GPU is present; however "
49                "Timer does not yet support GPU measurements. If your "
50                "code is CPU only, pass `timer=timeit.default_timer` to the "
51                "Timer's constructor to indicate this. (Note that this will "
52                "produce incorrect results if the GPU is in fact used, as "
53                "Timer will not synchronize CUDA.)"
54            )
55
56        if globals:
57            raise ValueError("C++ timing does not support globals.")
58
59        self._stmt: str = textwrap.dedent(stmt)
60        self._setup: str = textwrap.dedent(setup)
61        self._global_setup: str = textwrap.dedent(global_setup)
62        self._timeit_module: Optional[TimeitModuleType] = None
63
64    def timeit(self, number: int) -> float:
65        if self._timeit_module is None:
66            self._timeit_module = cpp_jit.compile_timeit_template(
67                stmt=self._stmt,
68                setup=self._setup,
69                global_setup=self._global_setup,
70            )
71
72        return self._timeit_module.timeit(number)
73
74
75class Timer:
76    """Helper class for measuring execution time of PyTorch statements.
77
78    For a full tutorial on how to use this class, see:
79    https://pytorch.org/tutorials/recipes/recipes/benchmark.html
80
81    The PyTorch Timer is based on `timeit.Timer` (and in fact uses
82    `timeit.Timer` internally), but with several key differences:
83
84    1) Runtime aware:
85        Timer will perform warmups (important as some elements of PyTorch are
86        lazily initialized), set threadpool size so that comparisons are
87        apples-to-apples, and synchronize asynchronous CUDA functions when
88        necessary.
89
90    2) Focus on replicates:
91        When measuring code, and particularly complex kernels / models,
92        run-to-run variation is a significant confounding factor. It is
93        expected that all measurements should include replicates to quantify
94        noise and allow median computation, which is more robust than mean.
95        To that effect, this class deviates from the `timeit` API by
96        conceptually merging `timeit.Timer.repeat` and `timeit.Timer.autorange`.
97        (Exact algorithms are discussed in method docstrings.) The `timeit`
98        method is replicated for cases where an adaptive strategy is not
99        desired.
100
101    3) Optional metadata:
102        When defining a Timer, one can optionally specify `label`, `sub_label`,
103        `description`, and `env`. (Defined later) These fields are included in
104        the representation of result object and by the `Compare` class to group
105        and display results for comparison.
106
107    4) Instruction counts
108        In addition to wall times, Timer can run a statement under Callgrind
109        and report instructions executed.
110
111    Directly analogous to `timeit.Timer` constructor arguments:
112
113        `stmt`, `setup`, `timer`, `globals`
114
115    PyTorch Timer specific constructor arguments:
116
117        `label`, `sub_label`, `description`, `env`, `num_threads`
118
119    Args:
120        stmt: Code snippet to be run in a loop and timed.
121
122        setup: Optional setup code. Used to define variables used in `stmt`
123
124        global_setup: (C++ only)
125            Code which is placed at the top level of the file for things like
126            `#include` statements.
127
128        timer:
129            Callable which returns the current time. If PyTorch was built
130            without CUDA or there is no GPU present, this defaults to
131            `timeit.default_timer`; otherwise it will synchronize CUDA before
132            measuring the time.
133
134        globals:
135            A dict which defines the global variables when `stmt` is being
136            executed. This is the other method for providing variables which
137            `stmt` needs.
138
139        label:
140            String which summarizes `stmt`. For instance, if `stmt` is
141            "torch.nn.functional.relu(torch.add(x, 1, out=out))"
142            one might set label to "ReLU(x + 1)" to improve readability.
143
144        sub_label:
145            Provide supplemental information to disambiguate measurements
146            with identical stmt or label. For instance, in our example
147            above sub_label might be "float" or "int", so that it is easy
148            to differentiate:
149            "ReLU(x + 1): (float)"
150
151            "ReLU(x + 1): (int)"
152            when printing Measurements or summarizing using `Compare`.
153
154        description:
155            String to distinguish measurements with identical label and
156            sub_label. The principal use of `description` is to signal to
157            `Compare` the columns of data. For instance one might set it
158            based on the input size  to create a table of the form: ::
159
160                                        | n=1 | n=4 | ...
161                                        ------------- ...
162                ReLU(x + 1): (float)    | ... | ... | ...
163                ReLU(x + 1): (int)      | ... | ... | ...
164
165
166            using `Compare`. It is also included when printing a Measurement.
167
168        env:
169            This tag indicates that otherwise identical tasks were run in
170            different environments, and are therefore not equivalent, for
171            instance when A/B testing a change to a kernel. `Compare` will
172            treat Measurements with different `env` specification as distinct
173            when merging replicate runs.
174
175        num_threads:
176            The size of the PyTorch threadpool when executing `stmt`. Single
177            threaded performance is important as both a key inference workload
178            and a good indicator of intrinsic algorithmic efficiency, so the
179            default is set to one. This is in contrast to the default PyTorch
180            threadpool size which tries to utilize all cores.
181    """
182
183    _timer_cls: Type[TimerClass] = timeit.Timer
184
185    def __init__(
186        self,
187        stmt: str = "pass",
188        setup: str = "pass",
189        global_setup: str = "",
190        timer: Callable[[], float] = timer,
191        globals: Optional[Dict[str, Any]] = None,
192        label: Optional[str] = None,
193        sub_label: Optional[str] = None,
194        description: Optional[str] = None,
195        env: Optional[str] = None,
196        num_threads: int = 1,
197        language: Union[Language, str] = Language.PYTHON,
198    ):
199        if not isinstance(stmt, str):
200            raise ValueError("Currently only a `str` stmt is supported.")
201
202        # We copy `globals` to prevent mutations from leaking.
203        # (For instance, `eval` adds the `__builtins__` key)
204        self._globals = dict(globals or {})
205
206        timer_kwargs = {}
207        if language in (Language.PYTHON, "py", "python"):
208            # Include `torch` if not specified as a convenience feature.
209            self._globals.setdefault("torch", torch)
210            self._language: Language = Language.PYTHON
211            if global_setup:
212                raise ValueError(
213                    f"global_setup is C++ only, got `{global_setup}`. Most "
214                    "likely this code can simply be moved to `setup`."
215                )
216
217        elif language in (Language.CPP, "cpp", "c++"):
218            assert self._timer_cls is timeit.Timer, "_timer_cls has already been swapped."
219            self._timer_cls = CPPTimer
220            setup = ("" if setup == "pass" else setup)
221            self._language = Language.CPP
222            timer_kwargs["global_setup"] = global_setup
223
224        else:
225            raise ValueError(f"Invalid language `{language}`.")
226
227        # Convenience adjustment so that multi-line code snippets defined in
228        # functions do not IndentationError (Python) or look odd (C++). The
229        # leading newline removal is for the initial newline that appears when
230        # defining block strings. For instance:
231        #   textwrap.dedent("""
232        #     print("This is a stmt")
233        #   """)
234        # produces '\nprint("This is a stmt")\n'.
235        #
236        # Stripping this down to 'print("This is a stmt")' doesn't change
237        # what gets executed, but it makes __repr__'s nicer.
238        stmt = textwrap.dedent(stmt)
239        stmt = (stmt[1:] if stmt and stmt[0] == "\n" else stmt).rstrip()
240        setup = textwrap.dedent(setup)
241        setup = (setup[1:] if setup and setup[0] == "\n" else setup).rstrip()
242
243        self._timer = self._timer_cls(
244            stmt=stmt,
245            setup=setup,
246            timer=timer,
247            globals=valgrind_timer_interface.CopyIfCallgrind.unwrap_all(self._globals),
248            **timer_kwargs,
249        )
250        self._task_spec = common.TaskSpec(
251            stmt=stmt,
252            setup=setup,
253            global_setup=global_setup,
254            label=label,
255            sub_label=sub_label,
256            description=description,
257            env=env,
258            num_threads=num_threads,
259        )
260
261    def _timeit(self, number: int) -> float:
262        # Even calling a timer in C++ takes ~50 ns, so no real operation should
263        # take less than 1 ns. (And this prevents divide by zero errors.)
264        return max(self._timer.timeit(number), 1e-9)
265
266    def timeit(self, number: int = 1000000) -> common.Measurement:
267        """Mirrors the semantics of timeit.Timer.timeit().
268
269        Execute the main statement (`stmt`) `number` times.
270        https://docs.python.org/3/library/timeit.html#timeit.Timer.timeit
271        """
272        with common.set_torch_threads(self._task_spec.num_threads):
273            # Warmup
274            self._timeit(number=max(int(number // 100), 2))
275
276            return common.Measurement(
277                number_per_run=number,
278                raw_times=[self._timeit(number=number)],
279                task_spec=self._task_spec
280            )
281
282    def repeat(self, repeat: int = -1, number: int = -1) -> None:
283        raise NotImplementedError("See `Timer.blocked_autorange.`")
284
285    def autorange(self, callback: Optional[Callable[[int, float], NoReturn]] = None) -> None:
286        raise NotImplementedError("See `Timer.blocked_autorange.`")
287
288    def _threaded_measurement_loop(
289        self,
290        number: int,
291        time_hook: Callable[[], float],
292        stop_hook: Callable[[List[float]], bool],
293        min_run_time: float,
294        max_run_time: Optional[float] = None,
295        callback: Optional[Callable[[int, float], NoReturn]] = None
296    ) -> List[float]:
297        total_time = 0.0
298        can_stop = False
299        times: List[float] = []
300        with common.set_torch_threads(self._task_spec.num_threads):
301            while (total_time < min_run_time) or (not can_stop):
302                time_spent = time_hook()
303                times.append(time_spent)
304                total_time += time_spent
305                if callback:
306                    callback(number, time_spent)
307                can_stop = stop_hook(times)
308                if max_run_time and total_time > max_run_time:
309                    break
310        return times
311
312    def _estimate_block_size(self, min_run_time: float) -> int:
313        with common.set_torch_threads(self._task_spec.num_threads):
314            # Estimate the block size needed for measurement to be negligible
315            # compared to the inner loop. This also serves as a warmup.
316            overhead = torch.tensor([self._timeit(0) for _ in range(5)]).median().item()
317            number = 1
318            while True:
319                time_taken = self._timeit(number)
320                relative_overhead = overhead / time_taken
321                if relative_overhead <= 1e-4 and time_taken >= min_run_time / 1000:
322                    break
323                if time_taken > min_run_time:
324                    break
325                # Avoid overflow in C++ pybind11 interface
326                if number * 10 > 2147483647:
327                    break
328                number *= 10
329        return number
330
331    def blocked_autorange(
332        self,
333        callback: Optional[Callable[[int, float], NoReturn]] = None,
334        min_run_time: float = 0.2,
335    ) -> common.Measurement:
336        """Measure many replicates while keeping timer overhead to a minimum.
337
338        At a high level, blocked_autorange executes the following pseudo-code::
339
340            `setup`
341
342            total_time = 0
343            while total_time < min_run_time
344                start = timer()
345                for _ in range(block_size):
346                    `stmt`
347                total_time += (timer() - start)
348
349        Note the variable `block_size` in the inner loop. The choice of block
350        size is important to measurement quality, and must balance two
351        competing objectives:
352
353            1) A small block size results in more replicates and generally
354               better statistics.
355
356            2) A large block size better amortizes the cost of `timer`
357               invocation, and results in a less biased measurement. This is
358               important because CUDA synchronization time is non-trivial
359               (order single to low double digit microseconds) and would
360               otherwise bias the measurement.
361
362        blocked_autorange sets block_size by running a warmup period,
363        increasing block size until timer overhead is less than 0.1% of
364        the overall computation. This value is then used for the main
365        measurement loop.
366
367        Returns:
368            A `Measurement` object that contains measured runtimes and
369            repetition counts, and can be used to compute statistics.
370            (mean, median, etc.)
371        """
372        number = self._estimate_block_size(min_run_time)
373
374        def time_hook() -> float:
375            return self._timeit(number)
376
377        def stop_hook(times: List[float]) -> bool:
378            return True
379
380        times = self._threaded_measurement_loop(
381            number, time_hook, stop_hook,
382            min_run_time=min_run_time,
383            callback=callback)
384
385        return common.Measurement(
386            number_per_run=number,
387            raw_times=times,
388            task_spec=self._task_spec
389        )
390
391    def adaptive_autorange(
392            self,
393            threshold: float = 0.1,
394            *,
395            min_run_time: float = 0.01,
396            max_run_time: float = 10.0,
397            callback: Optional[Callable[[int, float], NoReturn]] = None,
398    ) -> common.Measurement:
399        """Similar to `blocked_autorange` but also checks for variablility in measurements
400        and repeats until iqr/median is smaller than `threshold` or `max_run_time` is reached.
401
402
403        At a high level, adaptive_autorange executes the following pseudo-code::
404
405            `setup`
406
407            times = []
408            while times.sum < max_run_time
409                start = timer()
410                for _ in range(block_size):
411                    `stmt`
412                times.append(timer() - start)
413
414                enough_data = len(times)>3 and times.sum > min_run_time
415                small_iqr=times.iqr/times.mean<threshold
416
417                if enough_data and small_iqr:
418                    break
419
420        Args:
421            threshold: value of iqr/median threshold for stopping
422
423            min_run_time: total runtime needed before checking `threshold`
424
425            max_run_time: total runtime  for all measurements regardless of `threshold`
426
427        Returns:
428            A `Measurement` object that contains measured runtimes and
429            repetition counts, and can be used to compute statistics.
430            (mean, median, etc.)
431        """
432        number = self._estimate_block_size(min_run_time=0.05)
433
434        def time_hook() -> float:
435            return self._timeit(number)
436
437        def stop_hook(times: List[float]) -> bool:
438            if len(times) > 3:
439                return common.Measurement(
440                    number_per_run=number,
441                    raw_times=times,
442                    task_spec=self._task_spec
443                ).meets_confidence(threshold=threshold)
444            return False
445        times = self._threaded_measurement_loop(
446            number, time_hook, stop_hook, min_run_time, max_run_time, callback=callback)
447
448        return common.Measurement(
449            number_per_run=number,
450            raw_times=times,
451            task_spec=self._task_spec
452        )
453
454    @overload
455    def collect_callgrind(
456        self,
457        number: int,
458        *,
459        repeats: None,
460        collect_baseline: bool,
461        retain_out_file: bool,
462    ) -> valgrind_timer_interface.CallgrindStats:
463        ...
464
465    @overload
466    def collect_callgrind(
467        self,
468        number: int,
469        *,
470        repeats: int,
471        collect_baseline: bool,
472        retain_out_file: bool,
473    ) -> Tuple[valgrind_timer_interface.CallgrindStats, ...]:
474        ...
475
476    def collect_callgrind(
477        self,
478        number: int = 100,
479        *,
480        repeats: Optional[int] = None,
481        collect_baseline: bool = True,
482        retain_out_file: bool = False,
483    ) -> Any:
484        """Collect instruction counts using Callgrind.
485
486        Unlike wall times, instruction counts are deterministic
487        (modulo non-determinism in the program itself and small amounts of
488        jitter from the Python interpreter.) This makes them ideal for detailed
489        performance analysis. This method runs `stmt` in a separate process
490        so that Valgrind can instrument the program. Performance is severely
491        degraded due to the instrumentation, however this is ameliorated by
492        the fact that a small number of iterations is generally sufficient to
493        obtain good measurements.
494
495        In order to to use this method `valgrind`, `callgrind_control`, and
496        `callgrind_annotate` must be installed.
497
498        Because there is a process boundary between the caller (this process)
499        and the `stmt` execution, `globals` cannot contain arbitrary in-memory
500        data structures. (Unlike timing methods) Instead, globals are
501        restricted to builtins, `nn.Modules`'s, and TorchScripted functions/modules
502        to reduce the surprise factor from serialization and subsequent
503        deserialization. The `GlobalsBridge` class provides more detail on this
504        subject. Take particular care with nn.Modules: they rely on pickle and
505        you may need to add an import to `setup` for them to transfer properly.
506
507        By default, a profile for an empty statement will be collected and
508        cached to indicate how many instructions are from the Python loop which
509        drives `stmt`.
510
511        Returns:
512            A `CallgrindStats` object which provides instruction counts and
513            some basic facilities for analyzing and manipulating results.
514        """
515        if not isinstance(self._task_spec.stmt, str):
516            raise ValueError("`collect_callgrind` currently only supports string `stmt`")
517
518        if repeats is not None and repeats < 1:
519            raise ValueError("If specified, `repeats` must be >= 1")
520
521        # Check that the statement is valid. It doesn't guarantee success, but it's much
522        # simpler and quicker to raise an exception for a faulty `stmt` or `setup` in
523        # the parent process rather than the valgrind subprocess.
524        self._timeit(1)
525        is_python = (self._language == Language.PYTHON)
526        assert is_python or not self._globals
527        result = valgrind_timer_interface.wrapper_singleton().collect_callgrind(
528            task_spec=self._task_spec,
529            globals=self._globals,
530            number=number,
531            repeats=repeats or 1,
532            collect_baseline=collect_baseline and is_python,
533            is_python=is_python,
534            retain_out_file=retain_out_file,
535        )
536
537        return (result[0] if repeats is None else result)
538