1*da0073e9SAndroid Build Coastguard Workerfrom . import benchmark 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerclass ReduceBench(benchmark.Benchmark): 5*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, case, M, N, K, skip_input_transform): 6*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype) 7*da0073e9SAndroid Build Coastguard Worker self.case = case 8*da0073e9SAndroid Build Coastguard Worker self.M = M 9*da0073e9SAndroid Build Coastguard Worker self.N = N 10*da0073e9SAndroid Build Coastguard Worker self.K = K 11*da0073e9SAndroid Build Coastguard Worker self._set_skip_input_transform(skip_input_transform) 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker self.inputs = [ 14*da0073e9SAndroid Build Coastguard Worker self.randn( 15*da0073e9SAndroid Build Coastguard Worker [M, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad 16*da0073e9SAndroid Build Coastguard Worker ) 17*da0073e9SAndroid Build Coastguard Worker ] 18*da0073e9SAndroid Build Coastguard Worker if case == "row": 19*da0073e9SAndroid Build Coastguard Worker self.dims = [1, 2] 20*da0073e9SAndroid Build Coastguard Worker elif case == "mid": 21*da0073e9SAndroid Build Coastguard Worker self.dims = [0, 2] 22*da0073e9SAndroid Build Coastguard Worker elif case == "col": 23*da0073e9SAndroid Build Coastguard Worker self.dims = [0, 1] 24*da0073e9SAndroid Build Coastguard Worker elif case == "full": 25*da0073e9SAndroid Build Coastguard Worker self.dims = [0, 1, 2] 26*da0073e9SAndroid Build Coastguard Worker else: 27*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"invalid case: {case}") 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker def forward(self, inputs): 30*da0073e9SAndroid Build Coastguard Worker if self.skip_input_transform: 31*da0073e9SAndroid Build Coastguard Worker x = inputs 32*da0073e9SAndroid Build Coastguard Worker else: 33*da0073e9SAndroid Build Coastguard Worker x = self.add(inputs, 0.001) 34*da0073e9SAndroid Build Coastguard Worker y = self.sum(x, self.dims) 35*da0073e9SAndroid Build Coastguard Worker return y 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker def config(self): 38*da0073e9SAndroid Build Coastguard Worker if self.case == "full": 39*da0073e9SAndroid Build Coastguard Worker return [self.M * self.N * self.K, self._skip_input_transform_str()] 40*da0073e9SAndroid Build Coastguard Worker return [self.M, self.N, self.K, self._skip_input_transform_str()] 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker @staticmethod 43*da0073e9SAndroid Build Coastguard Worker def default_configs(): 44*da0073e9SAndroid Build Coastguard Worker return [ 45*da0073e9SAndroid Build Coastguard Worker # [512, 512, 512], 46*da0073e9SAndroid Build Coastguard Worker [512, 64, 512, "s0"], 47*da0073e9SAndroid Build Coastguard Worker ] 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker @staticmethod 50*da0073e9SAndroid Build Coastguard Worker def module(): 51*da0073e9SAndroid Build Coastguard Worker return "reduce" 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker def memory_workload(self): 54*da0073e9SAndroid Build Coastguard Worker if self.mode == "fwd": 55*da0073e9SAndroid Build Coastguard Worker sol_count = 1 56*da0073e9SAndroid Build Coastguard Worker algorithmic_count = 1 57*da0073e9SAndroid Build Coastguard Worker else: 58*da0073e9SAndroid Build Coastguard Worker sol_count = (1) + (1) 59*da0073e9SAndroid Build Coastguard Worker algorithmic_count = 1 + 1 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker buffer_size = self.M * self.N * self.K 62*da0073e9SAndroid Build Coastguard Worker return { 63*da0073e9SAndroid Build Coastguard Worker "sol": buffer_size * sol_count, 64*da0073e9SAndroid Build Coastguard Worker "algorithmic": buffer_size * algorithmic_count, 65*da0073e9SAndroid Build Coastguard Worker } 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker def _set_skip_input_transform(self, input_str): 68*da0073e9SAndroid Build Coastguard Worker # In the test setting, s1 will skip the input transformation, and s0 will not. 69*da0073e9SAndroid Build Coastguard Worker if input_str == "s0": 70*da0073e9SAndroid Build Coastguard Worker self.skip_input_transform = False 71*da0073e9SAndroid Build Coastguard Worker elif input_str == "s1": 72*da0073e9SAndroid Build Coastguard Worker self.skip_input_transform = True 73*da0073e9SAndroid Build Coastguard Worker else: 74*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"invalid skip_input_transform: {input_str}") 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker def _skip_input_transform_str(self): 77*da0073e9SAndroid Build Coastguard Worker if self.skip_input_transform: 78*da0073e9SAndroid Build Coastguard Worker return "s1" 79*da0073e9SAndroid Build Coastguard Worker else: 80*da0073e9SAndroid Build Coastguard Worker return "s0" 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Workerclass ReduceRowBench(ReduceBench): 84*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, M, N, K, skip_input_transform): 85*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype, "row", M, N, K, skip_input_transform) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker @staticmethod 88*da0073e9SAndroid Build Coastguard Worker def module(): 89*da0073e9SAndroid Build Coastguard Worker return "reduce_row" 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Workerclass ReduceMidBench(ReduceBench): 93*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, M, N, K, skip_input_transform): 94*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype, "mid", M, N, K, skip_input_transform) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker @staticmethod 97*da0073e9SAndroid Build Coastguard Worker def module(): 98*da0073e9SAndroid Build Coastguard Worker return "reduce_mid" 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Workerclass ReduceColBench(ReduceBench): 102*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, M, N, K, skip_input_transform): 103*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype, "col", M, N, K, skip_input_transform) 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker @staticmethod 106*da0073e9SAndroid Build Coastguard Worker def module(): 107*da0073e9SAndroid Build Coastguard Worker return "reduce_col" 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Workerclass ReduceFullBench(ReduceBench): 111*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, M, skip_input_transform): 112*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype, "full", M, 1, 1, skip_input_transform) 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker def config(self): 115*da0073e9SAndroid Build Coastguard Worker return [self.M * self.N * self.K, self._skip_input_transform_str()] 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker @staticmethod 118*da0073e9SAndroid Build Coastguard Worker def default_configs(): 119*da0073e9SAndroid Build Coastguard Worker return [ 120*da0073e9SAndroid Build Coastguard Worker [1 << 24, "s1"], 121*da0073e9SAndroid Build Coastguard Worker ] 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker @staticmethod 124*da0073e9SAndroid Build Coastguard Worker def module(): 125*da0073e9SAndroid Build Coastguard Worker return "reduce_full" 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Workerclass Reduce2DBench(benchmark.Benchmark): 129*da0073e9SAndroid Build Coastguard Worker """ 130*da0073e9SAndroid Build Coastguard Worker A benchmark class to validate 2 dimensional reduction performance. 131*da0073e9SAndroid Build Coastguard Worker Only a simple add is fused to induce the fuser and isolate reduction perf. 132*da0073e9SAndroid Build Coastguard Worker """ 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, red_dim, dim0, dim1): 135*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype) 136*da0073e9SAndroid Build Coastguard Worker self.red_dim = red_dim 137*da0073e9SAndroid Build Coastguard Worker self.dim0 = dim0 138*da0073e9SAndroid Build Coastguard Worker self.dim1 = dim1 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker self.inputs = [ 141*da0073e9SAndroid Build Coastguard Worker self.randn( 142*da0073e9SAndroid Build Coastguard Worker [dim0, dim1], 143*da0073e9SAndroid Build Coastguard Worker device=device, 144*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 145*da0073e9SAndroid Build Coastguard Worker requires_grad=self.requires_grad, 146*da0073e9SAndroid Build Coastguard Worker ) 147*da0073e9SAndroid Build Coastguard Worker ] 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker if red_dim != 0 and red_dim != 1: 150*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"invalid reduction dimension: {red_dim}") 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker def forward(self, inputs): 153*da0073e9SAndroid Build Coastguard Worker x = self.add(inputs, 0.001) 154*da0073e9SAndroid Build Coastguard Worker y = self.sum(x, [self.red_dim]) 155*da0073e9SAndroid Build Coastguard Worker return y 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker def config(self): 158*da0073e9SAndroid Build Coastguard Worker return [self.red_dim, self.dim0, self.dim1] 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker @staticmethod 161*da0073e9SAndroid Build Coastguard Worker def default_configs(): 162*da0073e9SAndroid Build Coastguard Worker return [ 163*da0073e9SAndroid Build Coastguard Worker [1, 640, 524288], 164*da0073e9SAndroid Build Coastguard Worker ] 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker @staticmethod 167*da0073e9SAndroid Build Coastguard Worker def module(): 168*da0073e9SAndroid Build Coastguard Worker return "reduce2d" 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker @staticmethod 171*da0073e9SAndroid Build Coastguard Worker def input_iterable(): 172*da0073e9SAndroid Build Coastguard Worker return True 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker def memory_workload(self): 175*da0073e9SAndroid Build Coastguard Worker assert self.mode == "fwd", "Only the forward operation is modeled!" 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker buffer_size = self.dim0 * self.dim1 178*da0073e9SAndroid Build Coastguard Worker if self.red_dim == 0: 179*da0073e9SAndroid Build Coastguard Worker buffer_size += self.dim1 180*da0073e9SAndroid Build Coastguard Worker else: 181*da0073e9SAndroid Build Coastguard Worker buffer_size += self.dim0 182*da0073e9SAndroid Build Coastguard Worker return { 183*da0073e9SAndroid Build Coastguard Worker "sol": buffer_size, 184*da0073e9SAndroid Build Coastguard Worker "algorithmic": buffer_size, 185*da0073e9SAndroid Build Coastguard Worker } 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Workerclass Reduce2DInnerBench(Reduce2DBench): 189*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, dim0, dim1): 190*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype, 1, dim0, dim1) 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker @staticmethod 193*da0073e9SAndroid Build Coastguard Worker def default_configs(): 194*da0073e9SAndroid Build Coastguard Worker parent_config = Reduce2DBench.default_configs()[0] 195*da0073e9SAndroid Build Coastguard Worker return [parent_config[1:]] 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker def config(self): 198*da0073e9SAndroid Build Coastguard Worker parent_config = super().config() 199*da0073e9SAndroid Build Coastguard Worker return parent_config[1:] 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker @staticmethod 202*da0073e9SAndroid Build Coastguard Worker def module(): 203*da0073e9SAndroid Build Coastguard Worker return "reduce2d_inner" 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Workerclass Reduce2DOuterBench(Reduce2DBench): 207*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, dim0, dim1): 208*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype, 0, dim0, dim1) 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker @staticmethod 211*da0073e9SAndroid Build Coastguard Worker def default_configs(): 212*da0073e9SAndroid Build Coastguard Worker parent_config = Reduce2DBench.default_configs()[0] 213*da0073e9SAndroid Build Coastguard Worker return [parent_config[1:]] 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker def config(self): 216*da0073e9SAndroid Build Coastguard Worker parent_config = super().config() 217*da0073e9SAndroid Build Coastguard Worker return parent_config[1:] 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker @staticmethod 220*da0073e9SAndroid Build Coastguard Worker def module(): 221*da0073e9SAndroid Build Coastguard Worker return "reduce2d_outer" 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(ReduceRowBench) 225*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(ReduceMidBench) 226*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(ReduceColBench) 227*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(Reduce2DInnerBench) 228*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(Reduce2DOuterBench) 229*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(ReduceFullBench) 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Workerclass DynamicReduce2DBench(benchmark.DynamicShape, Reduce2DBench): 233*da0073e9SAndroid Build Coastguard Worker """ 234*da0073e9SAndroid Build Coastguard Worker A benchmark class to validate 2 dimensional reduction performance. 235*da0073e9SAndroid Build Coastguard Worker Only a simple add is fused to induce the fuser and isolate reduction perf. 236*da0073e9SAndroid Build Coastguard Worker """ 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, red_dim, dim0, dim1): 239*da0073e9SAndroid Build Coastguard Worker benchmark.DynamicShape.__init__(self) 240*da0073e9SAndroid Build Coastguard Worker Reduce2DBench.__init__(self, mode, device, dtype, red_dim, dim0, dim1) 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker def instantiate_input(self): 243*da0073e9SAndroid Build Coastguard Worker dim0, dim1 = self.rand_shape([self.dim0, self.dim1]) 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker self.inputs = [ 246*da0073e9SAndroid Build Coastguard Worker self.randn( 247*da0073e9SAndroid Build Coastguard Worker [dim0, dim1], 248*da0073e9SAndroid Build Coastguard Worker device=self.device, 249*da0073e9SAndroid Build Coastguard Worker dtype=self.dtype, 250*da0073e9SAndroid Build Coastguard Worker requires_grad=self.requires_grad, 251*da0073e9SAndroid Build Coastguard Worker ) 252*da0073e9SAndroid Build Coastguard Worker ] 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker @staticmethod 255*da0073e9SAndroid Build Coastguard Worker def module(): 256*da0073e9SAndroid Build Coastguard Worker return "dynamicreduce2d" 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Workerclass DynamicReduce2DInnerBench(DynamicReduce2DBench): 260*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, dim0, dim1): 261*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype, 1, dim0, dim1) 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker @staticmethod 264*da0073e9SAndroid Build Coastguard Worker def default_configs(): 265*da0073e9SAndroid Build Coastguard Worker parent_config = DynamicReduce2DBench.default_configs()[0] 266*da0073e9SAndroid Build Coastguard Worker return [parent_config[1:]] 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker def config(self): 269*da0073e9SAndroid Build Coastguard Worker parent_config = super().config() 270*da0073e9SAndroid Build Coastguard Worker return parent_config[1:] 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker @staticmethod 273*da0073e9SAndroid Build Coastguard Worker def module(): 274*da0073e9SAndroid Build Coastguard Worker return "reduce2d_dynamic_inner" 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Workerclass DynamicReduce2DOuterBench(DynamicReduce2DBench): 278*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, dim0, dim1): 279*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype, 0, dim0, dim1) 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker @staticmethod 282*da0073e9SAndroid Build Coastguard Worker def default_configs(): 283*da0073e9SAndroid Build Coastguard Worker parent_config = DynamicReduce2DBench.default_configs()[0] 284*da0073e9SAndroid Build Coastguard Worker return [parent_config[1:]] 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker def config(self): 287*da0073e9SAndroid Build Coastguard Worker parent_config = super().config() 288*da0073e9SAndroid Build Coastguard Worker return parent_config[1:] 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker @staticmethod 291*da0073e9SAndroid Build Coastguard Worker def module(): 292*da0073e9SAndroid Build Coastguard Worker return "reduce2d_dynamic_outer" 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(DynamicReduce2DInnerBench) 296*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(DynamicReduce2DOuterBench) 297