1import json 2import time 3 4import torch 5 6 7"""PyTorch performance microbenchmarks. 8 9This module contains PyTorch-specific functionalities for performance 10microbenchmarks. 11""" 12 13 14class TorchBenchmarkBase(torch.nn.Module): 15 """This is a base class used to create Pytorch operator benchmark. 16 module_name is the name of the operator being benchmarked. 17 test_name is the name (it's created by concatenating all the 18 inputs) of a specific test 19 """ 20 21 def __init__(self): 22 super().__init__() 23 self.user_given_name = None 24 self._pass_count = 0 25 self._num_inputs_require_grads = 0 26 27 def _set_backward_test(self, is_backward): 28 self._is_backward = is_backward 29 30 def auto_set(self): 31 """This is used to automatically set the require_grad for the backward patch. 32 It is implemented based on two counters. One counter to save the number of 33 times init has been called. The other counter to save the number of times 34 this function itself has been called. In the very first time init is called, 35 this function counts how many inputs require gradient. In each of the 36 following init calls, this function will return only one true value. 37 Here is an example: 38 ... 39 self.v1 = torch.rand(M, N, K, requires_grad=self.auto_set()) 40 self.v2 = torch.rand(M, N, K, requires_grad=self.auto_set()) 41 ... 42 """ 43 if not self._is_backward: 44 return False 45 46 if self._pass_count == 0: 47 self._num_inputs_require_grads += 1 48 return True 49 else: 50 self._auto_set_counter += 1 51 return self._pass_count == self._auto_set_counter 52 53 def extract_inputs_tuple(self): 54 self.inputs_tuple = tuple(self.inputs.values()) 55 56 @torch.jit.export 57 def get_inputs(self): 58 # Need to convert the inputs to tuple outside of JIT so that 59 # JIT can infer the size of the inputs. 60 return self.inputs_tuple 61 62 @torch.jit.export 63 def forward_impl(self): 64 # This is to supply the inputs to the forward function which 65 # will be called in both the eager and JIT mode of local runs 66 return self.forward(*self.get_inputs()) 67 68 @torch.jit.export 69 def forward_consume(self, iters: int): 70 # _consume is used to avoid the dead-code-elimination optimization 71 for _ in range(iters): 72 torch.ops.operator_benchmark._consume(self.forward_impl()) 73 74 def module_name(self): 75 """this is used to label the operator being benchmarked""" 76 if self.user_given_name: 77 return self.user_given_name 78 return self.__class__.__name__ 79 80 def set_module_name(self, name): 81 self.user_given_name = name 82 83 def test_name(self, **kargs): 84 """this is a globally unique name which can be used to 85 label a specific test 86 """ 87 88 # This is a list of attributes which will not be included 89 # in the test name. 90 skip_key_list = ["device"] 91 92 test_name_str = [] 93 for key in kargs: 94 value = kargs[key] 95 test_name_str.append( 96 ("" if key in skip_key_list else key) 97 + str(value if type(value) != bool else int(value)) 98 ) 99 name = (self.module_name() + "_" + "_".join(test_name_str)).replace(" ", "") 100 return name 101 102 103class PyTorchOperatorTestCase: 104 """This class includes all the information needed to benchmark an operator. 105 op_bench: it's a user-defined class (child of TorchBenchmarkBase) 106 which includes input and operator, .etc 107 test_config: a namedtuple includes test_name, input_shape, tag, run_backward. 108 When run_backward is false, the run_forward method will be executed, 109 When run_backward is true, run_forward_eager and _output_mean will be 110 executed to generate output. Then, run_backward will be executed. 111 """ 112 113 def __init__(self, op_bench, test_config): 114 self.test_config = test_config 115 self.op_bench = op_bench 116 self.place_holder_tensor = torch.ones(1) 117 self.framework = "PyTorch" 118 self.time_series = [] 119 self._jit_forward_graph = None 120 121 def _generate_jit_forward_graph(self): 122 """generate a graph for the forward function via scripting""" 123 scripted_op_bench = torch.jit.script(self.op_bench) 124 return scripted_op_bench.forward_consume 125 126 def run_jit_forward(self, num_runs, print_per_iter=False, cuda_sync=False): 127 """Run the forward path of an op with JIT mode""" 128 if self._jit_forward_graph is None: 129 self._jit_forward_graph = self._generate_jit_forward_graph() 130 self._jit_forward_graph(num_runs) 131 132 def _print_per_iter(self): 133 # print last 50 values 134 length = min(len(self.time_series), 50) 135 for i in range(length): 136 print( 137 "PyTorchObserver " 138 + json.dumps( 139 { 140 "type": self.test_config.test_name, 141 "metric": "latency", 142 "unit": "ms", 143 "value": str(self.time_series[length - i - 1]), 144 } 145 ) 146 ) 147 148 def run_forward(self, num_runs, print_per_iter, cuda_sync): 149 """Run the forward path of an op with eager mode""" 150 if print_per_iter: 151 for _ in range(num_runs): 152 start_time = time.time() 153 self.output = self.op_bench.forward_impl() 154 if cuda_sync: 155 torch.cuda.synchronize(torch.cuda.current_device()) 156 end_time = time.time() 157 self.time_series.append((end_time - start_time) * 1e3) 158 else: 159 for _ in range(num_runs): 160 self.output = self.op_bench.forward_impl() 161 if cuda_sync: 162 torch.cuda.synchronize(torch.cuda.current_device()) 163 164 def _output_mean(self): 165 """TODO (mingzhe): it is not necessary to sum up everything by myself, 166 torch.autograd.backward do take a gradient tensor. By default, it 167 is the same shape as your output tensor, with all 1s. 168 Mathematically, it is the same as if the output is summed together. 169 So we should be able to get ride of this method. 170 dummy function for gradient calculation 171 """ 172 self.mean = self.output.mean() 173 174 def run_backward(self, num_runs, print_per_iter=False): 175 """Run the backward path of an op in many iterations""" 176 # TODO: can we use JIT here to reduce python overhead? 177 for _ in range(num_runs): 178 self.mean.backward(retain_graph=True) 179 180 181def create_pytorch_op_test_case(op_bench, test_config): 182 """This method is used to generate est. func_name is a global unique 183 string. For PyTorch add operator with M=8, N=2, K=1, tag = long, here 184 are the values for the members in test_case: 185 op.module_name: add 186 framework: PyTorch 187 test_config: TestConfig(test_name='add_M8_N2_K1', input_config='M: 8, N: 2, K: 1', 188 tag='long', run_backward=False) 189 func_name: addPyTorchTestConfig(test_name='add_M8_N2_K1', input_config='M: 8, N: 2, K: 1', 190 tag='long', run_backward=False) 191 """ 192 test_case = PyTorchOperatorTestCase(op_bench, test_config) 193 test_config = test_case.test_config 194 op = test_case.op_bench 195 func_name = f"{op.module_name()}{test_case.framework}{str(test_config)}" 196 return (func_name, test_case) 197