1*da0073e9SAndroid Build Coastguard Workerfrom . import benchmark 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerclass PoolingBench(benchmark.Benchmark): 5*da0073e9SAndroid Build Coastguard Worker def __init__(self, case, mode, device, dtype, kernel_size, N, C, H, W): 6*da0073e9SAndroid Build Coastguard Worker super().__init__(mode, device) 7*da0073e9SAndroid Build Coastguard Worker self.case = case 8*da0073e9SAndroid Build Coastguard Worker self.kernel_size = kernel_size 9*da0073e9SAndroid Build Coastguard Worker self.N = N 10*da0073e9SAndroid Build Coastguard Worker self.C = C 11*da0073e9SAndroid Build Coastguard Worker self.H = H 12*da0073e9SAndroid Build Coastguard Worker self.W = W 13*da0073e9SAndroid Build Coastguard Worker self.data = self.rand( 14*da0073e9SAndroid Build Coastguard Worker [N, C, H, W], device=device, dtype=dtype, requires_grad=self.requires_grad 15*da0073e9SAndroid Build Coastguard Worker ) 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker def forward(self): 18*da0073e9SAndroid Build Coastguard Worker if self.case == "maxpool": 19*da0073e9SAndroid Build Coastguard Worker y = self.max_pool2d(self.data, self.kernel_size, stride=1) 20*da0073e9SAndroid Build Coastguard Worker elif self.case == "avgpool": 21*da0073e9SAndroid Build Coastguard Worker y = self.avg_pool2d(self.data, self.kernel_size, stride=1) 22*da0073e9SAndroid Build Coastguard Worker return y 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker def config(self): 25*da0073e9SAndroid Build Coastguard Worker return [self.kernel_size, self.N, self.C, self.H, self.W] 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker def memory_workload(self): 28*da0073e9SAndroid Build Coastguard Worker if self.mode == "fwd": 29*da0073e9SAndroid Build Coastguard Worker sol_count = 1 + 1 30*da0073e9SAndroid Build Coastguard Worker algorithmic_count = 1 + 1 31*da0073e9SAndroid Build Coastguard Worker else: 32*da0073e9SAndroid Build Coastguard Worker sol_count = (1 + 1) + (1 + 1) 33*da0073e9SAndroid Build Coastguard Worker algorithmic_count = (1 + 1) + (2 + 1) 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker buffer_size = self.N * self.C * self.H * self.W 36*da0073e9SAndroid Build Coastguard Worker return { 37*da0073e9SAndroid Build Coastguard Worker "sol": buffer_size * sol_count, 38*da0073e9SAndroid Build Coastguard Worker "algorithmic": buffer_size * algorithmic_count, 39*da0073e9SAndroid Build Coastguard Worker } 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker @staticmethod 42*da0073e9SAndroid Build Coastguard Worker def default_configs(): 43*da0073e9SAndroid Build Coastguard Worker return [[3, 16, 32, 256, 256]] 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Workerclass MaxPoolBench(PoolingBench): 47*da0073e9SAndroid Build Coastguard Worker def __init__(self, *args): 48*da0073e9SAndroid Build Coastguard Worker super().__init__("maxpool", *args) 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker @staticmethod 51*da0073e9SAndroid Build Coastguard Worker def module(): 52*da0073e9SAndroid Build Coastguard Worker return "maxpool" 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Workerclass AvgPoolBench(PoolingBench): 56*da0073e9SAndroid Build Coastguard Worker def __init__(self, *args): 57*da0073e9SAndroid Build Coastguard Worker super().__init__("avgpool", *args) 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker @staticmethod 60*da0073e9SAndroid Build Coastguard Worker def module(): 61*da0073e9SAndroid Build Coastguard Worker return "avgpool" 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(MaxPoolBench) 65*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(AvgPoolBench) 66