xref: /aosp_15_r20/external/pytorch/torch/utils/throughput_benchmark.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import torch._C
4
5
6def format_time(time_us=None, time_ms=None, time_s=None):
7    """Define time formatting."""
8    assert sum([time_us is not None, time_ms is not None, time_s is not None]) == 1
9
10    US_IN_SECOND = 1e6
11    US_IN_MS = 1e3
12
13    if time_us is None:
14        if time_ms is not None:
15            time_us = time_ms * US_IN_MS
16        elif time_s is not None:
17            time_us = time_s * US_IN_SECOND
18        else:
19            raise AssertionError("Shouldn't reach here :)")
20
21    if time_us >= US_IN_SECOND:
22        return f'{time_us / US_IN_SECOND:.3f}s'
23    if time_us >= US_IN_MS:
24        return f'{time_us / US_IN_MS:.3f}ms'
25    return f'{time_us:.3f}us'
26
27
28class ExecutionStats:
29    def __init__(self, c_stats, benchmark_config):
30        self._c_stats = c_stats
31        self.benchmark_config = benchmark_config
32
33    @property
34    def latency_avg_ms(self):
35        return self._c_stats.latency_avg_ms
36
37    @property
38    def num_iters(self):
39        return self._c_stats.num_iters
40
41    @property
42    def iters_per_second(self):
43        """Return total number of iterations per second across all calling threads."""
44        return self.num_iters / self.total_time_seconds
45
46    @property
47    def total_time_seconds(self):
48        return self.num_iters * (
49            self.latency_avg_ms / 1000.0) / self.benchmark_config.num_calling_threads
50
51    def __str__(self):
52        return '\n'.join([
53            "Average latency per example: " + format_time(time_ms=self.latency_avg_ms),
54            f"Total number of iterations: {self.num_iters}",
55            f"Total number of iterations per second (across all threads): {self.iters_per_second:.2f}",
56            "Total time: " + format_time(time_s=self.total_time_seconds)
57        ])
58
59
60class ThroughputBenchmark:
61    """
62    This class is a wrapper around a c++ component throughput_benchmark::ThroughputBenchmark.
63
64    This wrapper on the throughput_benchmark::ThroughputBenchmark component is responsible
65    for executing a PyTorch module (nn.Module or ScriptModule) under an inference
66    server like load. It can emulate multiple calling threads to a single module
67    provided. In the future we plan to enhance this component to support inter and
68    intra-op parallelism as well as multiple models running in a single process.
69
70    Please note that even though nn.Module is supported, it might incur an overhead
71    from the need to hold GIL every time we execute Python code or pass around
72    inputs as Python objects. As soon as you have a ScriptModule version of your
73    model for inference deployment it is better to switch to using it in this
74    benchmark.
75
76    Example::
77
78        >>> # xdoctest: +SKIP("undefined vars")
79        >>> from torch.utils import ThroughputBenchmark
80        >>> bench = ThroughputBenchmark(my_module)
81        >>> # Pre-populate benchmark's data set with the inputs
82        >>> for input in inputs:
83        ...     # Both args and kwargs work, same as any PyTorch Module / ScriptModule
84        ...     bench.add_input(input[0], x2=input[1])
85        >>> # Inputs supplied above are randomly used during the execution
86        >>> stats = bench.benchmark(
87        ...     num_calling_threads=4,
88        ...     num_warmup_iters = 100,
89        ...     num_iters = 1000,
90        ... )
91        >>> print("Avg latency (ms): {}".format(stats.latency_avg_ms))
92        >>> print("Number of iterations: {}".format(stats.num_iters))
93    """
94
95    def __init__(self, module):
96        if isinstance(module, torch.jit.ScriptModule):
97            self._benchmark = torch._C.ThroughputBenchmark(module._c)
98        else:
99            self._benchmark = torch._C.ThroughputBenchmark(module)
100
101    def run_once(self, *args, **kwargs):
102        """
103        Given input id (input_idx) run benchmark once and return prediction.
104
105        This is useful for testing that benchmark actually runs the module you
106        want it to run. input_idx here is an index into inputs array populated
107        by calling add_input() method.
108        """
109        return self._benchmark.run_once(*args, **kwargs)
110
111    def add_input(self, *args, **kwargs):
112        """
113        Store a single input to a module into the benchmark memory and keep it there.
114
115        During the benchmark execution every thread is going to pick up a
116        random input from the all the inputs ever supplied to the benchmark via
117        this function.
118        """
119        self._benchmark.add_input(*args, **kwargs)
120
121    def benchmark(
122            self,
123            num_calling_threads=1,
124            num_warmup_iters=10,
125            num_iters=100,
126            profiler_output_path=""):
127        """
128        Run a benchmark on the module.
129
130        Args:
131            num_warmup_iters (int): Warmup iters are used to make sure we run a module
132                a few times before actually measuring things. This way we avoid cold
133                caches and any other similar problems. This is the number of warmup
134                iterations for each of the thread in separate
135
136            num_iters (int): Number of iterations the benchmark should run with.
137                This number is separate from the warmup iterations. Also the number is
138                shared across all the threads. Once the num_iters iterations across all
139                the threads is reached, we will stop execution. Though total number of
140                iterations might be slightly larger. Which is reported as
141                stats.num_iters where stats is the result of this function
142
143            profiler_output_path (str): Location to save Autograd Profiler trace.
144                If not empty, Autograd Profiler will be enabled for the main benchmark
145                execution (but not the warmup phase). The full trace will be saved
146                into the file path provided by this argument
147
148
149        This function returns BenchmarkExecutionStats object which is defined via pybind11.
150        It currently has two fields:
151            - num_iters - number of actual iterations the benchmark have made
152            - avg_latency_ms - average time it took to infer on one input example in milliseconds
153        """
154        config = torch._C.BenchmarkConfig()
155        config.num_calling_threads = num_calling_threads
156        config.num_warmup_iters = num_warmup_iters
157        config.num_iters = num_iters
158        config.profiler_output_path = profiler_output_path
159        c_stats = self._benchmark.benchmark(config)
160        return ExecutionStats(c_stats, config)
161