xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/conv.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from . import benchmark
2
3
4class ConvImplBench(benchmark.Benchmark):
5    def __init__(self, case, mode, device, dtype, kernel_size, N, iC, H, W, oC):
6        super().__init__(mode, device, dtype)
7        self.case = case
8        self.kernel_size = kernel_size
9        self.N = N
10        self.iC = iC
11        self.H = H
12        self.W = W
13        self.oC = oC
14        self.data = self.rand(
15            [N, iC, H, W], device=device, requires_grad=self.requires_grad
16        )
17        if case == "conv":
18            self.groups = 1
19        elif case == "depthwise_conv":
20            self.groups = iC
21        else:
22            raise ValueError(f"invalid case: {case}")
23
24        self.conv = self.conv2d_layer(iC, oC, kernel_size, groups=self.groups)
25        if device != "cpu":
26            self.to_device(self.conv, device)
27
28    def forward(self):
29        y = self.conv(self.data)
30        return y
31
32    def config(self):
33        return [self.kernel_size, self.N, self.iC, self.H, self.W, self.oC]
34
35    def memory_workload(self):
36        if self.mode == "fwd":
37            sol_count = {"i": 1, "o": 1, "k": 1}
38            algorithmic_count = {"i": 1, "o": 1, "k": 1}
39        else:
40            sol_count = {"i": 1 + 1, "o": 1 + 1, "k": 1 + 1}
41            algorithmic_count = {"i": 1 + (1 + 1), "o": 1 + (1 + 1), "k": 1 + (1 + 1)}
42
43        buffer_size = {
44            "i": self.N * self.iC * self.H * self.W,
45            "o": self.N * self.oC * self.H * self.W,
46            "k": self.oC
47            * (self.iC / self.groups)
48            * self.kernel_size
49            * self.kernel_size,
50        }
51        sol_size = 0
52        algorithmic_size = 0
53        for key in sol_count:
54            sol_size += buffer_size[key] * sol_count[key]
55            algorithmic_size += buffer_size[key] * algorithmic_count[key]
56        return {"sol": sol_size, "algorithmic": algorithmic_size}
57
58    def compute_workload(self):
59        if self.mode == "fwd":
60            count = 1
61        elif self.mode == "both":
62            count = 1 + (1 + 1)
63        else:
64            raise ValueError(f"invalid mode: {self.mode}")
65
66        op_count = (
67            self.N
68            * self.iC
69            / self.groups
70            * self.oC
71            * self.kernel_size
72            * self.kernel_size
73            * self.H
74            * self.W
75        )
76        op_count *= 2
77
78        return op_count * count
79
80    @staticmethod
81    def default_configs():
82        return [
83            [3, 64, 32, 128, 128, 64],
84        ]
85
86
87class ConvBench(ConvImplBench):
88    def __init__(self, *args):
89        super().__init__("conv", *args)
90
91    @staticmethod
92    def module():
93        return "conv"
94
95
96class DepthwiseConvBench(ConvImplBench):
97    def __init__(self, *args):
98        super().__init__("depthwise_conv", *args)
99
100    @staticmethod
101    def module():
102        return "depthwise_conv"
103
104
105benchmark.register_benchmark_class(ConvBench)
106benchmark.register_benchmark_class(DepthwiseConvBench)
107