xref: /aosp_15_r20/external/pytorch/benchmarks/operator_benchmark/benchmark_pytorch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import json
2import time
3
4import torch
5
6
7"""PyTorch performance microbenchmarks.
8
9This module contains PyTorch-specific functionalities for performance
10microbenchmarks.
11"""
12
13
14class TorchBenchmarkBase(torch.nn.Module):
15    """This is a base class used to create Pytorch operator benchmark.
16    module_name is the name of the operator being benchmarked.
17    test_name is the name (it's created by concatenating all the
18    inputs) of a specific test
19    """
20
21    def __init__(self):
22        super().__init__()
23        self.user_given_name = None
24        self._pass_count = 0
25        self._num_inputs_require_grads = 0
26
27    def _set_backward_test(self, is_backward):
28        self._is_backward = is_backward
29
30    def auto_set(self):
31        """This is used to automatically set the require_grad for the backward patch.
32        It is implemented based on two counters. One counter to save the number of
33        times init has been called. The other counter to save the number of times
34        this function itself has been called. In the very first time init is called,
35        this function counts how many inputs require gradient. In each of the
36        following init calls, this function will return only one true value.
37        Here is an example:
38            ...
39            self.v1 = torch.rand(M, N, K, requires_grad=self.auto_set())
40            self.v2 = torch.rand(M, N, K, requires_grad=self.auto_set())
41            ...
42        """
43        if not self._is_backward:
44            return False
45
46        if self._pass_count == 0:
47            self._num_inputs_require_grads += 1
48            return True
49        else:
50            self._auto_set_counter += 1
51            return self._pass_count == self._auto_set_counter
52
53    def extract_inputs_tuple(self):
54        self.inputs_tuple = tuple(self.inputs.values())
55
56    @torch.jit.export
57    def get_inputs(self):
58        # Need to convert the inputs to tuple outside of JIT so that
59        # JIT can infer the size of the inputs.
60        return self.inputs_tuple
61
62    @torch.jit.export
63    def forward_impl(self):
64        # This is to supply the inputs to the forward function which
65        # will be called in both the eager and JIT mode of local runs
66        return self.forward(*self.get_inputs())
67
68    @torch.jit.export
69    def forward_consume(self, iters: int):
70        #  _consume is used to avoid the dead-code-elimination optimization
71        for _ in range(iters):
72            torch.ops.operator_benchmark._consume(self.forward_impl())
73
74    def module_name(self):
75        """this is used to label the operator being benchmarked"""
76        if self.user_given_name:
77            return self.user_given_name
78        return self.__class__.__name__
79
80    def set_module_name(self, name):
81        self.user_given_name = name
82
83    def test_name(self, **kargs):
84        """this is a globally unique name which can be used to
85        label a specific test
86        """
87
88        # This is a list of attributes which will not be included
89        # in the test name.
90        skip_key_list = ["device"]
91
92        test_name_str = []
93        for key in kargs:
94            value = kargs[key]
95            test_name_str.append(
96                ("" if key in skip_key_list else key)
97                + str(value if type(value) != bool else int(value))
98            )
99        name = (self.module_name() + "_" + "_".join(test_name_str)).replace(" ", "")
100        return name
101
102
103class PyTorchOperatorTestCase:
104    """This class includes all the information needed to benchmark an operator.
105    op_bench: it's a user-defined class (child of TorchBenchmarkBase)
106    which includes input and operator, .etc
107    test_config: a namedtuple includes test_name, input_shape, tag, run_backward.
108    When run_backward is false, the run_forward method will be executed,
109    When run_backward is true, run_forward_eager and _output_mean will be
110    executed to generate output. Then, run_backward will be executed.
111    """
112
113    def __init__(self, op_bench, test_config):
114        self.test_config = test_config
115        self.op_bench = op_bench
116        self.place_holder_tensor = torch.ones(1)
117        self.framework = "PyTorch"
118        self.time_series = []
119        self._jit_forward_graph = None
120
121    def _generate_jit_forward_graph(self):
122        """generate a graph for the forward function via scripting"""
123        scripted_op_bench = torch.jit.script(self.op_bench)
124        return scripted_op_bench.forward_consume
125
126    def run_jit_forward(self, num_runs, print_per_iter=False, cuda_sync=False):
127        """Run the forward path of an op with JIT mode"""
128        if self._jit_forward_graph is None:
129            self._jit_forward_graph = self._generate_jit_forward_graph()
130        self._jit_forward_graph(num_runs)
131
132    def _print_per_iter(self):
133        # print last 50 values
134        length = min(len(self.time_series), 50)
135        for i in range(length):
136            print(
137                "PyTorchObserver "
138                + json.dumps(
139                    {
140                        "type": self.test_config.test_name,
141                        "metric": "latency",
142                        "unit": "ms",
143                        "value": str(self.time_series[length - i - 1]),
144                    }
145                )
146            )
147
148    def run_forward(self, num_runs, print_per_iter, cuda_sync):
149        """Run the forward path of an op with eager mode"""
150        if print_per_iter:
151            for _ in range(num_runs):
152                start_time = time.time()
153                self.output = self.op_bench.forward_impl()
154                if cuda_sync:
155                    torch.cuda.synchronize(torch.cuda.current_device())
156                end_time = time.time()
157                self.time_series.append((end_time - start_time) * 1e3)
158        else:
159            for _ in range(num_runs):
160                self.output = self.op_bench.forward_impl()
161            if cuda_sync:
162                torch.cuda.synchronize(torch.cuda.current_device())
163
164    def _output_mean(self):
165        """TODO (mingzhe): it is not necessary to sum up everything by myself,
166        torch.autograd.backward do take a gradient tensor. By default, it
167        is the same shape as your output tensor, with all 1s.
168        Mathematically, it is the same as if the output is summed together.
169        So we should be able to get ride of this method.
170        dummy function for gradient calculation
171        """
172        self.mean = self.output.mean()
173
174    def run_backward(self, num_runs, print_per_iter=False):
175        """Run the backward path of an op in many iterations"""
176        # TODO: can we use JIT here to reduce python overhead?
177        for _ in range(num_runs):
178            self.mean.backward(retain_graph=True)
179
180
181def create_pytorch_op_test_case(op_bench, test_config):
182    """This method is used to generate est. func_name is a global unique
183    string. For PyTorch add operator with M=8, N=2, K=1, tag = long, here
184    are the values for the members in test_case:
185    op.module_name: add
186    framework: PyTorch
187    test_config: TestConfig(test_name='add_M8_N2_K1', input_config='M: 8, N: 2, K: 1',
188        tag='long', run_backward=False)
189    func_name: addPyTorchTestConfig(test_name='add_M8_N2_K1', input_config='M: 8, N: 2, K: 1',
190                                    tag='long', run_backward=False)
191    """
192    test_case = PyTorchOperatorTestCase(op_bench, test_config)
193    test_config = test_case.test_config
194    op = test_case.op_bench
195    func_name = f"{op.module_name()}{test_case.framework}{str(test_config)}"
196    return (func_name, test_case)
197