xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/concat.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import numpy as np
2
3import torch
4
5from . import benchmark
6
7
8class Concat2D2InputBench(benchmark.Benchmark):
9    def __init__(self, mode, device, dtype, I1_D1, I1_D2, I2_D1, I2_D2, concat_dim):
10        super().__init__(mode, device, dtype)
11        self.I1_D1 = I1_D1
12        self.I1_D2 = I1_D2
13        self.I2_D1 = I2_D1
14        self.I2_D2 = I2_D2
15        self.concat_dim = concat_dim
16        self.input1 = self.randn(
17            [I1_D1, I1_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
18        )
19        self.input2 = self.randn(
20            [I2_D1, I2_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
21        )
22        self.inputs = [self.input1, self.input2]
23
24    def forward(self, input1, input2):
25        x1 = self.add(input1, 0.00001)
26        x2 = self.add(input2, 0.00001)
27        y = self.cat((x1, x2), dim=self.concat_dim)
28        return y
29
30    def reference(self):
31        return np.concatenate(
32            (self.numpy(self.input1), self.numpy(self.input2)),
33            axis=self.concat_dim,
34        )
35
36    def config(self):
37        return [self.I1_D1, self.I1_D2, self.I2_D1, self.I2_D2, self.concat_dim]
38
39    @staticmethod
40    def module():
41        return "concat2d2input"
42
43    def memory_workload(self):
44        if self.mode == "fwd":
45            sol_count = 1 + 1
46            algorithmic_count = 3 + 1
47        else:
48            sol_count = (1 + 1) + (1 + 1)
49            algorithmic_count = (3 + 1) + (3 + 1)
50
51        buffer_size = self.I1_D1 * self.I1_D2 + self.I2_D1 * self.I2_D2
52        return {
53            "sol": buffer_size * sol_count,
54            "algorithmic": buffer_size * algorithmic_count,
55        }
56
57    @staticmethod
58    def default_configs():
59        return [
60            [1, 160, 1, 14, 1],
61            [1, 580, 1, 174, 1],
62            [20, 160, 20, 14, 1],
63            [20, 580, 20, 174, 1],
64            [8, 512, 8, 512, 1],
65            [1 << 13, 1060, 1 << 13, 1040, 1],
66            [1 << 13, 2000, 1 << 13, 1074, 1],
67            [1 << 15, 1060, 1 << 15, 2670, 1],
68            [1 << 15, 5120, 1 << 15, 2512, 1],
69        ]
70
71
72benchmark.register_benchmark_class(Concat2D2InputBench)
73
74
75class ConcatGraphOptBench(benchmark.Benchmark):
76    def __init__(self, mode, device, dtype, I1_D1, I1_D2, I2_D1, I2_D2, concat_dim):
77        super().__init__(mode, device, dtype)
78        self.I1_D1 = I1_D1
79        self.I1_D2 = I1_D2
80        self.I2_D1 = I2_D1
81        self.I2_D2 = I2_D2
82        self.concat_dim = concat_dim
83        self.input1 = self.randn(
84            [I1_D1, I1_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
85        )
86        self.input2 = self.randn(
87            [I2_D1, I2_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
88        )
89        self.inputs = [self.input1, self.input2]
90        torch._C._jit_override_can_fuse_on_cpu(True)
91        torch._C._jit_cat_wo_conditionals(True)
92
93    def forward(self, input1, input2):
94        x1 = self.add(input1, 0.00001)
95        x2 = self.add(input2, 0.00001)
96        y = self.cat((x1, x2), dim=self.concat_dim)
97        z = self.relu(y)
98        return z
99
100    def reference(self):
101        return np.concatenate(
102            (self.numpy(self.input1), self.numpy(self.input2)),
103            axis=self.concat_dim,
104        )
105
106    def config(self):
107        return [self.I1_D1, self.I1_D2, self.I2_D1, self.I2_D2, self.concat_dim]
108
109    @staticmethod
110    def module():
111        return "concatGraphOpt"
112
113    def memory_workload(self):
114        if self.mode == "fwd":
115            sol_count = 1 + 1
116            algorithmic_count = 3 + 1
117        else:
118            sol_count = (1 + 1) + (1 + 1)
119            algorithmic_count = (3 + 1) + (3 + 1)
120
121        buffer_size = self.I1_D1 * self.I1_D2 + self.I2_D1 * self.I2_D2
122        return {
123            "sol": buffer_size * sol_count,
124            "algorithmic": buffer_size * algorithmic_count,
125        }
126
127    @staticmethod
128    def default_configs():
129        return [
130            [1 << 13, 1060, 1 << 13, 1040, 1],
131            [1 << 13, 2000, 1 << 13, 1074, 1],
132            [1 << 15, 1060, 1 << 15, 2670, 1],
133            [1 << 15, 5120, 1 << 15, 2512, 1],
134        ]
135
136
137benchmark.register_benchmark_class(ConcatGraphOptBench)
138