xref: /aosp_15_r20/external/pytorch/torch/_inductor/runtime/benchmarking.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import time
2from functools import cached_property, wraps
3from itertools import chain
4from statistics import median
5from typing import Any, Callable, Dict, List, Tuple
6from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
7
8import torch
9from torch._dynamo.utils import counters
10
11
12logger = torch._logging.getArtifactLogger(__name__, "benchmarking")
13
14
15MILLISECONDS_PER_SECOND = 1000
16
17P = ParamSpec("P")
18T = TypeVar("T")
19
20
21def maybe_time(
22    fn: Callable[Concatenate[Any, P], T]
23) -> Callable[Concatenate[Any, P], T]:
24    """Wrapper that logs the duration of `fn`, in milliseconds, along with a representation
25    of the function's args and kwargs, if logging is enabled. It is expected that `fn` is
26    a method of `Benchmarker` or one of its subclasses; typing limitations prevent us from
27    declaring this directly. If logging is disabled, this becomes a no-op.
28    """
29
30    # no-op if benchmarking-specific logging is disabled
31    if not torch._logging._internal.log_state.is_artifact_enabled("benchmarking"):
32        return fn
33
34    @wraps(fn)
35    def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
36        start_t = time.perf_counter()
37        result = fn(*args, **kwargs)
38        logger.debug(
39            "Call `benchmarking.%s.%s(*args=%r, **kwargs=%r)` took %f milliseconds.",
40            self.__class__.__name__,
41            fn.__name__,
42            args,
43            kwargs,
44            (time.perf_counter() - start_t) * MILLISECONDS_PER_SECOND,
45        )
46        return result
47
48    return wrapper
49
50
51def count(fn: Callable[Concatenate[Any, P], T]) -> Callable[Concatenate[Any, P], T]:
52    """Wrapper that increments relevant dynamo counters on `fn` call. It is expected that
53    `fn` is a method of `Benchmarker` or one of its subclass; typing limitations prevent
54    us from declaring this directly. The counter incrementation follows the formula,
55
56    `counters["inductor"]["benchmarking.Foo.bar] += 1`
57
58    where `Foo` is the class whose' instance called the function, and `bar` is the function name.
59    """
60
61    @wraps(fn)
62    def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T:
63        counters["inductor"][
64            "benchmarking." + self.__class__.__name__ + "." + fn.__name__
65        ] += 1
66        return fn(self, *args, **kwargs)
67
68    return wrapper
69
70
71class Benchmarker:
72    def __init__(self: Self) -> None:
73        pass
74
75    @maybe_time
76    @count
77    def benchmark(
78        self: Self,
79        fn: Callable[..., Any],
80        fn_args: Tuple[Any],
81        fn_kwargs: Dict[str, Any],
82        **kwargs: Any,
83    ) -> float:
84        """Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the
85        actual runtime calculation is dictated by the benchmarking implementation, but may be
86        one of [mean, median, minimum, etc.]). Functions as a convenience wrapper around
87        device-specific implementations, like `benchmark_cpu` and `benchmark_gpu`. Raises
88        `ValueError(...)` if we can't safely infer the device type of `fn`; for example,
89        if multiple device types are found in `fn_args` and `fn_kwargs`, or if no device
90        types are found.
91
92        Arguments:
93        - fn: The function to benchmark.
94        - fn_args: The function's arguments.
95        - fn_kwargs: The function's kwargs.
96
97        Keyword Arguments:
98        - **kwargs: The benchmarking implementation's kwargs.
99
100        Returns:
101        - The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds.
102        """
103        inferred_device = None
104        for arg_or_kwarg in chain(fn_args, fn_kwargs.values()):
105            if not isinstance(arg_or_kwarg, torch.Tensor):
106                continue
107            if inferred_device is None:
108                inferred_device = arg_or_kwarg.device
109            elif arg_or_kwarg.device != inferred_device:
110                raise ValueError(
111                    "Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!"
112                )
113        if inferred_device is None:
114            raise ValueError(
115                "Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly."  # noqa: B950
116            )
117        _callable = lambda: fn(*fn_args, **fn_kwargs)  # noqa: E731
118        if inferred_device == torch.device("cpu"):
119            return self.benchmark_cpu(_callable, **kwargs)
120        # TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking
121        # implementation which was written specifically with CUDA devices in mind, we may want to
122        # explore alternate implementations for other device types.
123        return self.benchmark_gpu(_callable, **kwargs)
124
125    @maybe_time
126    @count
127    def benchmark_cpu(
128        self: Self, _callable: Callable[[], Any], warmup: int = 20, rep: int = 100
129    ) -> float:
130        """Benchmark the CPU callable, `_callable`, and return the median runtime,
131        in milliseconds.
132
133        Arguments:
134        - _callable: The CPU callable to benchmark.
135
136        Keyword Arguments:
137        - warmup: Optionally, the duration, in milliseconds, to run `_callable`
138        before benchmarking starts.
139        - rep: Optionally, the duration, in milliseconds, to run `_callable`
140        during benchmarking.
141
142        Returns:
143        - The median runtime of `_callable`, in milliseconds.
144        """
145
146        def run_for(ms: int) -> List[float]:
147            timings = []
148            run_start_t = time.perf_counter()
149            while True:
150                start_t = time.perf_counter()
151                _callable()
152                end_t = time.perf_counter()
153                timings.append((end_t - start_t) * MILLISECONDS_PER_SECOND)
154                if ((end_t - run_start_t) * MILLISECONDS_PER_SECOND) > ms:
155                    break
156            return timings
157
158        run_for(warmup)
159        return median(run_for(rep))
160
161    @count
162    def benchmark_gpu(self: Self, *args: Any, **kwargs: Any) -> float:
163        raise NotImplementedError
164
165
166class TritonBenchmarker(Benchmarker):
167    @cached_property
168    @maybe_time
169    @count
170    def triton_do_bench(self: Self) -> Callable[..., Any]:
171        """Lazily import Triton's `do_bench`."""
172        try:
173            from triton.testing import do_bench
174        except ImportError as e:
175            raise NotImplementedError("requires Triton") from e
176        return do_bench
177
178    @maybe_time
179    @count
180    def benchmark_gpu(self: Self, _callable: Callable[[], Any], **kwargs: Any) -> float:
181        """Benchmark the GPU callable, `_callable`, and return the runtime, in milliseconds.
182
183        Arguments:
184        - _callable: The GPU callable to benchmark.
185
186        Keyword Arguments:
187        - quantiles: Optionally, a tuple of floats denoting the requested quantiles.
188        - return_mode: Optionally, the requested return mode. Currently, Triton's
189        `do_bench` supports min, max, mean, and median return modes.
190        - **kwargs: Additional kwargs passed to Triton's `do_bench`.
191
192        Returns:
193        - The runtime of `callable`, in milliseconds. If `kwargs["quantiles"]` is specified,
194        this is the first requested quantile. Else, if `kwargs["return_mode"]` is specified,
195        this is the requested return mode. Otherwise, this is the median.
196        """
197        if "quantiles" in kwargs:
198            return self.triton_do_bench(_callable, **kwargs)[0]
199        elif "return_mode" in kwargs:
200            return self.triton_do_bench(_callable, **kwargs)
201        return self.triton_do_bench(_callable, **kwargs, return_mode="median")
202
203
204benchmarker = TritonBenchmarker()
205