1*da0073e9SAndroid Build Coastguard Workerimport itertools 2*da0073e9SAndroid Build Coastguard Workerimport operator 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport numpy as np 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerfrom . import benchmark 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerclass BroadcastMulBench(benchmark.Benchmark): 12*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, case, M, N, K): 13*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype) 14*da0073e9SAndroid Build Coastguard Worker self.case = case 15*da0073e9SAndroid Build Coastguard Worker self.M = M 16*da0073e9SAndroid Build Coastguard Worker self.N = N 17*da0073e9SAndroid Build Coastguard Worker self.K = K 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker if case == "row": 20*da0073e9SAndroid Build Coastguard Worker self.d1 = self.rand( 21*da0073e9SAndroid Build Coastguard Worker [M, N, 1], device=device, dtype=dtype, requires_grad=self.requires_grad 22*da0073e9SAndroid Build Coastguard Worker ) 23*da0073e9SAndroid Build Coastguard Worker self.d2 = self.rand( 24*da0073e9SAndroid Build Coastguard Worker [M, 1, K], device=device, dtype=dtype, requires_grad=self.requires_grad 25*da0073e9SAndroid Build Coastguard Worker ) 26*da0073e9SAndroid Build Coastguard Worker elif case == "mid": 27*da0073e9SAndroid Build Coastguard Worker self.d1 = self.rand( 28*da0073e9SAndroid Build Coastguard Worker [M, N, 1], device=device, dtype=dtype, requires_grad=self.requires_grad 29*da0073e9SAndroid Build Coastguard Worker ) 30*da0073e9SAndroid Build Coastguard Worker self.d2 = self.rand( 31*da0073e9SAndroid Build Coastguard Worker [1, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad 32*da0073e9SAndroid Build Coastguard Worker ) 33*da0073e9SAndroid Build Coastguard Worker elif case == "col": 34*da0073e9SAndroid Build Coastguard Worker self.d1 = self.rand( 35*da0073e9SAndroid Build Coastguard Worker [M, 1, K], device=device, dtype=dtype, requires_grad=self.requires_grad 36*da0073e9SAndroid Build Coastguard Worker ) 37*da0073e9SAndroid Build Coastguard Worker self.d2 = self.rand( 38*da0073e9SAndroid Build Coastguard Worker [1, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad 39*da0073e9SAndroid Build Coastguard Worker ) 40*da0073e9SAndroid Build Coastguard Worker else: 41*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"invalid case: {case}") 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker self.inputs = [self.d1, self.d2] 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker def forward(self, d1, d2): 46*da0073e9SAndroid Build Coastguard Worker y = d1 + d2 47*da0073e9SAndroid Build Coastguard Worker return y 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker def reference(self): 50*da0073e9SAndroid Build Coastguard Worker return self.numpy(self.d1) + self.numpy(self.d2) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker def config(self): 53*da0073e9SAndroid Build Coastguard Worker return [self.M, self.N, self.K] 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker @staticmethod 56*da0073e9SAndroid Build Coastguard Worker def default_configs(): 57*da0073e9SAndroid Build Coastguard Worker return [[128, 256, 128]] 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker def memory_workload(self): 60*da0073e9SAndroid Build Coastguard Worker if self.mode == "fwd": 61*da0073e9SAndroid Build Coastguard Worker sol_count = 1 62*da0073e9SAndroid Build Coastguard Worker algorithmic_count = 1 63*da0073e9SAndroid Build Coastguard Worker else: 64*da0073e9SAndroid Build Coastguard Worker sol_count = (1) + (1) 65*da0073e9SAndroid Build Coastguard Worker algorithmic_count = 1 + (1 + 1) 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker buffer_size = self.M * self.N * self.K 68*da0073e9SAndroid Build Coastguard Worker return { 69*da0073e9SAndroid Build Coastguard Worker "sol": buffer_size * sol_count, 70*da0073e9SAndroid Build Coastguard Worker "algorithmic": buffer_size * algorithmic_count, 71*da0073e9SAndroid Build Coastguard Worker } 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Workerclass BroadcastRowBench(BroadcastMulBench): 75*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, M, N, K): 76*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype, "row", M, N, K) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker @staticmethod 79*da0073e9SAndroid Build Coastguard Worker def module(): 80*da0073e9SAndroid Build Coastguard Worker return "broadcast_row" 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Workerclass BroadcastMidBench(BroadcastMulBench): 84*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, M, N, K): 85*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype, "mid", M, N, K) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker @staticmethod 88*da0073e9SAndroid Build Coastguard Worker def module(): 89*da0073e9SAndroid Build Coastguard Worker return "broadcast_mid" 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Workerclass BroadcastColBench(BroadcastMulBench): 93*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, M, N, K): 94*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype, "col", M, N, K) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker @staticmethod 97*da0073e9SAndroid Build Coastguard Worker def module(): 98*da0073e9SAndroid Build Coastguard Worker return "broadcast_col" 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Workerclass BroadcastThreeArgs(benchmark.Benchmark): 102*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, M, N, K, L): 103*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype) 104*da0073e9SAndroid Build Coastguard Worker self.M = M 105*da0073e9SAndroid Build Coastguard Worker self.N = N 106*da0073e9SAndroid Build Coastguard Worker self.K = K 107*da0073e9SAndroid Build Coastguard Worker self.L = L 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker self.d1 = self.rand( 110*da0073e9SAndroid Build Coastguard Worker [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad 111*da0073e9SAndroid Build Coastguard Worker ) 112*da0073e9SAndroid Build Coastguard Worker self.d2 = self.rand( 113*da0073e9SAndroid Build Coastguard Worker [K, M, 1], device=device, dtype=dtype, requires_grad=self.requires_grad 114*da0073e9SAndroid Build Coastguard Worker ) 115*da0073e9SAndroid Build Coastguard Worker self.d3 = self.rand( 116*da0073e9SAndroid Build Coastguard Worker [L, K, 1, 1], device=device, dtype=dtype, requires_grad=self.requires_grad 117*da0073e9SAndroid Build Coastguard Worker ) 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker self.inputs = [self.d1, self.d2, self.d3] 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker def forward(self, d1, d2, d3): 122*da0073e9SAndroid Build Coastguard Worker y = d1 + d2 + d3 123*da0073e9SAndroid Build Coastguard Worker return y 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker def reference(self): 126*da0073e9SAndroid Build Coastguard Worker return self.numpy(self.d1) + self.numpy(self.d2) + self.numpy(self.d3) 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker def config(self): 129*da0073e9SAndroid Build Coastguard Worker return [self.M, self.N, self.K, self.L] 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker @staticmethod 132*da0073e9SAndroid Build Coastguard Worker def default_configs(): 133*da0073e9SAndroid Build Coastguard Worker return [[32, 16, 64, 128]] 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker def memory_workload(self): 136*da0073e9SAndroid Build Coastguard Worker if self.mode == "fwd": 137*da0073e9SAndroid Build Coastguard Worker sol_count = 1 138*da0073e9SAndroid Build Coastguard Worker algorithmic_count = 1 139*da0073e9SAndroid Build Coastguard Worker else: 140*da0073e9SAndroid Build Coastguard Worker sol_count = (1) + (1) 141*da0073e9SAndroid Build Coastguard Worker algorithmic_count = 1 + (1 + 1 + 1) 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker buffer_size = self.M * self.N * self.K * self.L * 4 144*da0073e9SAndroid Build Coastguard Worker return { 145*da0073e9SAndroid Build Coastguard Worker "sol": buffer_size * sol_count, 146*da0073e9SAndroid Build Coastguard Worker "algorithmic": buffer_size * algorithmic_count, 147*da0073e9SAndroid Build Coastguard Worker } 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker @staticmethod 150*da0073e9SAndroid Build Coastguard Worker def module(): 151*da0073e9SAndroid Build Coastguard Worker return "broadcast_3args" 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker# benchmark.register_benchmark_class(BroadcastRowBench) 155*da0073e9SAndroid Build Coastguard Worker# benchmark.register_benchmark_class(BroadcastMidBench) 156*da0073e9SAndroid Build Coastguard Worker# benchmark.register_benchmark_class(BroadcastColBench) 157*da0073e9SAndroid Build Coastguard Worker# benchmark.register_benchmark_class(BroadcastThreeArgs) 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker# TODO: merge this with elementwise bench 161*da0073e9SAndroid Build Coastguard Worker# A template class for elementwise operations. 162*da0073e9SAndroid Build Coastguard Worker# A derived class will override the class instance to customize its behavior. 163*da0073e9SAndroid Build Coastguard Workerclass BroadcastBench(benchmark.Benchmark): 164*da0073e9SAndroid Build Coastguard Worker # List of customization class variables. 165*da0073e9SAndroid Build Coastguard Worker op_str = None 166*da0073e9SAndroid Build Coastguard Worker binary_op_pt_func = None 167*da0073e9SAndroid Build Coastguard Worker binary_op_np_func = None 168*da0073e9SAndroid Build Coastguard Worker unary_op_pt_func = None 169*da0073e9SAndroid Build Coastguard Worker unary_op_np_func = None 170*da0073e9SAndroid Build Coastguard Worker split_input = True 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, M, N, K): 173*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype) 174*da0073e9SAndroid Build Coastguard Worker self.M = M 175*da0073e9SAndroid Build Coastguard Worker self.N = N 176*da0073e9SAndroid Build Coastguard Worker self.K = K 177*da0073e9SAndroid Build Coastguard Worker self.d1 = self.rand( 178*da0073e9SAndroid Build Coastguard Worker [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad 179*da0073e9SAndroid Build Coastguard Worker ) 180*da0073e9SAndroid Build Coastguard Worker self.d2 = self.rand( 181*da0073e9SAndroid Build Coastguard Worker [K, 1, N], device=device, dtype=dtype, requires_grad=self.requires_grad 182*da0073e9SAndroid Build Coastguard Worker ) 183*da0073e9SAndroid Build Coastguard Worker self.d3 = self.rand( 184*da0073e9SAndroid Build Coastguard Worker [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad 185*da0073e9SAndroid Build Coastguard Worker ) 186*da0073e9SAndroid Build Coastguard Worker self.d4 = self.rand( 187*da0073e9SAndroid Build Coastguard Worker [K, M, 1], device=device, dtype=dtype, requires_grad=self.requires_grad 188*da0073e9SAndroid Build Coastguard Worker ) 189*da0073e9SAndroid Build Coastguard Worker self.inputs = [self.d1, self.d2, self.d3, self.d4] 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker def _eval(self, d1, d2, d3, d4, binary_op, unary_op): 192*da0073e9SAndroid Build Coastguard Worker if not binary_op: 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker def binary_op(x, y): 195*da0073e9SAndroid Build Coastguard Worker return x + y 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker if not unary_op: 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker def unary_op(x): 200*da0073e9SAndroid Build Coastguard Worker return x 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker if self.split_input: 203*da0073e9SAndroid Build Coastguard Worker d1 = unary_op(d1) 204*da0073e9SAndroid Build Coastguard Worker d2 = unary_op(d2) 205*da0073e9SAndroid Build Coastguard Worker d3 = unary_op(d3) 206*da0073e9SAndroid Build Coastguard Worker d4 = unary_op(d4) 207*da0073e9SAndroid Build Coastguard Worker else: 208*da0073e9SAndroid Build Coastguard Worker d1, d2, d3, d4 = ( 209*da0073e9SAndroid Build Coastguard Worker unary_op(d1), 210*da0073e9SAndroid Build Coastguard Worker unary_op(d2), 211*da0073e9SAndroid Build Coastguard Worker unary_op(d1 + 0.001), 212*da0073e9SAndroid Build Coastguard Worker unary_op(d4), 213*da0073e9SAndroid Build Coastguard Worker ) 214*da0073e9SAndroid Build Coastguard Worker a = binary_op(d1, d2) 215*da0073e9SAndroid Build Coastguard Worker b = binary_op(d3, d4) 216*da0073e9SAndroid Build Coastguard Worker c = a + b 217*da0073e9SAndroid Build Coastguard Worker return c 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker def forward(self, d1, d2, d3, d4): 220*da0073e9SAndroid Build Coastguard Worker binary_op = self.__class__.binary_op_pt_func 221*da0073e9SAndroid Build Coastguard Worker unary_op = self.__class__.unary_op_pt_func 222*da0073e9SAndroid Build Coastguard Worker return self._eval(d1, d2, d3, d4, binary_op, unary_op) 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker def reference(self): 225*da0073e9SAndroid Build Coastguard Worker binary_op = self.__class__.binary_op_np_func 226*da0073e9SAndroid Build Coastguard Worker unary_op = self.__class__.unary_op_np_func 227*da0073e9SAndroid Build Coastguard Worker [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]] 228*da0073e9SAndroid Build Coastguard Worker return self._eval(d1, d2, d3, d4, binary_op, unary_op) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker def config(self): 231*da0073e9SAndroid Build Coastguard Worker return [self.M, self.N, self.K] 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker @classmethod 234*da0073e9SAndroid Build Coastguard Worker def module(cls): 235*da0073e9SAndroid Build Coastguard Worker return "broadcast_" + cls.op_str 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker def memory_workload(self): 238*da0073e9SAndroid Build Coastguard Worker input_count = len(self.inputs) 239*da0073e9SAndroid Build Coastguard Worker if self.mode == "fwd": 240*da0073e9SAndroid Build Coastguard Worker if self.split_input: 241*da0073e9SAndroid Build Coastguard Worker sol_count = 1 242*da0073e9SAndroid Build Coastguard Worker algorithmic_count = 1 243*da0073e9SAndroid Build Coastguard Worker else: 244*da0073e9SAndroid Build Coastguard Worker sol_count = 1 245*da0073e9SAndroid Build Coastguard Worker algorithmic_count = 1 246*da0073e9SAndroid Build Coastguard Worker else: 247*da0073e9SAndroid Build Coastguard Worker if self.split_input: 248*da0073e9SAndroid Build Coastguard Worker sol_count = 1 249*da0073e9SAndroid Build Coastguard Worker algorithmic_count = input_count 250*da0073e9SAndroid Build Coastguard Worker else: 251*da0073e9SAndroid Build Coastguard Worker sol_count = 1 252*da0073e9SAndroid Build Coastguard Worker algorithmic_count = input_count 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker buffer_size = self.M * self.N * self.K * 4 255*da0073e9SAndroid Build Coastguard Worker return { 256*da0073e9SAndroid Build Coastguard Worker "sol": buffer_size * sol_count, 257*da0073e9SAndroid Build Coastguard Worker "algorithmic": buffer_size * algorithmic_count, 258*da0073e9SAndroid Build Coastguard Worker } 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker @staticmethod 261*da0073e9SAndroid Build Coastguard Worker def default_configs(): 262*da0073e9SAndroid Build Coastguard Worker return [[1 << 8, 1 << 7, 1 << 9]] 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Workerdef register_broadcast_ops(): 266*da0073e9SAndroid Build Coastguard Worker binary_op_list = [ 267*da0073e9SAndroid Build Coastguard Worker ["mul", operator.mul], 268*da0073e9SAndroid Build Coastguard Worker ["add", operator.add], 269*da0073e9SAndroid Build Coastguard Worker ["sub", operator.sub], 270*da0073e9SAndroid Build Coastguard Worker ["div", lambda a, b: a / (b + 1e-4)], 271*da0073e9SAndroid Build Coastguard Worker [ 272*da0073e9SAndroid Build Coastguard Worker "pow", 273*da0073e9SAndroid Build Coastguard Worker torch.pow, 274*da0073e9SAndroid Build Coastguard Worker np.power, 275*da0073e9SAndroid Build Coastguard Worker ], # no fuson triggered 276*da0073e9SAndroid Build Coastguard Worker ["max", torch.max, np.maximum], 277*da0073e9SAndroid Build Coastguard Worker ["min", torch.min, np.minimum], 278*da0073e9SAndroid Build Coastguard Worker ] 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker unary_op_list = [ 281*da0073e9SAndroid Build Coastguard Worker ["erf", torch.erf, np.erf], 282*da0073e9SAndroid Build Coastguard Worker ["exp", torch.exp, np.exp], 283*da0073e9SAndroid Build Coastguard Worker ["sin", torch.sin, np.sin], 284*da0073e9SAndroid Build Coastguard Worker ["cos", torch.cos, np.cos], 285*da0073e9SAndroid Build Coastguard Worker ] 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker for split_input, binary_op in itertools.product([True, False], binary_op_list): 288*da0073e9SAndroid Build Coastguard Worker # Make a copy of BroadcastBench 289*da0073e9SAndroid Build Coastguard Worker if len(binary_op) == 2: 290*da0073e9SAndroid Build Coastguard Worker [op_str, op_pt_func] = binary_op 291*da0073e9SAndroid Build Coastguard Worker op_np_func = op_pt_func 292*da0073e9SAndroid Build Coastguard Worker elif len(binary_op) == 3: 293*da0073e9SAndroid Build Coastguard Worker [op_str, op_pt_func, op_np_func] = binary_op 294*da0073e9SAndroid Build Coastguard Worker split_str = "split" if split_input else "shared" 295*da0073e9SAndroid Build Coastguard Worker op_str = split_str + "_" + op_str 296*da0073e9SAndroid Build Coastguard Worker bm_cls = type("BroadcastBench_" + op_str, (BroadcastBench,), {}) 297*da0073e9SAndroid Build Coastguard Worker bm_cls.op_str = op_str 298*da0073e9SAndroid Build Coastguard Worker bm_cls.binary_op_pt_func = op_pt_func 299*da0073e9SAndroid Build Coastguard Worker bm_cls.binary_op_np_func = op_np_func 300*da0073e9SAndroid Build Coastguard Worker bm_cls.split_input = split_input 301*da0073e9SAndroid Build Coastguard Worker benchmark.register_benchmark_class(bm_cls) 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker for split_input, unary_op in itertools.product([True, False], unary_op_list): 304*da0073e9SAndroid Build Coastguard Worker # Make a copy of BroadcastBench 305*da0073e9SAndroid Build Coastguard Worker if len(unary_op) == 2: 306*da0073e9SAndroid Build Coastguard Worker [op_str, op_pt_func] = unary_op 307*da0073e9SAndroid Build Coastguard Worker op_np_func = op_pt_func 308*da0073e9SAndroid Build Coastguard Worker elif len(unary_op) == 3: 309*da0073e9SAndroid Build Coastguard Worker [op_str, op_pt_func, op_np_func] = unary_op 310*da0073e9SAndroid Build Coastguard Worker split_str = "split" if split_input else "shared" 311*da0073e9SAndroid Build Coastguard Worker op_str = split_str + "_" + op_str 312*da0073e9SAndroid Build Coastguard Worker bm_cls = type("BroadcastBench_" + op_str, (BroadcastBench,), {}) 313*da0073e9SAndroid Build Coastguard Worker bm_cls.op_str = op_str 314*da0073e9SAndroid Build Coastguard Worker bm_cls.unary_op_pt_func = op_pt_func 315*da0073e9SAndroid Build Coastguard Worker bm_cls.unary_op_np_func = op_np_func 316*da0073e9SAndroid Build Coastguard Worker bm_cls.split_input = split_input 317*da0073e9SAndroid Build Coastguard Worker benchmark.register_benchmark_class(bm_cls) 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Workerregister_broadcast_ops() 321