xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/concat.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport numpy as np
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport torch
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerfrom . import benchmark
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerclass Concat2D2InputBench(benchmark.Benchmark):
9*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, I1_D1, I1_D2, I2_D1, I2_D2, concat_dim):
10*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype)
11*da0073e9SAndroid Build Coastguard Worker        self.I1_D1 = I1_D1
12*da0073e9SAndroid Build Coastguard Worker        self.I1_D2 = I1_D2
13*da0073e9SAndroid Build Coastguard Worker        self.I2_D1 = I2_D1
14*da0073e9SAndroid Build Coastguard Worker        self.I2_D2 = I2_D2
15*da0073e9SAndroid Build Coastguard Worker        self.concat_dim = concat_dim
16*da0073e9SAndroid Build Coastguard Worker        self.input1 = self.randn(
17*da0073e9SAndroid Build Coastguard Worker            [I1_D1, I1_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
18*da0073e9SAndroid Build Coastguard Worker        )
19*da0073e9SAndroid Build Coastguard Worker        self.input2 = self.randn(
20*da0073e9SAndroid Build Coastguard Worker            [I2_D1, I2_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
21*da0073e9SAndroid Build Coastguard Worker        )
22*da0073e9SAndroid Build Coastguard Worker        self.inputs = [self.input1, self.input2]
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker    def forward(self, input1, input2):
25*da0073e9SAndroid Build Coastguard Worker        x1 = self.add(input1, 0.00001)
26*da0073e9SAndroid Build Coastguard Worker        x2 = self.add(input2, 0.00001)
27*da0073e9SAndroid Build Coastguard Worker        y = self.cat((x1, x2), dim=self.concat_dim)
28*da0073e9SAndroid Build Coastguard Worker        return y
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker    def reference(self):
31*da0073e9SAndroid Build Coastguard Worker        return np.concatenate(
32*da0073e9SAndroid Build Coastguard Worker            (self.numpy(self.input1), self.numpy(self.input2)),
33*da0073e9SAndroid Build Coastguard Worker            axis=self.concat_dim,
34*da0073e9SAndroid Build Coastguard Worker        )
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker    def config(self):
37*da0073e9SAndroid Build Coastguard Worker        return [self.I1_D1, self.I1_D2, self.I2_D1, self.I2_D2, self.concat_dim]
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker    @staticmethod
40*da0073e9SAndroid Build Coastguard Worker    def module():
41*da0073e9SAndroid Build Coastguard Worker        return "concat2d2input"
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker    def memory_workload(self):
44*da0073e9SAndroid Build Coastguard Worker        if self.mode == "fwd":
45*da0073e9SAndroid Build Coastguard Worker            sol_count = 1 + 1
46*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = 3 + 1
47*da0073e9SAndroid Build Coastguard Worker        else:
48*da0073e9SAndroid Build Coastguard Worker            sol_count = (1 + 1) + (1 + 1)
49*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = (3 + 1) + (3 + 1)
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker        buffer_size = self.I1_D1 * self.I1_D2 + self.I2_D1 * self.I2_D2
52*da0073e9SAndroid Build Coastguard Worker        return {
53*da0073e9SAndroid Build Coastguard Worker            "sol": buffer_size * sol_count,
54*da0073e9SAndroid Build Coastguard Worker            "algorithmic": buffer_size * algorithmic_count,
55*da0073e9SAndroid Build Coastguard Worker        }
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker    @staticmethod
58*da0073e9SAndroid Build Coastguard Worker    def default_configs():
59*da0073e9SAndroid Build Coastguard Worker        return [
60*da0073e9SAndroid Build Coastguard Worker            [1, 160, 1, 14, 1],
61*da0073e9SAndroid Build Coastguard Worker            [1, 580, 1, 174, 1],
62*da0073e9SAndroid Build Coastguard Worker            [20, 160, 20, 14, 1],
63*da0073e9SAndroid Build Coastguard Worker            [20, 580, 20, 174, 1],
64*da0073e9SAndroid Build Coastguard Worker            [8, 512, 8, 512, 1],
65*da0073e9SAndroid Build Coastguard Worker            [1 << 13, 1060, 1 << 13, 1040, 1],
66*da0073e9SAndroid Build Coastguard Worker            [1 << 13, 2000, 1 << 13, 1074, 1],
67*da0073e9SAndroid Build Coastguard Worker            [1 << 15, 1060, 1 << 15, 2670, 1],
68*da0073e9SAndroid Build Coastguard Worker            [1 << 15, 5120, 1 << 15, 2512, 1],
69*da0073e9SAndroid Build Coastguard Worker        ]
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(Concat2D2InputBench)
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Workerclass ConcatGraphOptBench(benchmark.Benchmark):
76*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, I1_D1, I1_D2, I2_D1, I2_D2, concat_dim):
77*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype)
78*da0073e9SAndroid Build Coastguard Worker        self.I1_D1 = I1_D1
79*da0073e9SAndroid Build Coastguard Worker        self.I1_D2 = I1_D2
80*da0073e9SAndroid Build Coastguard Worker        self.I2_D1 = I2_D1
81*da0073e9SAndroid Build Coastguard Worker        self.I2_D2 = I2_D2
82*da0073e9SAndroid Build Coastguard Worker        self.concat_dim = concat_dim
83*da0073e9SAndroid Build Coastguard Worker        self.input1 = self.randn(
84*da0073e9SAndroid Build Coastguard Worker            [I1_D1, I1_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
85*da0073e9SAndroid Build Coastguard Worker        )
86*da0073e9SAndroid Build Coastguard Worker        self.input2 = self.randn(
87*da0073e9SAndroid Build Coastguard Worker            [I2_D1, I2_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
88*da0073e9SAndroid Build Coastguard Worker        )
89*da0073e9SAndroid Build Coastguard Worker        self.inputs = [self.input1, self.input2]
90*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_override_can_fuse_on_cpu(True)
91*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_cat_wo_conditionals(True)
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker    def forward(self, input1, input2):
94*da0073e9SAndroid Build Coastguard Worker        x1 = self.add(input1, 0.00001)
95*da0073e9SAndroid Build Coastguard Worker        x2 = self.add(input2, 0.00001)
96*da0073e9SAndroid Build Coastguard Worker        y = self.cat((x1, x2), dim=self.concat_dim)
97*da0073e9SAndroid Build Coastguard Worker        z = self.relu(y)
98*da0073e9SAndroid Build Coastguard Worker        return z
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    def reference(self):
101*da0073e9SAndroid Build Coastguard Worker        return np.concatenate(
102*da0073e9SAndroid Build Coastguard Worker            (self.numpy(self.input1), self.numpy(self.input2)),
103*da0073e9SAndroid Build Coastguard Worker            axis=self.concat_dim,
104*da0073e9SAndroid Build Coastguard Worker        )
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker    def config(self):
107*da0073e9SAndroid Build Coastguard Worker        return [self.I1_D1, self.I1_D2, self.I2_D1, self.I2_D2, self.concat_dim]
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker    @staticmethod
110*da0073e9SAndroid Build Coastguard Worker    def module():
111*da0073e9SAndroid Build Coastguard Worker        return "concatGraphOpt"
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker    def memory_workload(self):
114*da0073e9SAndroid Build Coastguard Worker        if self.mode == "fwd":
115*da0073e9SAndroid Build Coastguard Worker            sol_count = 1 + 1
116*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = 3 + 1
117*da0073e9SAndroid Build Coastguard Worker        else:
118*da0073e9SAndroid Build Coastguard Worker            sol_count = (1 + 1) + (1 + 1)
119*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = (3 + 1) + (3 + 1)
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker        buffer_size = self.I1_D1 * self.I1_D2 + self.I2_D1 * self.I2_D2
122*da0073e9SAndroid Build Coastguard Worker        return {
123*da0073e9SAndroid Build Coastguard Worker            "sol": buffer_size * sol_count,
124*da0073e9SAndroid Build Coastguard Worker            "algorithmic": buffer_size * algorithmic_count,
125*da0073e9SAndroid Build Coastguard Worker        }
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker    @staticmethod
128*da0073e9SAndroid Build Coastguard Worker    def default_configs():
129*da0073e9SAndroid Build Coastguard Worker        return [
130*da0073e9SAndroid Build Coastguard Worker            [1 << 13, 1060, 1 << 13, 1040, 1],
131*da0073e9SAndroid Build Coastguard Worker            [1 << 13, 2000, 1 << 13, 1074, 1],
132*da0073e9SAndroid Build Coastguard Worker            [1 << 15, 1060, 1 << 15, 2670, 1],
133*da0073e9SAndroid Build Coastguard Worker            [1 << 15, 5120, 1 << 15, 2512, 1],
134*da0073e9SAndroid Build Coastguard Worker        ]
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(ConcatGraphOptBench)
138