1*da0073e9SAndroid Build Coastguard Workerimport itertools 2*da0073e9SAndroid Build Coastguard Workerimport operator 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport numpy as np 5*da0073e9SAndroid Build Coastguard Workerimport scipy.special 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerfrom . import benchmark 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker# A template class for elementwise operations. 13*da0073e9SAndroid Build Coastguard Worker# A derived class will override the class instance to customize its behavior. 14*da0073e9SAndroid Build Coastguard Workerclass ElementBench(benchmark.Benchmark): 15*da0073e9SAndroid Build Coastguard Worker # List of customization class variables. 16*da0073e9SAndroid Build Coastguard Worker op_str = None 17*da0073e9SAndroid Build Coastguard Worker binary_op_pt_func = None 18*da0073e9SAndroid Build Coastguard Worker binary_op_np_func = None 19*da0073e9SAndroid Build Coastguard Worker unary_op_pt_func = None 20*da0073e9SAndroid Build Coastguard Worker unary_op_np_func = None 21*da0073e9SAndroid Build Coastguard Worker split_input = True 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, N): 24*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype) 25*da0073e9SAndroid Build Coastguard Worker self.N = N 26*da0073e9SAndroid Build Coastguard Worker self.d1 = self.rand( 27*da0073e9SAndroid Build Coastguard Worker [N], device=device, dtype=dtype, requires_grad=self.requires_grad 28*da0073e9SAndroid Build Coastguard Worker ) 29*da0073e9SAndroid Build Coastguard Worker self.d2 = self.rand( 30*da0073e9SAndroid Build Coastguard Worker [N], device=device, dtype=dtype, requires_grad=self.requires_grad 31*da0073e9SAndroid Build Coastguard Worker ) 32*da0073e9SAndroid Build Coastguard Worker self.d3 = self.rand( 33*da0073e9SAndroid Build Coastguard Worker [N], device=device, dtype=dtype, requires_grad=self.requires_grad 34*da0073e9SAndroid Build Coastguard Worker ) 35*da0073e9SAndroid Build Coastguard Worker self.d4 = self.rand( 36*da0073e9SAndroid Build Coastguard Worker [N], device=device, dtype=dtype, requires_grad=self.requires_grad 37*da0073e9SAndroid Build Coastguard Worker ) 38*da0073e9SAndroid Build Coastguard Worker self.inputs = [self.d1, self.d2, self.d3, self.d4] 39*da0073e9SAndroid Build Coastguard Worker self.deterministic = "rand" not in self.op_str 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker def _eval(self, d1, d2, d3, d4, binary_op, unary_op): 42*da0073e9SAndroid Build Coastguard Worker if not binary_op: 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker def binary_op(x, y): 45*da0073e9SAndroid Build Coastguard Worker return x + y 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker if not unary_op: 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker def unary_op(x): 50*da0073e9SAndroid Build Coastguard Worker return x 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker if self.split_input: 53*da0073e9SAndroid Build Coastguard Worker d1 = unary_op(d1) 54*da0073e9SAndroid Build Coastguard Worker d2 = unary_op(d2) 55*da0073e9SAndroid Build Coastguard Worker d3 = unary_op(d3) 56*da0073e9SAndroid Build Coastguard Worker d4 = unary_op(d4) 57*da0073e9SAndroid Build Coastguard Worker else: 58*da0073e9SAndroid Build Coastguard Worker d2 = unary_op(d1 + 0.001) 59*da0073e9SAndroid Build Coastguard Worker d3 = unary_op(d1 + 0.002) 60*da0073e9SAndroid Build Coastguard Worker d4 = unary_op(d1 + 0.003) 61*da0073e9SAndroid Build Coastguard Worker d1 = unary_op(d1) 62*da0073e9SAndroid Build Coastguard Worker a = binary_op(d1, d2) 63*da0073e9SAndroid Build Coastguard Worker b = binary_op(d3, d4) 64*da0073e9SAndroid Build Coastguard Worker c = a + b 65*da0073e9SAndroid Build Coastguard Worker return c 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker def forward(self, d1, d2, d3, d4): 68*da0073e9SAndroid Build Coastguard Worker binary_op = self.__class__.binary_op_pt_func 69*da0073e9SAndroid Build Coastguard Worker unary_op = self.__class__.unary_op_pt_func 70*da0073e9SAndroid Build Coastguard Worker return self._eval(d1, d2, d3, d4, binary_op, unary_op) 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker def reference(self): 73*da0073e9SAndroid Build Coastguard Worker binary_op = self.__class__.binary_op_np_func 74*da0073e9SAndroid Build Coastguard Worker unary_op = self.__class__.unary_op_np_func 75*da0073e9SAndroid Build Coastguard Worker [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]] 76*da0073e9SAndroid Build Coastguard Worker return self._eval(d1, d2, d3, d4, binary_op, unary_op) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker def config(self): 79*da0073e9SAndroid Build Coastguard Worker return [self.N] 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker @classmethod 82*da0073e9SAndroid Build Coastguard Worker def module(cls): 83*da0073e9SAndroid Build Coastguard Worker return "element_" + cls.op_str 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker def memory_workload(self): 86*da0073e9SAndroid Build Coastguard Worker input_count = len(self.inputs) 87*da0073e9SAndroid Build Coastguard Worker if self.mode == "fwd": 88*da0073e9SAndroid Build Coastguard Worker if self.split_input: 89*da0073e9SAndroid Build Coastguard Worker sol_count = input_count + 1 90*da0073e9SAndroid Build Coastguard Worker algorithmic_count = input_count + 1 91*da0073e9SAndroid Build Coastguard Worker else: 92*da0073e9SAndroid Build Coastguard Worker sol_count = 1 + 1 93*da0073e9SAndroid Build Coastguard Worker algorithmic_count = 1 + 1 94*da0073e9SAndroid Build Coastguard Worker if "rand" in self.op_str: 95*da0073e9SAndroid Build Coastguard Worker sol_count = 1 96*da0073e9SAndroid Build Coastguard Worker algorithmic_count = 1 97*da0073e9SAndroid Build Coastguard Worker else: 98*da0073e9SAndroid Build Coastguard Worker if self.split_input: 99*da0073e9SAndroid Build Coastguard Worker sol_count = (input_count + 1) + (1 + input_count) 100*da0073e9SAndroid Build Coastguard Worker algorithmic_count = (input_count + 1) + ((2 + 1) * input_count) 101*da0073e9SAndroid Build Coastguard Worker else: 102*da0073e9SAndroid Build Coastguard Worker sol_count = 1 + 1 103*da0073e9SAndroid Build Coastguard Worker algorithmic_count = 1 + 1 104*da0073e9SAndroid Build Coastguard Worker if "rand" in self.op_str: 105*da0073e9SAndroid Build Coastguard Worker sol_count = 1 106*da0073e9SAndroid Build Coastguard Worker algorithmic_count = 1 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker buffer_size = self.N 109*da0073e9SAndroid Build Coastguard Worker return { 110*da0073e9SAndroid Build Coastguard Worker "sol": buffer_size * sol_count, 111*da0073e9SAndroid Build Coastguard Worker "algorithmic": buffer_size * algorithmic_count, 112*da0073e9SAndroid Build Coastguard Worker } 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker @staticmethod 115*da0073e9SAndroid Build Coastguard Worker def default_configs(): 116*da0073e9SAndroid Build Coastguard Worker return [[1 << 25]] 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Workerdef register_element_ops(): 120*da0073e9SAndroid Build Coastguard Worker binary_op_list = [ 121*da0073e9SAndroid Build Coastguard Worker ["mul", operator.mul], 122*da0073e9SAndroid Build Coastguard Worker ["add", operator.add], 123*da0073e9SAndroid Build Coastguard Worker ["sub", operator.sub], 124*da0073e9SAndroid Build Coastguard Worker ["div", lambda a, b: a / (b + 1e-4)], 125*da0073e9SAndroid Build Coastguard Worker [ 126*da0073e9SAndroid Build Coastguard Worker "pow", 127*da0073e9SAndroid Build Coastguard Worker torch.pow, 128*da0073e9SAndroid Build Coastguard Worker np.power, 129*da0073e9SAndroid Build Coastguard Worker ], # no fuson triggered 130*da0073e9SAndroid Build Coastguard Worker ["max", torch.max, np.maximum], 131*da0073e9SAndroid Build Coastguard Worker ["min", torch.min, np.minimum], 132*da0073e9SAndroid Build Coastguard Worker ] 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker unary_op_list = [ 135*da0073e9SAndroid Build Coastguard Worker ["erf", torch.erf, scipy.special.erf], 136*da0073e9SAndroid Build Coastguard Worker ["exp", torch.exp, np.exp], 137*da0073e9SAndroid Build Coastguard Worker ["sin", torch.sin, np.sin], 138*da0073e9SAndroid Build Coastguard Worker ["cos", torch.cos, np.cos], 139*da0073e9SAndroid Build Coastguard Worker ["rand_like", torch.rand_like, lambda x: np.random.rand(*x.shape)], 140*da0073e9SAndroid Build Coastguard Worker ] 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker for split_input, binary_op in itertools.product([True, False], binary_op_list): 143*da0073e9SAndroid Build Coastguard Worker # Make a copy of ElementBench 144*da0073e9SAndroid Build Coastguard Worker if len(binary_op) == 2: 145*da0073e9SAndroid Build Coastguard Worker [op_str, op_pt_func] = binary_op 146*da0073e9SAndroid Build Coastguard Worker op_np_func = op_pt_func 147*da0073e9SAndroid Build Coastguard Worker elif len(binary_op) == 3: 148*da0073e9SAndroid Build Coastguard Worker [op_str, op_pt_func, op_np_func] = binary_op 149*da0073e9SAndroid Build Coastguard Worker split_str = "split" if split_input else "shared" 150*da0073e9SAndroid Build Coastguard Worker op_str = split_str + "_" + op_str 151*da0073e9SAndroid Build Coastguard Worker bm_cls = type("ElementBench_" + op_str, (ElementBench,), {}) 152*da0073e9SAndroid Build Coastguard Worker bm_cls.op_str = op_str 153*da0073e9SAndroid Build Coastguard Worker bm_cls.binary_op_pt_func = op_pt_func 154*da0073e9SAndroid Build Coastguard Worker bm_cls.binary_op_np_func = op_np_func 155*da0073e9SAndroid Build Coastguard Worker bm_cls.split_input = split_input 156*da0073e9SAndroid Build Coastguard Worker benchmark.register_benchmark_class(bm_cls) 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker for split_input, unary_op in itertools.product([True, False], unary_op_list): 159*da0073e9SAndroid Build Coastguard Worker # Make a copy of ElementBench 160*da0073e9SAndroid Build Coastguard Worker if len(unary_op) == 2: 161*da0073e9SAndroid Build Coastguard Worker [op_str, op_pt_func] = unary_op 162*da0073e9SAndroid Build Coastguard Worker op_np_func = op_pt_func 163*da0073e9SAndroid Build Coastguard Worker elif len(unary_op) == 3: 164*da0073e9SAndroid Build Coastguard Worker [op_str, op_pt_func, op_np_func] = unary_op 165*da0073e9SAndroid Build Coastguard Worker split_str = "split" if split_input else "shared" 166*da0073e9SAndroid Build Coastguard Worker op_str = split_str + "_" + op_str 167*da0073e9SAndroid Build Coastguard Worker bm_cls = type("ElementBench_" + op_str, (ElementBench,), {}) 168*da0073e9SAndroid Build Coastguard Worker bm_cls.op_str = op_str 169*da0073e9SAndroid Build Coastguard Worker bm_cls.unary_op_pt_func = op_pt_func 170*da0073e9SAndroid Build Coastguard Worker bm_cls.unary_op_np_func = op_np_func 171*da0073e9SAndroid Build Coastguard Worker bm_cls.split_input = split_input 172*da0073e9SAndroid Build Coastguard Worker benchmark.register_benchmark_class(bm_cls) 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker# benchmark.register_benchmark_class(ElementMulBench) 176*da0073e9SAndroid Build Coastguard Workerregister_element_ops() 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Workerclass SimpleElementBench(benchmark.Benchmark): 180*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, N): 181*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype) 182*da0073e9SAndroid Build Coastguard Worker self.N = N 183*da0073e9SAndroid Build Coastguard Worker self.data = self.rand( 184*da0073e9SAndroid Build Coastguard Worker [N], device=device, dtype=dtype, requires_grad=self.requires_grad 185*da0073e9SAndroid Build Coastguard Worker ) 186*da0073e9SAndroid Build Coastguard Worker self.inputs = [self.data] 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker def forward(self, data): 189*da0073e9SAndroid Build Coastguard Worker a = data + 0.001 190*da0073e9SAndroid Build Coastguard Worker b = a + 0.002 191*da0073e9SAndroid Build Coastguard Worker return b 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker def reference(self): 194*da0073e9SAndroid Build Coastguard Worker binary_op = self.__class__.binary_op_np_func 195*da0073e9SAndroid Build Coastguard Worker unary_op = self.__class__.unary_op_np_func 196*da0073e9SAndroid Build Coastguard Worker [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]] 197*da0073e9SAndroid Build Coastguard Worker return self._eval(d1, d2, d3, d4, binary_op, unary_op) 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker def config(self): 200*da0073e9SAndroid Build Coastguard Worker return [self.N] 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker @staticmethod 203*da0073e9SAndroid Build Coastguard Worker def input_iterable(): 204*da0073e9SAndroid Build Coastguard Worker return True 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Worker @classmethod 207*da0073e9SAndroid Build Coastguard Worker def module(cls): 208*da0073e9SAndroid Build Coastguard Worker return "simple_element" 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker def memory_workload(self): 211*da0073e9SAndroid Build Coastguard Worker input_count = len(self.inputs) 212*da0073e9SAndroid Build Coastguard Worker if self.mode == "fwd": 213*da0073e9SAndroid Build Coastguard Worker sol_count = 2 214*da0073e9SAndroid Build Coastguard Worker algorithmic_count = 2 215*da0073e9SAndroid Build Coastguard Worker else: 216*da0073e9SAndroid Build Coastguard Worker sol_count = 2 217*da0073e9SAndroid Build Coastguard Worker algorithmic_count = 2 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker buffer_size = self.N 220*da0073e9SAndroid Build Coastguard Worker return { 221*da0073e9SAndroid Build Coastguard Worker "sol": buffer_size * sol_count, 222*da0073e9SAndroid Build Coastguard Worker "algorithmic": buffer_size * algorithmic_count, 223*da0073e9SAndroid Build Coastguard Worker } 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker @staticmethod 226*da0073e9SAndroid Build Coastguard Worker def default_configs(): 227*da0073e9SAndroid Build Coastguard Worker return [[1 << 25]] 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(SimpleElementBench) 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Workerclass DynamicSimpleElementBench(benchmark.DynamicShape, SimpleElementBench): 234*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, N): 235*da0073e9SAndroid Build Coastguard Worker benchmark.DynamicShape.__init__(self) 236*da0073e9SAndroid Build Coastguard Worker SimpleElementBench.__init__(self, mode, device, dtype, N) 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker @classmethod 239*da0073e9SAndroid Build Coastguard Worker def module(cls): 240*da0073e9SAndroid Build Coastguard Worker return "simple_dynamic_element" 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker def instantiate_input(self): 243*da0073e9SAndroid Build Coastguard Worker (N,) = self.rand_shape([self.N]) 244*da0073e9SAndroid Build Coastguard Worker data = self.rand( 245*da0073e9SAndroid Build Coastguard Worker [N], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad 246*da0073e9SAndroid Build Coastguard Worker ) 247*da0073e9SAndroid Build Coastguard Worker self.inputs = [data] 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(DynamicSimpleElementBench) 251