xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/pooling.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom . import benchmark
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerclass PoolingBench(benchmark.Benchmark):
5*da0073e9SAndroid Build Coastguard Worker    def __init__(self, case, mode, device, dtype, kernel_size, N, C, H, W):
6*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device)
7*da0073e9SAndroid Build Coastguard Worker        self.case = case
8*da0073e9SAndroid Build Coastguard Worker        self.kernel_size = kernel_size
9*da0073e9SAndroid Build Coastguard Worker        self.N = N
10*da0073e9SAndroid Build Coastguard Worker        self.C = C
11*da0073e9SAndroid Build Coastguard Worker        self.H = H
12*da0073e9SAndroid Build Coastguard Worker        self.W = W
13*da0073e9SAndroid Build Coastguard Worker        self.data = self.rand(
14*da0073e9SAndroid Build Coastguard Worker            [N, C, H, W], device=device, dtype=dtype, requires_grad=self.requires_grad
15*da0073e9SAndroid Build Coastguard Worker        )
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker    def forward(self):
18*da0073e9SAndroid Build Coastguard Worker        if self.case == "maxpool":
19*da0073e9SAndroid Build Coastguard Worker            y = self.max_pool2d(self.data, self.kernel_size, stride=1)
20*da0073e9SAndroid Build Coastguard Worker        elif self.case == "avgpool":
21*da0073e9SAndroid Build Coastguard Worker            y = self.avg_pool2d(self.data, self.kernel_size, stride=1)
22*da0073e9SAndroid Build Coastguard Worker        return y
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker    def config(self):
25*da0073e9SAndroid Build Coastguard Worker        return [self.kernel_size, self.N, self.C, self.H, self.W]
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    def memory_workload(self):
28*da0073e9SAndroid Build Coastguard Worker        if self.mode == "fwd":
29*da0073e9SAndroid Build Coastguard Worker            sol_count = 1 + 1
30*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = 1 + 1
31*da0073e9SAndroid Build Coastguard Worker        else:
32*da0073e9SAndroid Build Coastguard Worker            sol_count = (1 + 1) + (1 + 1)
33*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = (1 + 1) + (2 + 1)
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker        buffer_size = self.N * self.C * self.H * self.W
36*da0073e9SAndroid Build Coastguard Worker        return {
37*da0073e9SAndroid Build Coastguard Worker            "sol": buffer_size * sol_count,
38*da0073e9SAndroid Build Coastguard Worker            "algorithmic": buffer_size * algorithmic_count,
39*da0073e9SAndroid Build Coastguard Worker        }
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker    @staticmethod
42*da0073e9SAndroid Build Coastguard Worker    def default_configs():
43*da0073e9SAndroid Build Coastguard Worker        return [[3, 16, 32, 256, 256]]
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Workerclass MaxPoolBench(PoolingBench):
47*da0073e9SAndroid Build Coastguard Worker    def __init__(self, *args):
48*da0073e9SAndroid Build Coastguard Worker        super().__init__("maxpool", *args)
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker    @staticmethod
51*da0073e9SAndroid Build Coastguard Worker    def module():
52*da0073e9SAndroid Build Coastguard Worker        return "maxpool"
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Workerclass AvgPoolBench(PoolingBench):
56*da0073e9SAndroid Build Coastguard Worker    def __init__(self, *args):
57*da0073e9SAndroid Build Coastguard Worker        super().__init__("avgpool", *args)
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker    @staticmethod
60*da0073e9SAndroid Build Coastguard Worker    def module():
61*da0073e9SAndroid Build Coastguard Worker        return "avgpool"
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(MaxPoolBench)
65*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(AvgPoolBench)
66