1# Owner(s): ["module: unknown"] 2import math 3import os 4from collections import defaultdict 5from typing import Any, Callable, Dict, List, Set, Tuple 6from typing_extensions import Self 7 8import torch 9import torch.utils._pytree as pytree 10from torch._guards import active_fake_mode 11from torch._inductor.utils import get_device_tflops, get_gpu_dram_gbps 12from torch._subclasses.fake_tensor import FakeTensorMode 13from torch.distributed._tools.mod_tracker import ModTracker 14from torch.utils._mode_utils import no_dispatch 15from torch.utils._python_dispatch import TorchDispatchMode 16from torch.utils.flop_counter import flop_registry 17 18 19aten = torch.ops.aten 20 21# This value is hard-coded here: 22# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 23_PYTORCH_MIN_ALLOCATE = ( 24 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 25) 26 27# No fall-back kernel needed/exists for view ops 28_VIEW_OPS = { 29 aten.lift_fresh, 30 aten.t, 31 aten.transpose, 32 aten.view, 33 aten.detach, 34 aten._unsafe_view, 35 aten.split, 36 aten.adjoint, 37 aten.as_strided, 38 aten.diagonal, 39 aten.expand, 40 aten.expand_as, 41 aten.movedim, 42 aten.permute, 43 aten.select, 44 aten.squeeze, 45 aten.mT, 46 aten.mH, 47 aten.real, 48 aten.imag, 49 aten.view_as, 50 aten.unflatten, 51 aten.unfold, 52 aten.unbind, 53 aten.unsqueeze, 54 aten.vsplit, 55 aten.hsplit, 56 aten.split_with_sizes, 57 aten.swapaxes, 58 aten.swapdims, 59 aten.chunk, 60} 61# We can ignore benchmarking tensor create ops 62_CREATE_OPS = { 63 aten.randint, 64 aten.randn, 65 aten.rand, 66 aten.randn_like, 67 aten.rand_like, 68 aten.randint_like, 69 aten.arange, 70 aten.ones_like, 71 aten.zeros_like, 72} 73 74_IGNORE_OPS = _VIEW_OPS | _CREATE_OPS 75 76__all__ = ["RuntimeEstimator"] 77 78 79class RuntimeEstimator(TorchDispatchMode): 80 """ 81 Estimates the GPU runtime in milliseconds using various estimation methods under the ``FakeTensorMode``. 82 83 This class provides a ``TorchDispatchMode`` based context manager that can be used to estimate the eager 84 runtime of PyTorch functions. It supports two estimation modes, benchmarking (`operator-level-benchmark`) and 85 roofline cost modeling (`operator-level-cost-model`). 86 For modules executed under this context manager, it agggregates the forward and backward operation runtimes 87 and also records their execution orders. 88 89 Attributes: 90 mod_runtimes (Dict[str, Dict[str, float]]): A dictionary of module runtimes. The key to the outer dictionary 91 is the fully qualified name (FQN) of the module. For each module the forward and backward runtimes of the 92 operations are aggregated in the inner dictionary keyed by 'fw' and 'bw'. 93 mod_fw_pre_order (List[str]): List of module FQNs in pre-forward execution order. 94 mod_bw_pre_order (List[str]): List of module FQNs in pre-backward execution order. 95 mod_fw_post_order (List[str]): List of module FQNs in post-forward execution order. 96 mod_bw_post_order (List[str]): List of module FQNs in post-backward execution order. 97 total_runtime (float): The total estimated runtime in milliseconds. 98 99 Note: 100 1) The benchmarking estimate mode will execute kernels on GPU and assumes that every operation can run in 101 isolation without causing an OOM error. It is also designed to be used only under ``FakeTensorMode``. 102 2) Currently wrapper tensor sub-classes such as ``DTensor`` won't produce correct estimates. We plan to support 103 them in future PRs. 104 3) We only estimate the compute time, if your code has communication, it will not be considered. Again, we will 105 support this in future PRs. 106 107 Example usage: 108 109 .. code-block:: python 110 111 runtime_estimator = RuntimeEstimator() 112 with FakeTensorMode(): 113 module = ... 114 optimizer = ... 115 inp = ... 116 with runtime_estimator(estimate_mode_type="operator-level-cost-model"): 117 loss = module(inp) 118 loss.backward() 119 optimizer.step() 120 optimizer.zero_grad() 121 runtime_estimator.display_modulewise_stats() 122 """ 123 124 _float_types: Set[torch.dtype] = { 125 torch.float16, 126 torch.bfloat16, 127 torch.float32, 128 torch.float64, 129 } 130 _no_fallback_kernel: Set[torch._ops._OpNamespace] = set() 131 fake_mode: FakeTensorMode 132 133 def __init__(self) -> None: 134 super().__init__() 135 self._estimate: Callable 136 self._estimate_mode_type: str 137 self._mod_tracker = ModTracker() 138 self.mod_runtimes: Dict[str, Dict[str, float]] = defaultdict( 139 lambda: defaultdict(lambda: 0.0) 140 ) 141 self.mod_fw_pre_order: List[str] = [] 142 self.mod_bw_pre_order: List[str] = [] 143 self.mod_fw_post_order: List[str] = [] 144 self.mod_bw_post_order: List[str] = [] 145 self.total_runtime: float = 0.0 146 147 # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_subclasses/fake_tensor.py#L1969 # noqa: PGH004,B950 148 # NB: returns fake tensors 149 @classmethod 150 def _maybe_run_and_benchmark_fallback_kernel( # type: ignore[no-untyped-def] 151 cls, 152 func, 153 args, 154 kwargs, 155 orig_not_implemented_exception, 156 ): 157 """ 158 Runs and benchmarks a fallback kernel for a given function. 159 160 Args: 161 func (Callable): The function to benchmark. 162 args (Tuple): The arguments to pass to the function. 163 kwargs (Dict[str, Any]): The keyword arguments to pass to the function. 164 orig_not_implemented_exception (Exception): The original exception to raise if the fallback kernel 165 is not implemented. 166 167 Returns: 168 Tuple[Any, float]: A tuple containing the result of the function and 169 the mean operation time in milliseconds. 170 """ 171 # these should all be supported, just to be safe 172 # avoid fallback for operators which inplace modify metadata 173 # because the input fake tensors would be umodified 174 if torch.Tag.inplace_view in func.tags: # type: ignore[attr-defined] 175 raise orig_not_implemented_exception 176 177 inp_impls = {} 178 flat_args, args_spec = pytree.tree_flatten((args, kwargs)) 179 # Don't use in_kernel_invocation_manager(fake_mode) as we want to do 180 # REAL compute (not with meta device) 181 with no_dispatch(): 182 183 def to_real_tensor(e): # type: ignore[no-untyped-def] 184 if cls.fake_mode.is_our_fake(e): 185 if e.dtype in cls._float_types: 186 out = torch.rand_like(e, device=e.fake_device) 187 else: 188 out = torch.ones_like(e, device=e.fake_device) 189 if e.is_sparse: 190 out._coalesced_(e.is_coalesced()) 191 inp_impls[id(out)] = e 192 return out 193 return e 194 195 flat_args = [to_real_tensor(a) for a in flat_args] 196 args, kwargs = pytree.tree_unflatten(flat_args, args_spec) 197 r = func(*args, **kwargs) 198 warmup_iters, actual_iters = 2, 3 199 for _ in range(warmup_iters): 200 func(*args, **kwargs) 201 start_event = torch.cuda.Event(enable_timing=True) 202 end_event = torch.cuda.Event(enable_timing=True) 203 start_event.record(torch.cuda.current_stream()) 204 for _ in range(actual_iters): 205 func(*args, **kwargs) 206 end_event.record(torch.cuda.current_stream()) 207 torch.cuda.synchronize() 208 cuda_time = start_event.elapsed_time(end_event) 209 mean_op_time = cuda_time / actual_iters 210 211 storages = set() 212 213 for e in flat_args: 214 if isinstance(e, torch.Tensor): 215 if not e.is_sparse: 216 storages.add(e._typed_storage()._cdata) 217 218 # TODO: also check metadata change on inputs 219 # proper aliasing/metadata relationship between outputs and inputs will 220 # not be set up, bc of conversion to device, unless we can reuse an 221 # input impl 222 223 def map_out(e): # type: ignore[no-untyped-def] 224 if id(e) not in inp_impls and ( 225 isinstance(e, torch.Tensor) 226 and not e.is_sparse 227 and e._typed_storage()._cdata in storages 228 ): 229 raise orig_not_implemented_exception 230 231 if isinstance(e, torch.Tensor): 232 if id(e) in inp_impls: 233 return inp_impls[id(e)] 234 else: 235 return cls.fake_mode.fake_tensor_converter.from_real_tensor( 236 cls.fake_mode, e 237 ) 238 else: 239 return e 240 241 return (pytree.tree_map(map_out, r), mean_op_time) 242 243 @classmethod 244 def _benchmark_estimate(cls, func, args, kwargs) -> Tuple[Any, float]: # type: ignore[no-untyped-def] 245 """ 246 Estimates the runtime of a function using benchmarking. 247 248 Args: 249 func: The function to estimate. 250 args: The arguments to pass to the function. 251 kwargs: The keyword arguments to pass to the function. 252 res: The result of the function. 253 254 Returns: 255 Tuple[Any, float]: A tuple containing the result of the function and 256 the mean operation time in milliseconds. 257 """ 258 assert isinstance( 259 cls.fake_mode, FakeTensorMode 260 ), "Initialize/Assign FakeTensorMode before using this function" 261 mean_op_time = 0.0 262 if func._overloadpacket not in _VIEW_OPS: 263 try: 264 res, mean_op_time = cls._maybe_run_and_benchmark_fallback_kernel( 265 func, 266 args, 267 kwargs, 268 NotImplementedError, 269 ) 270 return (res, mean_op_time) 271 except NotImplementedError: 272 cls._no_fallback_kernel.add(func._overloadpacket) 273 res = func(*args, **kwargs or {}) 274 return (res, mean_op_time) 275 276 # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_inductor/scheduler.py#L589 # noqa: PGH004,B950 277 @classmethod 278 def _roofline_estimate(cls, func, args, kwargs) -> Tuple[Any, float]: # type: ignore[no-untyped-def] 279 """ 280 Estimates the runtime of a function using a roofline cost model. 281 282 Args: 283 func: The function to estimate. 284 args: The arguments to pass to the function. 285 kwargs: The keyword arguments to pass to the function. 286 out: The output of the function. 287 288 Returns: 289 Tuple[Any, float]: A tuple containing the result of the function and 290 the mean operation time in milliseconds. 291 """ 292 assert ( 293 torch.cuda.is_available() 294 ), "Roofline estimation needs to access CUDA capabilities to make estimations" 295 296 def get_num_bytes(t: torch.Tensor) -> int: 297 """ 298 Calculates the memory consumption of a tensor. 299 300 Args: 301 t (torch.Tensor): The input tensor. 302 303 Returns: 304 int: The memory consumption of the tensor in bytes. 305 """ 306 num_bytes = t.untyped_storage().nbytes() 307 mem_consumed = ( 308 math.ceil(num_bytes / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE 309 ) 310 return mem_consumed 311 312 def get_compute_time(func_packet, args, kwargs, out, out_dtypes) -> float: # type: ignore[no-untyped-def] 313 """ 314 Estimates the compute time of an aten operator. 315 316 Args: 317 func_packet: The operator overload packet. 318 args: The arguments to the operator. 319 kwargs: The keyword arguments to the operator. 320 out: The output of the operator. 321 out_dtypes: The output data types. 322 323 Returns: 324 float: The estimated compute time in nanoseconds. 325 """ 326 if func_packet in flop_registry: 327 assert ( 328 len(out_dtypes) == 1 329 ), f"Only support single out dtype got {out_dtypes} for {func_packet}" 330 dtype = out_dtypes.pop() 331 # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s 332 peak_gpu_flops = get_device_tflops(dtype) * 1e15 333 # We can expect to achieve 75% of theoretical peak flops 334 factor = 0.75 335 peak_empirical_flops = factor * peak_gpu_flops 336 flop_count_func = flop_registry[func_packet] 337 # We divide by a factor of 2 to get the MACs (multiply and accumulate) 338 flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2 339 # We multiply by 1e9 to get the time in nano seconds 340 compute_time = (flop_count / peak_empirical_flops) * 1e9 341 return compute_time 342 return 0.0 343 344 def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no-untyped-def] 345 """ 346 Estimates the memory transfer time of input and output tensors. 347 348 Args: 349 flat_args_kwargs (List[torch.Tensor]): The flat list of arguments and keyword arguments. 350 flat_outs (List[torch.Tensor]): The flat list of outputs. 351 352 Returns: 353 float: The estimated memory transfer time in nanoseconds. 354 """ 355 gpu_memory_bandwidth = get_gpu_dram_gbps() 356 read_bytes = sum( 357 get_num_bytes(t) 358 for t in flat_args_kwargs 359 if isinstance(t, torch.Tensor) 360 ) 361 write_bytes = sum( 362 get_num_bytes(t) for t in flat_outs if isinstance(t, torch.Tensor) 363 ) 364 counted_bytes = read_bytes + write_bytes 365 # The GPU memory bandwidth is in GB/s so the transfer time is in nanoseconds 366 transfer_time = counted_bytes / gpu_memory_bandwidth 367 return transfer_time 368 369 # Roofline Cost Model Explanation 370 371 # The roofline cost model estimates the execution time of an operator based on 372 # the device's empirical maximum FLOPs/sec (pi) and device DRAM bandwidth (beta). 373 374 # Variables: 375 # - pi: Maximum empirical FLOPs/sec of the device 376 # - beta: Maximum empirical device DRAM bandwidth (bytes/sec) of the device 377 # - I: Arithmetic intensity of the operator (FLOPs/bytes) 378 # - op_flops: FLOPs required by the operator 379 # - op_bytes: Bytes transferred to and from DRAM for the operator 380 381 # Calculation Steps: 382 # 1. Calculate arithmetic intensity: I = op_flops / op_bytes 383 # 2. Calculate estimated FLOPs/sec: est_flops_sec = min(pi, beta * I) 384 # 3. Calculate estimated operator time: estimated_op_time = op_flops / est_flops_sec 385 # This simplifies to: estimated_op_time = max(op_flops / pi, op_flops / (beta * I)) 386 # Further simplifying: estimated_op_time = max(op_flops / pi, op_bytes / beta) 387 388 # Simplified Formulas: 389 # - compute_time = op_flops / pi 390 # - transfer_time = op_bytes / beta 391 # - estimated_op_time = max(compute_time, transfer_time) 392 393 kwargs = kwargs if kwargs else {} 394 out = func(*args, **kwargs) 395 op_time = 0.0 396 func_packet = func._overloadpacket 397 if func_packet not in _IGNORE_OPS: 398 flat_args_kwargs, args_spec = pytree.tree_flatten((args, kwargs)) 399 flat_outs, out_spec = pytree.tree_flatten(out) 400 transfer_time = get_transfer_time(flat_args_kwargs, flat_outs) 401 402 out_dtypes = { 403 t.dtype 404 for t in flat_outs 405 if isinstance(t, torch.Tensor) and t.dtype in cls._float_types 406 } 407 408 args, kwargs = pytree.tree_unflatten(flat_args_kwargs, args_spec) 409 out = pytree.tree_unflatten(flat_outs, out_spec) 410 411 compute_time = get_compute_time(func_packet, args, kwargs, out, out_dtypes) 412 # We get the estimated time as the max of the transfer time and 413 # compute time. We divide by 1e6 to get the time in ms 414 op_time = max(transfer_time, compute_time) / 1e6 415 416 return (out, op_time) 417 418 def display_modulewise_stats(self, depth: int = 2) -> None: 419 """ 420 Displays module-wise statistics collected by ``RuntimeEstimator``. 421 422 Prints the pre-forward and pre-backward execution orders. 423 Displays the module-wise forward and backward runtimes in milliseconds. 424 425 Args: 426 depth (int): The maximum depth of module hierarchy to display (default to 2). 427 """ 428 print("Pre-Forward Execution Order: ") 429 for mod_fqn in self.mod_fw_pre_order: 430 mod_depth = mod_fqn.count(".") + 1 431 if mod_depth > depth: 432 continue 433 print(mod_fqn) 434 print("Pre-Backward Execution Order: ") 435 for mod_fqn in self.mod_bw_pre_order: 436 mod_depth = mod_fqn.count(".") + 1 437 if mod_depth > depth: 438 continue 439 print(mod_fqn) 440 for mod_fqn, runtimes in self.mod_runtimes.items(): 441 mod_depth = mod_fqn.count(".") + 1 442 if mod_depth > depth: 443 continue 444 print( 445 f"{mod_fqn} fw: {runtimes.get('fw', 0.0):.3f}ms bw: {runtimes.get('bw', 0.0):.3f}ms" 446 ) 447 448 def __torch_dispatch__(self, func, types, args=..., kwargs=None): # type: ignore[no-untyped-def] 449 # TODO: @sanketpurandare: Flatten tensors by desugaring the tensor subclasses 450 # TODO: @sanketpurandare: Add logic for incorporating communication time 451 res, op_time = self._estimate(func, args, kwargs) 452 for par in self._mod_tracker.parents: 453 if self._mod_tracker.is_bw: 454 self.mod_runtimes[par]["bw"] += op_time 455 else: 456 self.mod_runtimes[par]["fw"] += op_time 457 self.total_runtime += op_time 458 return res 459 460 def __call__(self, estimate_mode_type: str) -> Self: 461 """ 462 Sets the estimate mode type. 463 464 Currently supported modes: 465 - "operator-level-benchmark": Estimates runtime using operator benchmarking. 466 - "operator-level-cost-model": Estimates runtime using roofline cost model. 467 468 Args: 469 estimate_mode_type (str): The type of estimate mode to use. 470 471 Returns: 472 RuntimeEstimator: The runtime estimator instance. 473 474 Raises: 475 NotImplementedError: If the estimate mode type is not supported. 476 """ 477 if estimate_mode_type == "operator-level-benchmark": 478 self._estimate = RuntimeEstimator._benchmark_estimate 479 elif estimate_mode_type == "operator-level-cost-model": 480 self._estimate = RuntimeEstimator._roofline_estimate 481 else: 482 raise NotImplementedError( 483 f"estimate_mode_type {estimate_mode_type} not supported" 484 ) 485 self._estimate_mode_type = estimate_mode_type 486 return self 487 488 def __enter__(self) -> Self: 489 fake_mode = active_fake_mode() 490 assert isinstance( 491 fake_mode, FakeTensorMode 492 ), "No FakeTensorMode found, designed to used under FakeTensorMode" 493 RuntimeEstimator.fake_mode = fake_mode 494 self.total_runtime = 0.0 495 self.mod_runtimes = defaultdict(lambda: defaultdict(lambda: 0.0)) 496 self.mod_fw_pre_order.clear() 497 self.mod_bw_pre_order.clear() 498 self.mod_fw_post_order.clear() 499 self.mod_bw_post_order.clear() 500 self._mod_tracker.register_user_hooks( 501 pre_fw_hook=lambda mod, inp: self.mod_fw_pre_order.append( 502 self._mod_tracker.get_known_fqn(mod) 503 ), 504 pre_bw_hook=lambda mod, g_out: self.mod_bw_pre_order.append( 505 self._mod_tracker.get_known_fqn(mod) 506 ), 507 post_fw_hook=lambda mod, inp, out: self.mod_fw_post_order.append( 508 self._mod_tracker.get_known_fqn(mod) 509 ), 510 post_bw_hook=lambda mod, g_inp: self.mod_bw_post_order.append( 511 self._mod_tracker.get_known_fqn(mod) 512 ), 513 ) 514 self._mod_tracker.__enter__() 515 super().__enter__() 516 return self 517 518 def __exit__(self, *args: Any) -> None: 519 print( 520 f"Estimated ({self._estimate_mode_type})" 521 f"total_time: {self.total_runtime:.3f} ms" 522 ) 523 if len(self._no_fallback_kernel) > 0: 524 print("no_fallback_kernel: ", list(self._no_fallback_kernel)) 525 super().__exit__(*args) 526 self._mod_tracker.clear_user_hooks() 527 self._mod_tracker.__exit__() 528