1import itertools 2import operator 3 4import numpy as np 5import scipy.special 6 7import torch 8 9from . import benchmark 10 11 12# A template class for elementwise operations. 13# A derived class will override the class instance to customize its behavior. 14class ElementBench(benchmark.Benchmark): 15 # List of customization class variables. 16 op_str = None 17 binary_op_pt_func = None 18 binary_op_np_func = None 19 unary_op_pt_func = None 20 unary_op_np_func = None 21 split_input = True 22 23 def __init__(self, mode, device, dtype, N): 24 super().__init__(mode, device, dtype) 25 self.N = N 26 self.d1 = self.rand( 27 [N], device=device, dtype=dtype, requires_grad=self.requires_grad 28 ) 29 self.d2 = self.rand( 30 [N], device=device, dtype=dtype, requires_grad=self.requires_grad 31 ) 32 self.d3 = self.rand( 33 [N], device=device, dtype=dtype, requires_grad=self.requires_grad 34 ) 35 self.d4 = self.rand( 36 [N], device=device, dtype=dtype, requires_grad=self.requires_grad 37 ) 38 self.inputs = [self.d1, self.d2, self.d3, self.d4] 39 self.deterministic = "rand" not in self.op_str 40 41 def _eval(self, d1, d2, d3, d4, binary_op, unary_op): 42 if not binary_op: 43 44 def binary_op(x, y): 45 return x + y 46 47 if not unary_op: 48 49 def unary_op(x): 50 return x 51 52 if self.split_input: 53 d1 = unary_op(d1) 54 d2 = unary_op(d2) 55 d3 = unary_op(d3) 56 d4 = unary_op(d4) 57 else: 58 d2 = unary_op(d1 + 0.001) 59 d3 = unary_op(d1 + 0.002) 60 d4 = unary_op(d1 + 0.003) 61 d1 = unary_op(d1) 62 a = binary_op(d1, d2) 63 b = binary_op(d3, d4) 64 c = a + b 65 return c 66 67 def forward(self, d1, d2, d3, d4): 68 binary_op = self.__class__.binary_op_pt_func 69 unary_op = self.__class__.unary_op_pt_func 70 return self._eval(d1, d2, d3, d4, binary_op, unary_op) 71 72 def reference(self): 73 binary_op = self.__class__.binary_op_np_func 74 unary_op = self.__class__.unary_op_np_func 75 [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]] 76 return self._eval(d1, d2, d3, d4, binary_op, unary_op) 77 78 def config(self): 79 return [self.N] 80 81 @classmethod 82 def module(cls): 83 return "element_" + cls.op_str 84 85 def memory_workload(self): 86 input_count = len(self.inputs) 87 if self.mode == "fwd": 88 if self.split_input: 89 sol_count = input_count + 1 90 algorithmic_count = input_count + 1 91 else: 92 sol_count = 1 + 1 93 algorithmic_count = 1 + 1 94 if "rand" in self.op_str: 95 sol_count = 1 96 algorithmic_count = 1 97 else: 98 if self.split_input: 99 sol_count = (input_count + 1) + (1 + input_count) 100 algorithmic_count = (input_count + 1) + ((2 + 1) * input_count) 101 else: 102 sol_count = 1 + 1 103 algorithmic_count = 1 + 1 104 if "rand" in self.op_str: 105 sol_count = 1 106 algorithmic_count = 1 107 108 buffer_size = self.N 109 return { 110 "sol": buffer_size * sol_count, 111 "algorithmic": buffer_size * algorithmic_count, 112 } 113 114 @staticmethod 115 def default_configs(): 116 return [[1 << 25]] 117 118 119def register_element_ops(): 120 binary_op_list = [ 121 ["mul", operator.mul], 122 ["add", operator.add], 123 ["sub", operator.sub], 124 ["div", lambda a, b: a / (b + 1e-4)], 125 [ 126 "pow", 127 torch.pow, 128 np.power, 129 ], # no fuson triggered 130 ["max", torch.max, np.maximum], 131 ["min", torch.min, np.minimum], 132 ] 133 134 unary_op_list = [ 135 ["erf", torch.erf, scipy.special.erf], 136 ["exp", torch.exp, np.exp], 137 ["sin", torch.sin, np.sin], 138 ["cos", torch.cos, np.cos], 139 ["rand_like", torch.rand_like, lambda x: np.random.rand(*x.shape)], 140 ] 141 142 for split_input, binary_op in itertools.product([True, False], binary_op_list): 143 # Make a copy of ElementBench 144 if len(binary_op) == 2: 145 [op_str, op_pt_func] = binary_op 146 op_np_func = op_pt_func 147 elif len(binary_op) == 3: 148 [op_str, op_pt_func, op_np_func] = binary_op 149 split_str = "split" if split_input else "shared" 150 op_str = split_str + "_" + op_str 151 bm_cls = type("ElementBench_" + op_str, (ElementBench,), {}) 152 bm_cls.op_str = op_str 153 bm_cls.binary_op_pt_func = op_pt_func 154 bm_cls.binary_op_np_func = op_np_func 155 bm_cls.split_input = split_input 156 benchmark.register_benchmark_class(bm_cls) 157 158 for split_input, unary_op in itertools.product([True, False], unary_op_list): 159 # Make a copy of ElementBench 160 if len(unary_op) == 2: 161 [op_str, op_pt_func] = unary_op 162 op_np_func = op_pt_func 163 elif len(unary_op) == 3: 164 [op_str, op_pt_func, op_np_func] = unary_op 165 split_str = "split" if split_input else "shared" 166 op_str = split_str + "_" + op_str 167 bm_cls = type("ElementBench_" + op_str, (ElementBench,), {}) 168 bm_cls.op_str = op_str 169 bm_cls.unary_op_pt_func = op_pt_func 170 bm_cls.unary_op_np_func = op_np_func 171 bm_cls.split_input = split_input 172 benchmark.register_benchmark_class(bm_cls) 173 174 175# benchmark.register_benchmark_class(ElementMulBench) 176register_element_ops() 177 178 179class SimpleElementBench(benchmark.Benchmark): 180 def __init__(self, mode, device, dtype, N): 181 super().__init__(mode, device, dtype) 182 self.N = N 183 self.data = self.rand( 184 [N], device=device, dtype=dtype, requires_grad=self.requires_grad 185 ) 186 self.inputs = [self.data] 187 188 def forward(self, data): 189 a = data + 0.001 190 b = a + 0.002 191 return b 192 193 def reference(self): 194 binary_op = self.__class__.binary_op_np_func 195 unary_op = self.__class__.unary_op_np_func 196 [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]] 197 return self._eval(d1, d2, d3, d4, binary_op, unary_op) 198 199 def config(self): 200 return [self.N] 201 202 @staticmethod 203 def input_iterable(): 204 return True 205 206 @classmethod 207 def module(cls): 208 return "simple_element" 209 210 def memory_workload(self): 211 input_count = len(self.inputs) 212 if self.mode == "fwd": 213 sol_count = 2 214 algorithmic_count = 2 215 else: 216 sol_count = 2 217 algorithmic_count = 2 218 219 buffer_size = self.N 220 return { 221 "sol": buffer_size * sol_count, 222 "algorithmic": buffer_size * algorithmic_count, 223 } 224 225 @staticmethod 226 def default_configs(): 227 return [[1 << 25]] 228 229 230benchmark.register_benchmark_class(SimpleElementBench) 231 232 233class DynamicSimpleElementBench(benchmark.DynamicShape, SimpleElementBench): 234 def __init__(self, mode, device, dtype, N): 235 benchmark.DynamicShape.__init__(self) 236 SimpleElementBench.__init__(self, mode, device, dtype, N) 237 238 @classmethod 239 def module(cls): 240 return "simple_dynamic_element" 241 242 def instantiate_input(self): 243 (N,) = self.rand_shape([self.N]) 244 data = self.rand( 245 [N], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad 246 ) 247 self.inputs = [data] 248 249 250benchmark.register_benchmark_class(DynamicSimpleElementBench) 251