xref: /aosp_15_r20/external/pytorch/torch/distributed/_tools/runtime_estimator.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: unknown"]
2import math
3import os
4from collections import defaultdict
5from typing import Any, Callable, Dict, List, Set, Tuple
6from typing_extensions import Self
7
8import torch
9import torch.utils._pytree as pytree
10from torch._guards import active_fake_mode
11from torch._inductor.utils import get_device_tflops, get_gpu_dram_gbps
12from torch._subclasses.fake_tensor import FakeTensorMode
13from torch.distributed._tools.mod_tracker import ModTracker
14from torch.utils._mode_utils import no_dispatch
15from torch.utils._python_dispatch import TorchDispatchMode
16from torch.utils.flop_counter import flop_registry
17
18
19aten = torch.ops.aten
20
21# This value is hard-coded here:
22# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117
23_PYTORCH_MIN_ALLOCATE = (
24    2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1
25)
26
27# No fall-back kernel needed/exists for view ops
28_VIEW_OPS = {
29    aten.lift_fresh,
30    aten.t,
31    aten.transpose,
32    aten.view,
33    aten.detach,
34    aten._unsafe_view,
35    aten.split,
36    aten.adjoint,
37    aten.as_strided,
38    aten.diagonal,
39    aten.expand,
40    aten.expand_as,
41    aten.movedim,
42    aten.permute,
43    aten.select,
44    aten.squeeze,
45    aten.mT,
46    aten.mH,
47    aten.real,
48    aten.imag,
49    aten.view_as,
50    aten.unflatten,
51    aten.unfold,
52    aten.unbind,
53    aten.unsqueeze,
54    aten.vsplit,
55    aten.hsplit,
56    aten.split_with_sizes,
57    aten.swapaxes,
58    aten.swapdims,
59    aten.chunk,
60}
61# We can ignore benchmarking tensor create ops
62_CREATE_OPS = {
63    aten.randint,
64    aten.randn,
65    aten.rand,
66    aten.randn_like,
67    aten.rand_like,
68    aten.randint_like,
69    aten.arange,
70    aten.ones_like,
71    aten.zeros_like,
72}
73
74_IGNORE_OPS = _VIEW_OPS | _CREATE_OPS
75
76__all__ = ["RuntimeEstimator"]
77
78
79class RuntimeEstimator(TorchDispatchMode):
80    """
81    Estimates the GPU runtime in milliseconds using various estimation methods under the ``FakeTensorMode``.
82
83    This class provides a ``TorchDispatchMode`` based context manager that can be used to estimate the eager
84    runtime of PyTorch functions. It supports two estimation modes, benchmarking (`operator-level-benchmark`) and
85    roofline cost modeling (`operator-level-cost-model`).
86    For modules executed under this context manager, it agggregates the forward and backward operation runtimes
87    and also records their execution orders.
88
89    Attributes:
90        mod_runtimes (Dict[str, Dict[str, float]]): A dictionary of module runtimes. The key to the outer dictionary
91            is the fully qualified name (FQN) of the module. For each module the forward and backward runtimes of the
92            operations are aggregated in the inner dictionary keyed by 'fw' and 'bw'.
93        mod_fw_pre_order (List[str]): List of module FQNs in pre-forward execution order.
94        mod_bw_pre_order (List[str]): List of module FQNs in pre-backward execution order.
95        mod_fw_post_order (List[str]): List of module FQNs in post-forward execution order.
96        mod_bw_post_order (List[str]): List of module FQNs in post-backward execution order.
97        total_runtime (float): The total estimated runtime in milliseconds.
98
99    Note:
100        1) The benchmarking estimate mode will execute kernels on GPU and assumes that every operation can run in
101            isolation without causing an OOM error. It is also designed to be used only under ``FakeTensorMode``.
102        2) Currently wrapper tensor sub-classes such as ``DTensor`` won't produce correct estimates. We plan to support
103            them in future PRs.
104        3) We only estimate the compute time, if your code has communication, it will not be considered. Again, we will
105            support this in future PRs.
106
107    Example usage:
108
109        .. code-block:: python
110
111            runtime_estimator = RuntimeEstimator()
112            with FakeTensorMode():
113                module = ...
114                optimizer = ...
115                inp = ...
116                with runtime_estimator(estimate_mode_type="operator-level-cost-model"):
117                    loss = module(inp)
118                    loss.backward()
119                    optimizer.step()
120                    optimizer.zero_grad()
121                runtime_estimator.display_modulewise_stats()
122    """
123
124    _float_types: Set[torch.dtype] = {
125        torch.float16,
126        torch.bfloat16,
127        torch.float32,
128        torch.float64,
129    }
130    _no_fallback_kernel: Set[torch._ops._OpNamespace] = set()
131    fake_mode: FakeTensorMode
132
133    def __init__(self) -> None:
134        super().__init__()
135        self._estimate: Callable
136        self._estimate_mode_type: str
137        self._mod_tracker = ModTracker()
138        self.mod_runtimes: Dict[str, Dict[str, float]] = defaultdict(
139            lambda: defaultdict(lambda: 0.0)
140        )
141        self.mod_fw_pre_order: List[str] = []
142        self.mod_bw_pre_order: List[str] = []
143        self.mod_fw_post_order: List[str] = []
144        self.mod_bw_post_order: List[str] = []
145        self.total_runtime: float = 0.0
146
147    # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_subclasses/fake_tensor.py#L1969  # noqa: PGH004,B950
148    # NB: returns fake tensors
149    @classmethod
150    def _maybe_run_and_benchmark_fallback_kernel(  # type: ignore[no-untyped-def]
151        cls,
152        func,
153        args,
154        kwargs,
155        orig_not_implemented_exception,
156    ):
157        """
158        Runs and benchmarks a fallback kernel for a given function.
159
160        Args:
161            func (Callable): The function to benchmark.
162            args (Tuple): The arguments to pass to the function.
163            kwargs (Dict[str, Any]): The keyword arguments to pass to the function.
164            orig_not_implemented_exception (Exception): The original exception to raise if the fallback kernel
165                is not implemented.
166
167        Returns:
168            Tuple[Any, float]: A tuple containing the result of the function and
169                the mean operation time in milliseconds.
170        """
171        # these should all be supported, just to be safe
172        # avoid fallback for operators which inplace modify metadata
173        # because the input fake tensors would be umodified
174        if torch.Tag.inplace_view in func.tags:  # type: ignore[attr-defined]
175            raise orig_not_implemented_exception
176
177        inp_impls = {}
178        flat_args, args_spec = pytree.tree_flatten((args, kwargs))
179        # Don't use in_kernel_invocation_manager(fake_mode) as we want to do
180        # REAL compute (not with meta device)
181        with no_dispatch():
182
183            def to_real_tensor(e):  # type: ignore[no-untyped-def]
184                if cls.fake_mode.is_our_fake(e):
185                    if e.dtype in cls._float_types:
186                        out = torch.rand_like(e, device=e.fake_device)
187                    else:
188                        out = torch.ones_like(e, device=e.fake_device)
189                    if e.is_sparse:
190                        out._coalesced_(e.is_coalesced())
191                    inp_impls[id(out)] = e
192                    return out
193                return e
194
195            flat_args = [to_real_tensor(a) for a in flat_args]
196            args, kwargs = pytree.tree_unflatten(flat_args, args_spec)
197            r = func(*args, **kwargs)
198            warmup_iters, actual_iters = 2, 3
199            for _ in range(warmup_iters):
200                func(*args, **kwargs)
201            start_event = torch.cuda.Event(enable_timing=True)
202            end_event = torch.cuda.Event(enable_timing=True)
203            start_event.record(torch.cuda.current_stream())
204            for _ in range(actual_iters):
205                func(*args, **kwargs)
206            end_event.record(torch.cuda.current_stream())
207            torch.cuda.synchronize()
208            cuda_time = start_event.elapsed_time(end_event)
209            mean_op_time = cuda_time / actual_iters
210
211        storages = set()
212
213        for e in flat_args:
214            if isinstance(e, torch.Tensor):
215                if not e.is_sparse:
216                    storages.add(e._typed_storage()._cdata)
217
218        # TODO: also check metadata change on inputs
219        # proper aliasing/metadata relationship between outputs and inputs will
220        # not be set up, bc of conversion to device, unless we can reuse an
221        # input impl
222
223        def map_out(e):  # type: ignore[no-untyped-def]
224            if id(e) not in inp_impls and (
225                isinstance(e, torch.Tensor)
226                and not e.is_sparse
227                and e._typed_storage()._cdata in storages
228            ):
229                raise orig_not_implemented_exception
230
231            if isinstance(e, torch.Tensor):
232                if id(e) in inp_impls:
233                    return inp_impls[id(e)]
234                else:
235                    return cls.fake_mode.fake_tensor_converter.from_real_tensor(
236                        cls.fake_mode, e
237                    )
238            else:
239                return e
240
241        return (pytree.tree_map(map_out, r), mean_op_time)
242
243    @classmethod
244    def _benchmark_estimate(cls, func, args, kwargs) -> Tuple[Any, float]:  # type: ignore[no-untyped-def]
245        """
246        Estimates the runtime of a function using benchmarking.
247
248        Args:
249            func: The function to estimate.
250            args: The arguments to pass to the function.
251            kwargs: The keyword arguments to pass to the function.
252            res: The result of the function.
253
254        Returns:
255            Tuple[Any, float]: A tuple containing the result of the function and
256                the mean operation time in milliseconds.
257        """
258        assert isinstance(
259            cls.fake_mode, FakeTensorMode
260        ), "Initialize/Assign FakeTensorMode before using this function"
261        mean_op_time = 0.0
262        if func._overloadpacket not in _VIEW_OPS:
263            try:
264                res, mean_op_time = cls._maybe_run_and_benchmark_fallback_kernel(
265                    func,
266                    args,
267                    kwargs,
268                    NotImplementedError,
269                )
270                return (res, mean_op_time)
271            except NotImplementedError:
272                cls._no_fallback_kernel.add(func._overloadpacket)
273        res = func(*args, **kwargs or {})
274        return (res, mean_op_time)
275
276    # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_inductor/scheduler.py#L589  # noqa: PGH004,B950
277    @classmethod
278    def _roofline_estimate(cls, func, args, kwargs) -> Tuple[Any, float]:  # type: ignore[no-untyped-def]
279        """
280        Estimates the runtime of a function using a roofline cost model.
281
282        Args:
283            func: The function to estimate.
284            args: The arguments to pass to the function.
285            kwargs: The keyword arguments to pass to the function.
286            out: The output of the function.
287
288        Returns:
289            Tuple[Any, float]: A tuple containing the result of the function and
290                the mean operation time in milliseconds.
291        """
292        assert (
293            torch.cuda.is_available()
294        ), "Roofline estimation needs to access CUDA capabilities to make estimations"
295
296        def get_num_bytes(t: torch.Tensor) -> int:
297            """
298            Calculates the memory consumption of a tensor.
299
300            Args:
301                t (torch.Tensor): The input tensor.
302
303            Returns:
304                int: The memory consumption of the tensor in bytes.
305            """
306            num_bytes = t.untyped_storage().nbytes()
307            mem_consumed = (
308                math.ceil(num_bytes / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE
309            )
310            return mem_consumed
311
312        def get_compute_time(func_packet, args, kwargs, out, out_dtypes) -> float:  # type: ignore[no-untyped-def]
313            """
314            Estimates the compute time of an aten operator.
315
316            Args:
317                func_packet: The operator overload packet.
318                args: The arguments to the operator.
319                kwargs: The keyword arguments to the operator.
320                out: The output of the operator.
321                out_dtypes: The output data types.
322
323            Returns:
324                float: The estimated compute time in nanoseconds.
325            """
326            if func_packet in flop_registry:
327                assert (
328                    len(out_dtypes) == 1
329                ), f"Only support single out dtype got {out_dtypes} for {func_packet}"
330                dtype = out_dtypes.pop()
331                # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s
332                peak_gpu_flops = get_device_tflops(dtype) * 1e15
333                # We can expect to achieve 75% of theoretical peak flops
334                factor = 0.75
335                peak_empirical_flops = factor * peak_gpu_flops
336                flop_count_func = flop_registry[func_packet]
337                # We divide by a factor of 2 to get the MACs (multiply and accumulate)
338                flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2
339                # We multiply by 1e9 to get the time in nano seconds
340                compute_time = (flop_count / peak_empirical_flops) * 1e9
341                return compute_time
342            return 0.0
343
344        def get_transfer_time(flat_args_kwargs, flat_outs) -> float:  # type: ignore[no-untyped-def]
345            """
346            Estimates the memory transfer time of input and output tensors.
347
348            Args:
349                flat_args_kwargs (List[torch.Tensor]): The flat list of arguments and keyword arguments.
350                flat_outs (List[torch.Tensor]): The flat list of outputs.
351
352            Returns:
353                float: The estimated memory transfer time in nanoseconds.
354            """
355            gpu_memory_bandwidth = get_gpu_dram_gbps()
356            read_bytes = sum(
357                get_num_bytes(t)
358                for t in flat_args_kwargs
359                if isinstance(t, torch.Tensor)
360            )
361            write_bytes = sum(
362                get_num_bytes(t) for t in flat_outs if isinstance(t, torch.Tensor)
363            )
364            counted_bytes = read_bytes + write_bytes
365            # The GPU memory bandwidth is in GB/s so the transfer time is in nanoseconds
366            transfer_time = counted_bytes / gpu_memory_bandwidth
367            return transfer_time
368
369        # Roofline Cost Model Explanation
370
371        # The roofline cost model estimates the execution time of an operator based on
372        # the device's empirical maximum FLOPs/sec (pi) and device DRAM bandwidth (beta).
373
374        # Variables:
375        # - pi: Maximum empirical FLOPs/sec of the device
376        # - beta: Maximum empirical device DRAM bandwidth (bytes/sec) of the device
377        # - I: Arithmetic intensity of the operator (FLOPs/bytes)
378        # - op_flops: FLOPs required by the operator
379        # - op_bytes: Bytes transferred to and from DRAM for the operator
380
381        # Calculation Steps:
382        # 1. Calculate arithmetic intensity: I = op_flops / op_bytes
383        # 2. Calculate estimated FLOPs/sec: est_flops_sec = min(pi, beta * I)
384        # 3. Calculate estimated operator time: estimated_op_time = op_flops / est_flops_sec
385        #    This simplifies to: estimated_op_time = max(op_flops / pi, op_flops / (beta * I))
386        #    Further simplifying: estimated_op_time = max(op_flops / pi, op_bytes / beta)
387
388        # Simplified Formulas:
389        # - compute_time = op_flops / pi
390        # - transfer_time = op_bytes / beta
391        # - estimated_op_time = max(compute_time, transfer_time)
392
393        kwargs = kwargs if kwargs else {}
394        out = func(*args, **kwargs)
395        op_time = 0.0
396        func_packet = func._overloadpacket
397        if func_packet not in _IGNORE_OPS:
398            flat_args_kwargs, args_spec = pytree.tree_flatten((args, kwargs))
399            flat_outs, out_spec = pytree.tree_flatten(out)
400            transfer_time = get_transfer_time(flat_args_kwargs, flat_outs)
401
402            out_dtypes = {
403                t.dtype
404                for t in flat_outs
405                if isinstance(t, torch.Tensor) and t.dtype in cls._float_types
406            }
407
408            args, kwargs = pytree.tree_unflatten(flat_args_kwargs, args_spec)
409            out = pytree.tree_unflatten(flat_outs, out_spec)
410
411            compute_time = get_compute_time(func_packet, args, kwargs, out, out_dtypes)
412            # We get the estimated time as the max of the transfer time and
413            # compute time. We divide by 1e6 to get the time in ms
414            op_time = max(transfer_time, compute_time) / 1e6
415
416        return (out, op_time)
417
418    def display_modulewise_stats(self, depth: int = 2) -> None:
419        """
420        Displays module-wise statistics collected by ``RuntimeEstimator``.
421
422        Prints the pre-forward and pre-backward execution orders.
423        Displays the module-wise forward and backward runtimes in milliseconds.
424
425        Args:
426            depth (int): The maximum depth of module hierarchy to display (default to 2).
427        """
428        print("Pre-Forward Execution Order: ")
429        for mod_fqn in self.mod_fw_pre_order:
430            mod_depth = mod_fqn.count(".") + 1
431            if mod_depth > depth:
432                continue
433            print(mod_fqn)
434        print("Pre-Backward Execution Order: ")
435        for mod_fqn in self.mod_bw_pre_order:
436            mod_depth = mod_fqn.count(".") + 1
437            if mod_depth > depth:
438                continue
439            print(mod_fqn)
440        for mod_fqn, runtimes in self.mod_runtimes.items():
441            mod_depth = mod_fqn.count(".") + 1
442            if mod_depth > depth:
443                continue
444            print(
445                f"{mod_fqn} fw: {runtimes.get('fw', 0.0):.3f}ms bw: {runtimes.get('bw', 0.0):.3f}ms"
446            )
447
448    def __torch_dispatch__(self, func, types, args=..., kwargs=None):  # type: ignore[no-untyped-def]
449        # TODO: @sanketpurandare: Flatten tensors by desugaring the tensor subclasses
450        # TODO: @sanketpurandare: Add logic for incorporating communication time
451        res, op_time = self._estimate(func, args, kwargs)
452        for par in self._mod_tracker.parents:
453            if self._mod_tracker.is_bw:
454                self.mod_runtimes[par]["bw"] += op_time
455            else:
456                self.mod_runtimes[par]["fw"] += op_time
457        self.total_runtime += op_time
458        return res
459
460    def __call__(self, estimate_mode_type: str) -> Self:
461        """
462        Sets the estimate mode type.
463
464        Currently supported modes:
465            - "operator-level-benchmark": Estimates runtime using operator benchmarking.
466            - "operator-level-cost-model": Estimates runtime using roofline cost model.
467
468        Args:
469            estimate_mode_type (str): The type of estimate mode to use.
470
471        Returns:
472            RuntimeEstimator: The runtime estimator instance.
473
474        Raises:
475            NotImplementedError: If the estimate mode type is not supported.
476        """
477        if estimate_mode_type == "operator-level-benchmark":
478            self._estimate = RuntimeEstimator._benchmark_estimate
479        elif estimate_mode_type == "operator-level-cost-model":
480            self._estimate = RuntimeEstimator._roofline_estimate
481        else:
482            raise NotImplementedError(
483                f"estimate_mode_type {estimate_mode_type} not supported"
484            )
485        self._estimate_mode_type = estimate_mode_type
486        return self
487
488    def __enter__(self) -> Self:
489        fake_mode = active_fake_mode()
490        assert isinstance(
491            fake_mode, FakeTensorMode
492        ), "No FakeTensorMode found, designed to used under FakeTensorMode"
493        RuntimeEstimator.fake_mode = fake_mode
494        self.total_runtime = 0.0
495        self.mod_runtimes = defaultdict(lambda: defaultdict(lambda: 0.0))
496        self.mod_fw_pre_order.clear()
497        self.mod_bw_pre_order.clear()
498        self.mod_fw_post_order.clear()
499        self.mod_bw_post_order.clear()
500        self._mod_tracker.register_user_hooks(
501            pre_fw_hook=lambda mod, inp: self.mod_fw_pre_order.append(
502                self._mod_tracker.get_known_fqn(mod)
503            ),
504            pre_bw_hook=lambda mod, g_out: self.mod_bw_pre_order.append(
505                self._mod_tracker.get_known_fqn(mod)
506            ),
507            post_fw_hook=lambda mod, inp, out: self.mod_fw_post_order.append(
508                self._mod_tracker.get_known_fqn(mod)
509            ),
510            post_bw_hook=lambda mod, g_inp: self.mod_bw_post_order.append(
511                self._mod_tracker.get_known_fqn(mod)
512            ),
513        )
514        self._mod_tracker.__enter__()
515        super().__enter__()
516        return self
517
518    def __exit__(self, *args: Any) -> None:
519        print(
520            f"Estimated ({self._estimate_mode_type})"
521            f"total_time: {self.total_runtime:.3f} ms"
522        )
523        if len(self._no_fallback_kernel) > 0:
524            print("no_fallback_kernel: ", list(self._no_fallback_kernel))
525        super().__exit__(*args)
526        self._mod_tracker.clear_user_hooks()
527        self._mod_tracker.__exit__()
528