1import time 2from functools import cached_property, wraps 3from itertools import chain 4from statistics import median 5from typing import Any, Callable, Dict, List, Tuple 6from typing_extensions import Concatenate, ParamSpec, Self, TypeVar 7 8import torch 9from torch._dynamo.utils import counters 10 11 12logger = torch._logging.getArtifactLogger(__name__, "benchmarking") 13 14 15MILLISECONDS_PER_SECOND = 1000 16 17P = ParamSpec("P") 18T = TypeVar("T") 19 20 21def maybe_time( 22 fn: Callable[Concatenate[Any, P], T] 23) -> Callable[Concatenate[Any, P], T]: 24 """Wrapper that logs the duration of `fn`, in milliseconds, along with a representation 25 of the function's args and kwargs, if logging is enabled. It is expected that `fn` is 26 a method of `Benchmarker` or one of its subclasses; typing limitations prevent us from 27 declaring this directly. If logging is disabled, this becomes a no-op. 28 """ 29 30 # no-op if benchmarking-specific logging is disabled 31 if not torch._logging._internal.log_state.is_artifact_enabled("benchmarking"): 32 return fn 33 34 @wraps(fn) 35 def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T: 36 start_t = time.perf_counter() 37 result = fn(*args, **kwargs) 38 logger.debug( 39 "Call `benchmarking.%s.%s(*args=%r, **kwargs=%r)` took %f milliseconds.", 40 self.__class__.__name__, 41 fn.__name__, 42 args, 43 kwargs, 44 (time.perf_counter() - start_t) * MILLISECONDS_PER_SECOND, 45 ) 46 return result 47 48 return wrapper 49 50 51def count(fn: Callable[Concatenate[Any, P], T]) -> Callable[Concatenate[Any, P], T]: 52 """Wrapper that increments relevant dynamo counters on `fn` call. It is expected that 53 `fn` is a method of `Benchmarker` or one of its subclass; typing limitations prevent 54 us from declaring this directly. The counter incrementation follows the formula, 55 56 `counters["inductor"]["benchmarking.Foo.bar] += 1` 57 58 where `Foo` is the class whose' instance called the function, and `bar` is the function name. 59 """ 60 61 @wraps(fn) 62 def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T: 63 counters["inductor"][ 64 "benchmarking." + self.__class__.__name__ + "." + fn.__name__ 65 ] += 1 66 return fn(self, *args, **kwargs) 67 68 return wrapper 69 70 71class Benchmarker: 72 def __init__(self: Self) -> None: 73 pass 74 75 @maybe_time 76 @count 77 def benchmark( 78 self: Self, 79 fn: Callable[..., Any], 80 fn_args: Tuple[Any], 81 fn_kwargs: Dict[str, Any], 82 **kwargs: Any, 83 ) -> float: 84 """Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the 85 actual runtime calculation is dictated by the benchmarking implementation, but may be 86 one of [mean, median, minimum, etc.]). Functions as a convenience wrapper around 87 device-specific implementations, like `benchmark_cpu` and `benchmark_gpu`. Raises 88 `ValueError(...)` if we can't safely infer the device type of `fn`; for example, 89 if multiple device types are found in `fn_args` and `fn_kwargs`, or if no device 90 types are found. 91 92 Arguments: 93 - fn: The function to benchmark. 94 - fn_args: The function's arguments. 95 - fn_kwargs: The function's kwargs. 96 97 Keyword Arguments: 98 - **kwargs: The benchmarking implementation's kwargs. 99 100 Returns: 101 - The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds. 102 """ 103 inferred_device = None 104 for arg_or_kwarg in chain(fn_args, fn_kwargs.values()): 105 if not isinstance(arg_or_kwarg, torch.Tensor): 106 continue 107 if inferred_device is None: 108 inferred_device = arg_or_kwarg.device 109 elif arg_or_kwarg.device != inferred_device: 110 raise ValueError( 111 "Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!" 112 ) 113 if inferred_device is None: 114 raise ValueError( 115 "Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950 116 ) 117 _callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731 118 if inferred_device == torch.device("cpu"): 119 return self.benchmark_cpu(_callable, **kwargs) 120 # TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking 121 # implementation which was written specifically with CUDA devices in mind, we may want to 122 # explore alternate implementations for other device types. 123 return self.benchmark_gpu(_callable, **kwargs) 124 125 @maybe_time 126 @count 127 def benchmark_cpu( 128 self: Self, _callable: Callable[[], Any], warmup: int = 20, rep: int = 100 129 ) -> float: 130 """Benchmark the CPU callable, `_callable`, and return the median runtime, 131 in milliseconds. 132 133 Arguments: 134 - _callable: The CPU callable to benchmark. 135 136 Keyword Arguments: 137 - warmup: Optionally, the duration, in milliseconds, to run `_callable` 138 before benchmarking starts. 139 - rep: Optionally, the duration, in milliseconds, to run `_callable` 140 during benchmarking. 141 142 Returns: 143 - The median runtime of `_callable`, in milliseconds. 144 """ 145 146 def run_for(ms: int) -> List[float]: 147 timings = [] 148 run_start_t = time.perf_counter() 149 while True: 150 start_t = time.perf_counter() 151 _callable() 152 end_t = time.perf_counter() 153 timings.append((end_t - start_t) * MILLISECONDS_PER_SECOND) 154 if ((end_t - run_start_t) * MILLISECONDS_PER_SECOND) > ms: 155 break 156 return timings 157 158 run_for(warmup) 159 return median(run_for(rep)) 160 161 @count 162 def benchmark_gpu(self: Self, *args: Any, **kwargs: Any) -> float: 163 raise NotImplementedError 164 165 166class TritonBenchmarker(Benchmarker): 167 @cached_property 168 @maybe_time 169 @count 170 def triton_do_bench(self: Self) -> Callable[..., Any]: 171 """Lazily import Triton's `do_bench`.""" 172 try: 173 from triton.testing import do_bench 174 except ImportError as e: 175 raise NotImplementedError("requires Triton") from e 176 return do_bench 177 178 @maybe_time 179 @count 180 def benchmark_gpu(self: Self, _callable: Callable[[], Any], **kwargs: Any) -> float: 181 """Benchmark the GPU callable, `_callable`, and return the runtime, in milliseconds. 182 183 Arguments: 184 - _callable: The GPU callable to benchmark. 185 186 Keyword Arguments: 187 - quantiles: Optionally, a tuple of floats denoting the requested quantiles. 188 - return_mode: Optionally, the requested return mode. Currently, Triton's 189 `do_bench` supports min, max, mean, and median return modes. 190 - **kwargs: Additional kwargs passed to Triton's `do_bench`. 191 192 Returns: 193 - The runtime of `callable`, in milliseconds. If `kwargs["quantiles"]` is specified, 194 this is the first requested quantile. Else, if `kwargs["return_mode"]` is specified, 195 this is the requested return mode. Otherwise, this is the median. 196 """ 197 if "quantiles" in kwargs: 198 return self.triton_do_bench(_callable, **kwargs)[0] 199 elif "return_mode" in kwargs: 200 return self.triton_do_bench(_callable, **kwargs) 201 return self.triton_do_bench(_callable, **kwargs, return_mode="median") 202 203 204benchmarker = TritonBenchmarker() 205