1*da0073e9SAndroid Build Coastguard Workerimport numpy as np 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerfrom . import benchmark 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerclass Concat2D2InputBench(benchmark.Benchmark): 9*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, I1_D1, I1_D2, I2_D1, I2_D2, concat_dim): 10*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype) 11*da0073e9SAndroid Build Coastguard Worker self.I1_D1 = I1_D1 12*da0073e9SAndroid Build Coastguard Worker self.I1_D2 = I1_D2 13*da0073e9SAndroid Build Coastguard Worker self.I2_D1 = I2_D1 14*da0073e9SAndroid Build Coastguard Worker self.I2_D2 = I2_D2 15*da0073e9SAndroid Build Coastguard Worker self.concat_dim = concat_dim 16*da0073e9SAndroid Build Coastguard Worker self.input1 = self.randn( 17*da0073e9SAndroid Build Coastguard Worker [I1_D1, I1_D2], device=device, dtype=dtype, requires_grad=self.requires_grad 18*da0073e9SAndroid Build Coastguard Worker ) 19*da0073e9SAndroid Build Coastguard Worker self.input2 = self.randn( 20*da0073e9SAndroid Build Coastguard Worker [I2_D1, I2_D2], device=device, dtype=dtype, requires_grad=self.requires_grad 21*da0073e9SAndroid Build Coastguard Worker ) 22*da0073e9SAndroid Build Coastguard Worker self.inputs = [self.input1, self.input2] 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker def forward(self, input1, input2): 25*da0073e9SAndroid Build Coastguard Worker x1 = self.add(input1, 0.00001) 26*da0073e9SAndroid Build Coastguard Worker x2 = self.add(input2, 0.00001) 27*da0073e9SAndroid Build Coastguard Worker y = self.cat((x1, x2), dim=self.concat_dim) 28*da0073e9SAndroid Build Coastguard Worker return y 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker def reference(self): 31*da0073e9SAndroid Build Coastguard Worker return np.concatenate( 32*da0073e9SAndroid Build Coastguard Worker (self.numpy(self.input1), self.numpy(self.input2)), 33*da0073e9SAndroid Build Coastguard Worker axis=self.concat_dim, 34*da0073e9SAndroid Build Coastguard Worker ) 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker def config(self): 37*da0073e9SAndroid Build Coastguard Worker return [self.I1_D1, self.I1_D2, self.I2_D1, self.I2_D2, self.concat_dim] 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker @staticmethod 40*da0073e9SAndroid Build Coastguard Worker def module(): 41*da0073e9SAndroid Build Coastguard Worker return "concat2d2input" 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker def memory_workload(self): 44*da0073e9SAndroid Build Coastguard Worker if self.mode == "fwd": 45*da0073e9SAndroid Build Coastguard Worker sol_count = 1 + 1 46*da0073e9SAndroid Build Coastguard Worker algorithmic_count = 3 + 1 47*da0073e9SAndroid Build Coastguard Worker else: 48*da0073e9SAndroid Build Coastguard Worker sol_count = (1 + 1) + (1 + 1) 49*da0073e9SAndroid Build Coastguard Worker algorithmic_count = (3 + 1) + (3 + 1) 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker buffer_size = self.I1_D1 * self.I1_D2 + self.I2_D1 * self.I2_D2 52*da0073e9SAndroid Build Coastguard Worker return { 53*da0073e9SAndroid Build Coastguard Worker "sol": buffer_size * sol_count, 54*da0073e9SAndroid Build Coastguard Worker "algorithmic": buffer_size * algorithmic_count, 55*da0073e9SAndroid Build Coastguard Worker } 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker @staticmethod 58*da0073e9SAndroid Build Coastguard Worker def default_configs(): 59*da0073e9SAndroid Build Coastguard Worker return [ 60*da0073e9SAndroid Build Coastguard Worker [1, 160, 1, 14, 1], 61*da0073e9SAndroid Build Coastguard Worker [1, 580, 1, 174, 1], 62*da0073e9SAndroid Build Coastguard Worker [20, 160, 20, 14, 1], 63*da0073e9SAndroid Build Coastguard Worker [20, 580, 20, 174, 1], 64*da0073e9SAndroid Build Coastguard Worker [8, 512, 8, 512, 1], 65*da0073e9SAndroid Build Coastguard Worker [1 << 13, 1060, 1 << 13, 1040, 1], 66*da0073e9SAndroid Build Coastguard Worker [1 << 13, 2000, 1 << 13, 1074, 1], 67*da0073e9SAndroid Build Coastguard Worker [1 << 15, 1060, 1 << 15, 2670, 1], 68*da0073e9SAndroid Build Coastguard Worker [1 << 15, 5120, 1 << 15, 2512, 1], 69*da0073e9SAndroid Build Coastguard Worker ] 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(Concat2D2InputBench) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Workerclass ConcatGraphOptBench(benchmark.Benchmark): 76*da0073e9SAndroid Build Coastguard Worker def __init__(self, mode, device, dtype, I1_D1, I1_D2, I2_D1, I2_D2, concat_dim): 77*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device, dtype) 78*da0073e9SAndroid Build Coastguard Worker self.I1_D1 = I1_D1 79*da0073e9SAndroid Build Coastguard Worker self.I1_D2 = I1_D2 80*da0073e9SAndroid Build Coastguard Worker self.I2_D1 = I2_D1 81*da0073e9SAndroid Build Coastguard Worker self.I2_D2 = I2_D2 82*da0073e9SAndroid Build Coastguard Worker self.concat_dim = concat_dim 83*da0073e9SAndroid Build Coastguard Worker self.input1 = self.randn( 84*da0073e9SAndroid Build Coastguard Worker [I1_D1, I1_D2], device=device, dtype=dtype, requires_grad=self.requires_grad 85*da0073e9SAndroid Build Coastguard Worker ) 86*da0073e9SAndroid Build Coastguard Worker self.input2 = self.randn( 87*da0073e9SAndroid Build Coastguard Worker [I2_D1, I2_D2], device=device, dtype=dtype, requires_grad=self.requires_grad 88*da0073e9SAndroid Build Coastguard Worker ) 89*da0073e9SAndroid Build Coastguard Worker self.inputs = [self.input1, self.input2] 90*da0073e9SAndroid Build Coastguard Worker torch._C._jit_override_can_fuse_on_cpu(True) 91*da0073e9SAndroid Build Coastguard Worker torch._C._jit_cat_wo_conditionals(True) 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker def forward(self, input1, input2): 94*da0073e9SAndroid Build Coastguard Worker x1 = self.add(input1, 0.00001) 95*da0073e9SAndroid Build Coastguard Worker x2 = self.add(input2, 0.00001) 96*da0073e9SAndroid Build Coastguard Worker y = self.cat((x1, x2), dim=self.concat_dim) 97*da0073e9SAndroid Build Coastguard Worker z = self.relu(y) 98*da0073e9SAndroid Build Coastguard Worker return z 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker def reference(self): 101*da0073e9SAndroid Build Coastguard Worker return np.concatenate( 102*da0073e9SAndroid Build Coastguard Worker (self.numpy(self.input1), self.numpy(self.input2)), 103*da0073e9SAndroid Build Coastguard Worker axis=self.concat_dim, 104*da0073e9SAndroid Build Coastguard Worker ) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker def config(self): 107*da0073e9SAndroid Build Coastguard Worker return [self.I1_D1, self.I1_D2, self.I2_D1, self.I2_D2, self.concat_dim] 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker @staticmethod 110*da0073e9SAndroid Build Coastguard Worker def module(): 111*da0073e9SAndroid Build Coastguard Worker return "concatGraphOpt" 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker def memory_workload(self): 114*da0073e9SAndroid Build Coastguard Worker if self.mode == "fwd": 115*da0073e9SAndroid Build Coastguard Worker sol_count = 1 + 1 116*da0073e9SAndroid Build Coastguard Worker algorithmic_count = 3 + 1 117*da0073e9SAndroid Build Coastguard Worker else: 118*da0073e9SAndroid Build Coastguard Worker sol_count = (1 + 1) + (1 + 1) 119*da0073e9SAndroid Build Coastguard Worker algorithmic_count = (3 + 1) + (3 + 1) 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker buffer_size = self.I1_D1 * self.I1_D2 + self.I2_D1 * self.I2_D2 122*da0073e9SAndroid Build Coastguard Worker return { 123*da0073e9SAndroid Build Coastguard Worker "sol": buffer_size * sol_count, 124*da0073e9SAndroid Build Coastguard Worker "algorithmic": buffer_size * algorithmic_count, 125*da0073e9SAndroid Build Coastguard Worker } 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker @staticmethod 128*da0073e9SAndroid Build Coastguard Worker def default_configs(): 129*da0073e9SAndroid Build Coastguard Worker return [ 130*da0073e9SAndroid Build Coastguard Worker [1 << 13, 1060, 1 << 13, 1040, 1], 131*da0073e9SAndroid Build Coastguard Worker [1 << 13, 2000, 1 << 13, 1074, 1], 132*da0073e9SAndroid Build Coastguard Worker [1 << 15, 1060, 1 << 15, 2670, 1], 133*da0073e9SAndroid Build Coastguard Worker [1 << 15, 5120, 1 << 15, 2512, 1], 134*da0073e9SAndroid Build Coastguard Worker ] 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(ConcatGraphOptBench) 138