xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/reduction.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom . import benchmark
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerclass ReduceBench(benchmark.Benchmark):
5*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, case, M, N, K, skip_input_transform):
6*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype)
7*da0073e9SAndroid Build Coastguard Worker        self.case = case
8*da0073e9SAndroid Build Coastguard Worker        self.M = M
9*da0073e9SAndroid Build Coastguard Worker        self.N = N
10*da0073e9SAndroid Build Coastguard Worker        self.K = K
11*da0073e9SAndroid Build Coastguard Worker        self._set_skip_input_transform(skip_input_transform)
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker        self.inputs = [
14*da0073e9SAndroid Build Coastguard Worker            self.randn(
15*da0073e9SAndroid Build Coastguard Worker                [M, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad
16*da0073e9SAndroid Build Coastguard Worker            )
17*da0073e9SAndroid Build Coastguard Worker        ]
18*da0073e9SAndroid Build Coastguard Worker        if case == "row":
19*da0073e9SAndroid Build Coastguard Worker            self.dims = [1, 2]
20*da0073e9SAndroid Build Coastguard Worker        elif case == "mid":
21*da0073e9SAndroid Build Coastguard Worker            self.dims = [0, 2]
22*da0073e9SAndroid Build Coastguard Worker        elif case == "col":
23*da0073e9SAndroid Build Coastguard Worker            self.dims = [0, 1]
24*da0073e9SAndroid Build Coastguard Worker        elif case == "full":
25*da0073e9SAndroid Build Coastguard Worker            self.dims = [0, 1, 2]
26*da0073e9SAndroid Build Coastguard Worker        else:
27*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"invalid case: {case}")
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker    def forward(self, inputs):
30*da0073e9SAndroid Build Coastguard Worker        if self.skip_input_transform:
31*da0073e9SAndroid Build Coastguard Worker            x = inputs
32*da0073e9SAndroid Build Coastguard Worker        else:
33*da0073e9SAndroid Build Coastguard Worker            x = self.add(inputs, 0.001)
34*da0073e9SAndroid Build Coastguard Worker        y = self.sum(x, self.dims)
35*da0073e9SAndroid Build Coastguard Worker        return y
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker    def config(self):
38*da0073e9SAndroid Build Coastguard Worker        if self.case == "full":
39*da0073e9SAndroid Build Coastguard Worker            return [self.M * self.N * self.K, self._skip_input_transform_str()]
40*da0073e9SAndroid Build Coastguard Worker        return [self.M, self.N, self.K, self._skip_input_transform_str()]
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker    @staticmethod
43*da0073e9SAndroid Build Coastguard Worker    def default_configs():
44*da0073e9SAndroid Build Coastguard Worker        return [
45*da0073e9SAndroid Build Coastguard Worker            # [512, 512, 512],
46*da0073e9SAndroid Build Coastguard Worker            [512, 64, 512, "s0"],
47*da0073e9SAndroid Build Coastguard Worker        ]
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker    @staticmethod
50*da0073e9SAndroid Build Coastguard Worker    def module():
51*da0073e9SAndroid Build Coastguard Worker        return "reduce"
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    def memory_workload(self):
54*da0073e9SAndroid Build Coastguard Worker        if self.mode == "fwd":
55*da0073e9SAndroid Build Coastguard Worker            sol_count = 1
56*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = 1
57*da0073e9SAndroid Build Coastguard Worker        else:
58*da0073e9SAndroid Build Coastguard Worker            sol_count = (1) + (1)
59*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = 1 + 1
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker        buffer_size = self.M * self.N * self.K
62*da0073e9SAndroid Build Coastguard Worker        return {
63*da0073e9SAndroid Build Coastguard Worker            "sol": buffer_size * sol_count,
64*da0073e9SAndroid Build Coastguard Worker            "algorithmic": buffer_size * algorithmic_count,
65*da0073e9SAndroid Build Coastguard Worker        }
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker    def _set_skip_input_transform(self, input_str):
68*da0073e9SAndroid Build Coastguard Worker        # In the test setting, s1 will skip the input transformation, and s0 will not.
69*da0073e9SAndroid Build Coastguard Worker        if input_str == "s0":
70*da0073e9SAndroid Build Coastguard Worker            self.skip_input_transform = False
71*da0073e9SAndroid Build Coastguard Worker        elif input_str == "s1":
72*da0073e9SAndroid Build Coastguard Worker            self.skip_input_transform = True
73*da0073e9SAndroid Build Coastguard Worker        else:
74*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"invalid skip_input_transform: {input_str}")
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker    def _skip_input_transform_str(self):
77*da0073e9SAndroid Build Coastguard Worker        if self.skip_input_transform:
78*da0073e9SAndroid Build Coastguard Worker            return "s1"
79*da0073e9SAndroid Build Coastguard Worker        else:
80*da0073e9SAndroid Build Coastguard Worker            return "s0"
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Workerclass ReduceRowBench(ReduceBench):
84*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, M, N, K, skip_input_transform):
85*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype, "row", M, N, K, skip_input_transform)
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker    @staticmethod
88*da0073e9SAndroid Build Coastguard Worker    def module():
89*da0073e9SAndroid Build Coastguard Worker        return "reduce_row"
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Workerclass ReduceMidBench(ReduceBench):
93*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, M, N, K, skip_input_transform):
94*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype, "mid", M, N, K, skip_input_transform)
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker    @staticmethod
97*da0073e9SAndroid Build Coastguard Worker    def module():
98*da0073e9SAndroid Build Coastguard Worker        return "reduce_mid"
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Workerclass ReduceColBench(ReduceBench):
102*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, M, N, K, skip_input_transform):
103*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype, "col", M, N, K, skip_input_transform)
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker    @staticmethod
106*da0073e9SAndroid Build Coastguard Worker    def module():
107*da0073e9SAndroid Build Coastguard Worker        return "reduce_col"
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Workerclass ReduceFullBench(ReduceBench):
111*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, M, skip_input_transform):
112*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype, "full", M, 1, 1, skip_input_transform)
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker    def config(self):
115*da0073e9SAndroid Build Coastguard Worker        return [self.M * self.N * self.K, self._skip_input_transform_str()]
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker    @staticmethod
118*da0073e9SAndroid Build Coastguard Worker    def default_configs():
119*da0073e9SAndroid Build Coastguard Worker        return [
120*da0073e9SAndroid Build Coastguard Worker            [1 << 24, "s1"],
121*da0073e9SAndroid Build Coastguard Worker        ]
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker    @staticmethod
124*da0073e9SAndroid Build Coastguard Worker    def module():
125*da0073e9SAndroid Build Coastguard Worker        return "reduce_full"
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Workerclass Reduce2DBench(benchmark.Benchmark):
129*da0073e9SAndroid Build Coastguard Worker    """
130*da0073e9SAndroid Build Coastguard Worker    A benchmark class to validate 2 dimensional reduction performance.
131*da0073e9SAndroid Build Coastguard Worker    Only a simple add is fused to induce the fuser and isolate reduction perf.
132*da0073e9SAndroid Build Coastguard Worker    """
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, red_dim, dim0, dim1):
135*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype)
136*da0073e9SAndroid Build Coastguard Worker        self.red_dim = red_dim
137*da0073e9SAndroid Build Coastguard Worker        self.dim0 = dim0
138*da0073e9SAndroid Build Coastguard Worker        self.dim1 = dim1
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker        self.inputs = [
141*da0073e9SAndroid Build Coastguard Worker            self.randn(
142*da0073e9SAndroid Build Coastguard Worker                [dim0, dim1],
143*da0073e9SAndroid Build Coastguard Worker                device=device,
144*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
145*da0073e9SAndroid Build Coastguard Worker                requires_grad=self.requires_grad,
146*da0073e9SAndroid Build Coastguard Worker            )
147*da0073e9SAndroid Build Coastguard Worker        ]
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker        if red_dim != 0 and red_dim != 1:
150*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"invalid reduction dimension: {red_dim}")
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker    def forward(self, inputs):
153*da0073e9SAndroid Build Coastguard Worker        x = self.add(inputs, 0.001)
154*da0073e9SAndroid Build Coastguard Worker        y = self.sum(x, [self.red_dim])
155*da0073e9SAndroid Build Coastguard Worker        return y
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker    def config(self):
158*da0073e9SAndroid Build Coastguard Worker        return [self.red_dim, self.dim0, self.dim1]
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker    @staticmethod
161*da0073e9SAndroid Build Coastguard Worker    def default_configs():
162*da0073e9SAndroid Build Coastguard Worker        return [
163*da0073e9SAndroid Build Coastguard Worker            [1, 640, 524288],
164*da0073e9SAndroid Build Coastguard Worker        ]
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker    @staticmethod
167*da0073e9SAndroid Build Coastguard Worker    def module():
168*da0073e9SAndroid Build Coastguard Worker        return "reduce2d"
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker    @staticmethod
171*da0073e9SAndroid Build Coastguard Worker    def input_iterable():
172*da0073e9SAndroid Build Coastguard Worker        return True
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker    def memory_workload(self):
175*da0073e9SAndroid Build Coastguard Worker        assert self.mode == "fwd", "Only the forward operation is modeled!"
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker        buffer_size = self.dim0 * self.dim1
178*da0073e9SAndroid Build Coastguard Worker        if self.red_dim == 0:
179*da0073e9SAndroid Build Coastguard Worker            buffer_size += self.dim1
180*da0073e9SAndroid Build Coastguard Worker        else:
181*da0073e9SAndroid Build Coastguard Worker            buffer_size += self.dim0
182*da0073e9SAndroid Build Coastguard Worker        return {
183*da0073e9SAndroid Build Coastguard Worker            "sol": buffer_size,
184*da0073e9SAndroid Build Coastguard Worker            "algorithmic": buffer_size,
185*da0073e9SAndroid Build Coastguard Worker        }
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Workerclass Reduce2DInnerBench(Reduce2DBench):
189*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, dim0, dim1):
190*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype, 1, dim0, dim1)
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker    @staticmethod
193*da0073e9SAndroid Build Coastguard Worker    def default_configs():
194*da0073e9SAndroid Build Coastguard Worker        parent_config = Reduce2DBench.default_configs()[0]
195*da0073e9SAndroid Build Coastguard Worker        return [parent_config[1:]]
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker    def config(self):
198*da0073e9SAndroid Build Coastguard Worker        parent_config = super().config()
199*da0073e9SAndroid Build Coastguard Worker        return parent_config[1:]
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker    @staticmethod
202*da0073e9SAndroid Build Coastguard Worker    def module():
203*da0073e9SAndroid Build Coastguard Worker        return "reduce2d_inner"
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Workerclass Reduce2DOuterBench(Reduce2DBench):
207*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, dim0, dim1):
208*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype, 0, dim0, dim1)
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker    @staticmethod
211*da0073e9SAndroid Build Coastguard Worker    def default_configs():
212*da0073e9SAndroid Build Coastguard Worker        parent_config = Reduce2DBench.default_configs()[0]
213*da0073e9SAndroid Build Coastguard Worker        return [parent_config[1:]]
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker    def config(self):
216*da0073e9SAndroid Build Coastguard Worker        parent_config = super().config()
217*da0073e9SAndroid Build Coastguard Worker        return parent_config[1:]
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker    @staticmethod
220*da0073e9SAndroid Build Coastguard Worker    def module():
221*da0073e9SAndroid Build Coastguard Worker        return "reduce2d_outer"
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(ReduceRowBench)
225*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(ReduceMidBench)
226*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(ReduceColBench)
227*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(Reduce2DInnerBench)
228*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(Reduce2DOuterBench)
229*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(ReduceFullBench)
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Workerclass DynamicReduce2DBench(benchmark.DynamicShape, Reduce2DBench):
233*da0073e9SAndroid Build Coastguard Worker    """
234*da0073e9SAndroid Build Coastguard Worker    A benchmark class to validate 2 dimensional reduction performance.
235*da0073e9SAndroid Build Coastguard Worker    Only a simple add is fused to induce the fuser and isolate reduction perf.
236*da0073e9SAndroid Build Coastguard Worker    """
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, red_dim, dim0, dim1):
239*da0073e9SAndroid Build Coastguard Worker        benchmark.DynamicShape.__init__(self)
240*da0073e9SAndroid Build Coastguard Worker        Reduce2DBench.__init__(self, mode, device, dtype, red_dim, dim0, dim1)
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker    def instantiate_input(self):
243*da0073e9SAndroid Build Coastguard Worker        dim0, dim1 = self.rand_shape([self.dim0, self.dim1])
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker        self.inputs = [
246*da0073e9SAndroid Build Coastguard Worker            self.randn(
247*da0073e9SAndroid Build Coastguard Worker                [dim0, dim1],
248*da0073e9SAndroid Build Coastguard Worker                device=self.device,
249*da0073e9SAndroid Build Coastguard Worker                dtype=self.dtype,
250*da0073e9SAndroid Build Coastguard Worker                requires_grad=self.requires_grad,
251*da0073e9SAndroid Build Coastguard Worker            )
252*da0073e9SAndroid Build Coastguard Worker        ]
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker    @staticmethod
255*da0073e9SAndroid Build Coastguard Worker    def module():
256*da0073e9SAndroid Build Coastguard Worker        return "dynamicreduce2d"
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Workerclass DynamicReduce2DInnerBench(DynamicReduce2DBench):
260*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, dim0, dim1):
261*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype, 1, dim0, dim1)
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker    @staticmethod
264*da0073e9SAndroid Build Coastguard Worker    def default_configs():
265*da0073e9SAndroid Build Coastguard Worker        parent_config = DynamicReduce2DBench.default_configs()[0]
266*da0073e9SAndroid Build Coastguard Worker        return [parent_config[1:]]
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker    def config(self):
269*da0073e9SAndroid Build Coastguard Worker        parent_config = super().config()
270*da0073e9SAndroid Build Coastguard Worker        return parent_config[1:]
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker    @staticmethod
273*da0073e9SAndroid Build Coastguard Worker    def module():
274*da0073e9SAndroid Build Coastguard Worker        return "reduce2d_dynamic_inner"
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Workerclass DynamicReduce2DOuterBench(DynamicReduce2DBench):
278*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, dim0, dim1):
279*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype, 0, dim0, dim1)
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker    @staticmethod
282*da0073e9SAndroid Build Coastguard Worker    def default_configs():
283*da0073e9SAndroid Build Coastguard Worker        parent_config = DynamicReduce2DBench.default_configs()[0]
284*da0073e9SAndroid Build Coastguard Worker        return [parent_config[1:]]
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker    def config(self):
287*da0073e9SAndroid Build Coastguard Worker        parent_config = super().config()
288*da0073e9SAndroid Build Coastguard Worker        return parent_config[1:]
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker    @staticmethod
291*da0073e9SAndroid Build Coastguard Worker    def module():
292*da0073e9SAndroid Build Coastguard Worker        return "reduce2d_dynamic_outer"
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(DynamicReduce2DInnerBench)
296*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(DynamicReduce2DOuterBench)
297