xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/reduction.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from . import benchmark
2
3
4class ReduceBench(benchmark.Benchmark):
5    def __init__(self, mode, device, dtype, case, M, N, K, skip_input_transform):
6        super().__init__(mode, device, dtype)
7        self.case = case
8        self.M = M
9        self.N = N
10        self.K = K
11        self._set_skip_input_transform(skip_input_transform)
12
13        self.inputs = [
14            self.randn(
15                [M, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad
16            )
17        ]
18        if case == "row":
19            self.dims = [1, 2]
20        elif case == "mid":
21            self.dims = [0, 2]
22        elif case == "col":
23            self.dims = [0, 1]
24        elif case == "full":
25            self.dims = [0, 1, 2]
26        else:
27            raise ValueError(f"invalid case: {case}")
28
29    def forward(self, inputs):
30        if self.skip_input_transform:
31            x = inputs
32        else:
33            x = self.add(inputs, 0.001)
34        y = self.sum(x, self.dims)
35        return y
36
37    def config(self):
38        if self.case == "full":
39            return [self.M * self.N * self.K, self._skip_input_transform_str()]
40        return [self.M, self.N, self.K, self._skip_input_transform_str()]
41
42    @staticmethod
43    def default_configs():
44        return [
45            # [512, 512, 512],
46            [512, 64, 512, "s0"],
47        ]
48
49    @staticmethod
50    def module():
51        return "reduce"
52
53    def memory_workload(self):
54        if self.mode == "fwd":
55            sol_count = 1
56            algorithmic_count = 1
57        else:
58            sol_count = (1) + (1)
59            algorithmic_count = 1 + 1
60
61        buffer_size = self.M * self.N * self.K
62        return {
63            "sol": buffer_size * sol_count,
64            "algorithmic": buffer_size * algorithmic_count,
65        }
66
67    def _set_skip_input_transform(self, input_str):
68        # In the test setting, s1 will skip the input transformation, and s0 will not.
69        if input_str == "s0":
70            self.skip_input_transform = False
71        elif input_str == "s1":
72            self.skip_input_transform = True
73        else:
74            raise ValueError(f"invalid skip_input_transform: {input_str}")
75
76    def _skip_input_transform_str(self):
77        if self.skip_input_transform:
78            return "s1"
79        else:
80            return "s0"
81
82
83class ReduceRowBench(ReduceBench):
84    def __init__(self, mode, device, dtype, M, N, K, skip_input_transform):
85        super().__init__(mode, device, dtype, "row", M, N, K, skip_input_transform)
86
87    @staticmethod
88    def module():
89        return "reduce_row"
90
91
92class ReduceMidBench(ReduceBench):
93    def __init__(self, mode, device, dtype, M, N, K, skip_input_transform):
94        super().__init__(mode, device, dtype, "mid", M, N, K, skip_input_transform)
95
96    @staticmethod
97    def module():
98        return "reduce_mid"
99
100
101class ReduceColBench(ReduceBench):
102    def __init__(self, mode, device, dtype, M, N, K, skip_input_transform):
103        super().__init__(mode, device, dtype, "col", M, N, K, skip_input_transform)
104
105    @staticmethod
106    def module():
107        return "reduce_col"
108
109
110class ReduceFullBench(ReduceBench):
111    def __init__(self, mode, device, dtype, M, skip_input_transform):
112        super().__init__(mode, device, dtype, "full", M, 1, 1, skip_input_transform)
113
114    def config(self):
115        return [self.M * self.N * self.K, self._skip_input_transform_str()]
116
117    @staticmethod
118    def default_configs():
119        return [
120            [1 << 24, "s1"],
121        ]
122
123    @staticmethod
124    def module():
125        return "reduce_full"
126
127
128class Reduce2DBench(benchmark.Benchmark):
129    """
130    A benchmark class to validate 2 dimensional reduction performance.
131    Only a simple add is fused to induce the fuser and isolate reduction perf.
132    """
133
134    def __init__(self, mode, device, dtype, red_dim, dim0, dim1):
135        super().__init__(mode, device, dtype)
136        self.red_dim = red_dim
137        self.dim0 = dim0
138        self.dim1 = dim1
139
140        self.inputs = [
141            self.randn(
142                [dim0, dim1],
143                device=device,
144                dtype=dtype,
145                requires_grad=self.requires_grad,
146            )
147        ]
148
149        if red_dim != 0 and red_dim != 1:
150            raise ValueError(f"invalid reduction dimension: {red_dim}")
151
152    def forward(self, inputs):
153        x = self.add(inputs, 0.001)
154        y = self.sum(x, [self.red_dim])
155        return y
156
157    def config(self):
158        return [self.red_dim, self.dim0, self.dim1]
159
160    @staticmethod
161    def default_configs():
162        return [
163            [1, 640, 524288],
164        ]
165
166    @staticmethod
167    def module():
168        return "reduce2d"
169
170    @staticmethod
171    def input_iterable():
172        return True
173
174    def memory_workload(self):
175        assert self.mode == "fwd", "Only the forward operation is modeled!"
176
177        buffer_size = self.dim0 * self.dim1
178        if self.red_dim == 0:
179            buffer_size += self.dim1
180        else:
181            buffer_size += self.dim0
182        return {
183            "sol": buffer_size,
184            "algorithmic": buffer_size,
185        }
186
187
188class Reduce2DInnerBench(Reduce2DBench):
189    def __init__(self, mode, device, dtype, dim0, dim1):
190        super().__init__(mode, device, dtype, 1, dim0, dim1)
191
192    @staticmethod
193    def default_configs():
194        parent_config = Reduce2DBench.default_configs()[0]
195        return [parent_config[1:]]
196
197    def config(self):
198        parent_config = super().config()
199        return parent_config[1:]
200
201    @staticmethod
202    def module():
203        return "reduce2d_inner"
204
205
206class Reduce2DOuterBench(Reduce2DBench):
207    def __init__(self, mode, device, dtype, dim0, dim1):
208        super().__init__(mode, device, dtype, 0, dim0, dim1)
209
210    @staticmethod
211    def default_configs():
212        parent_config = Reduce2DBench.default_configs()[0]
213        return [parent_config[1:]]
214
215    def config(self):
216        parent_config = super().config()
217        return parent_config[1:]
218
219    @staticmethod
220    def module():
221        return "reduce2d_outer"
222
223
224benchmark.register_benchmark_class(ReduceRowBench)
225benchmark.register_benchmark_class(ReduceMidBench)
226benchmark.register_benchmark_class(ReduceColBench)
227benchmark.register_benchmark_class(Reduce2DInnerBench)
228benchmark.register_benchmark_class(Reduce2DOuterBench)
229benchmark.register_benchmark_class(ReduceFullBench)
230
231
232class DynamicReduce2DBench(benchmark.DynamicShape, Reduce2DBench):
233    """
234    A benchmark class to validate 2 dimensional reduction performance.
235    Only a simple add is fused to induce the fuser and isolate reduction perf.
236    """
237
238    def __init__(self, mode, device, dtype, red_dim, dim0, dim1):
239        benchmark.DynamicShape.__init__(self)
240        Reduce2DBench.__init__(self, mode, device, dtype, red_dim, dim0, dim1)
241
242    def instantiate_input(self):
243        dim0, dim1 = self.rand_shape([self.dim0, self.dim1])
244
245        self.inputs = [
246            self.randn(
247                [dim0, dim1],
248                device=self.device,
249                dtype=self.dtype,
250                requires_grad=self.requires_grad,
251            )
252        ]
253
254    @staticmethod
255    def module():
256        return "dynamicreduce2d"
257
258
259class DynamicReduce2DInnerBench(DynamicReduce2DBench):
260    def __init__(self, mode, device, dtype, dim0, dim1):
261        super().__init__(mode, device, dtype, 1, dim0, dim1)
262
263    @staticmethod
264    def default_configs():
265        parent_config = DynamicReduce2DBench.default_configs()[0]
266        return [parent_config[1:]]
267
268    def config(self):
269        parent_config = super().config()
270        return parent_config[1:]
271
272    @staticmethod
273    def module():
274        return "reduce2d_dynamic_inner"
275
276
277class DynamicReduce2DOuterBench(DynamicReduce2DBench):
278    def __init__(self, mode, device, dtype, dim0, dim1):
279        super().__init__(mode, device, dtype, 0, dim0, dim1)
280
281    @staticmethod
282    def default_configs():
283        parent_config = DynamicReduce2DBench.default_configs()[0]
284        return [parent_config[1:]]
285
286    def config(self):
287        parent_config = super().config()
288        return parent_config[1:]
289
290    @staticmethod
291    def module():
292        return "reduce2d_dynamic_outer"
293
294
295benchmark.register_benchmark_class(DynamicReduce2DInnerBench)
296benchmark.register_benchmark_class(DynamicReduce2DOuterBench)
297