xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/broadcast.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import itertools
2import operator
3
4import numpy as np
5
6import torch
7
8from . import benchmark
9
10
11class BroadcastMulBench(benchmark.Benchmark):
12    def __init__(self, mode, device, dtype, case, M, N, K):
13        super().__init__(mode, device, dtype)
14        self.case = case
15        self.M = M
16        self.N = N
17        self.K = K
18
19        if case == "row":
20            self.d1 = self.rand(
21                [M, N, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
22            )
23            self.d2 = self.rand(
24                [M, 1, K], device=device, dtype=dtype, requires_grad=self.requires_grad
25            )
26        elif case == "mid":
27            self.d1 = self.rand(
28                [M, N, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
29            )
30            self.d2 = self.rand(
31                [1, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad
32            )
33        elif case == "col":
34            self.d1 = self.rand(
35                [M, 1, K], device=device, dtype=dtype, requires_grad=self.requires_grad
36            )
37            self.d2 = self.rand(
38                [1, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad
39            )
40        else:
41            raise ValueError(f"invalid case: {case}")
42
43        self.inputs = [self.d1, self.d2]
44
45    def forward(self, d1, d2):
46        y = d1 + d2
47        return y
48
49    def reference(self):
50        return self.numpy(self.d1) + self.numpy(self.d2)
51
52    def config(self):
53        return [self.M, self.N, self.K]
54
55    @staticmethod
56    def default_configs():
57        return [[128, 256, 128]]
58
59    def memory_workload(self):
60        if self.mode == "fwd":
61            sol_count = 1
62            algorithmic_count = 1
63        else:
64            sol_count = (1) + (1)
65            algorithmic_count = 1 + (1 + 1)
66
67        buffer_size = self.M * self.N * self.K
68        return {
69            "sol": buffer_size * sol_count,
70            "algorithmic": buffer_size * algorithmic_count,
71        }
72
73
74class BroadcastRowBench(BroadcastMulBench):
75    def __init__(self, mode, device, dtype, M, N, K):
76        super().__init__(mode, device, dtype, "row", M, N, K)
77
78    @staticmethod
79    def module():
80        return "broadcast_row"
81
82
83class BroadcastMidBench(BroadcastMulBench):
84    def __init__(self, mode, device, dtype, M, N, K):
85        super().__init__(mode, device, dtype, "mid", M, N, K)
86
87    @staticmethod
88    def module():
89        return "broadcast_mid"
90
91
92class BroadcastColBench(BroadcastMulBench):
93    def __init__(self, mode, device, dtype, M, N, K):
94        super().__init__(mode, device, dtype, "col", M, N, K)
95
96    @staticmethod
97    def module():
98        return "broadcast_col"
99
100
101class BroadcastThreeArgs(benchmark.Benchmark):
102    def __init__(self, mode, device, dtype, M, N, K, L):
103        super().__init__(mode, device, dtype)
104        self.M = M
105        self.N = N
106        self.K = K
107        self.L = L
108
109        self.d1 = self.rand(
110            [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad
111        )
112        self.d2 = self.rand(
113            [K, M, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
114        )
115        self.d3 = self.rand(
116            [L, K, 1, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
117        )
118
119        self.inputs = [self.d1, self.d2, self.d3]
120
121    def forward(self, d1, d2, d3):
122        y = d1 + d2 + d3
123        return y
124
125    def reference(self):
126        return self.numpy(self.d1) + self.numpy(self.d2) + self.numpy(self.d3)
127
128    def config(self):
129        return [self.M, self.N, self.K, self.L]
130
131    @staticmethod
132    def default_configs():
133        return [[32, 16, 64, 128]]
134
135    def memory_workload(self):
136        if self.mode == "fwd":
137            sol_count = 1
138            algorithmic_count = 1
139        else:
140            sol_count = (1) + (1)
141            algorithmic_count = 1 + (1 + 1 + 1)
142
143        buffer_size = self.M * self.N * self.K * self.L * 4
144        return {
145            "sol": buffer_size * sol_count,
146            "algorithmic": buffer_size * algorithmic_count,
147        }
148
149    @staticmethod
150    def module():
151        return "broadcast_3args"
152
153
154# benchmark.register_benchmark_class(BroadcastRowBench)
155# benchmark.register_benchmark_class(BroadcastMidBench)
156# benchmark.register_benchmark_class(BroadcastColBench)
157# benchmark.register_benchmark_class(BroadcastThreeArgs)
158
159
160# TODO: merge this with elementwise bench
161# A template class for elementwise operations.
162# A derived class will override the class instance to customize its behavior.
163class BroadcastBench(benchmark.Benchmark):
164    # List of customization class variables.
165    op_str = None
166    binary_op_pt_func = None
167    binary_op_np_func = None
168    unary_op_pt_func = None
169    unary_op_np_func = None
170    split_input = True
171
172    def __init__(self, mode, device, dtype, M, N, K):
173        super().__init__(mode, device, dtype)
174        self.M = M
175        self.N = N
176        self.K = K
177        self.d1 = self.rand(
178            [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad
179        )
180        self.d2 = self.rand(
181            [K, 1, N], device=device, dtype=dtype, requires_grad=self.requires_grad
182        )
183        self.d3 = self.rand(
184            [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad
185        )
186        self.d4 = self.rand(
187            [K, M, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
188        )
189        self.inputs = [self.d1, self.d2, self.d3, self.d4]
190
191    def _eval(self, d1, d2, d3, d4, binary_op, unary_op):
192        if not binary_op:
193
194            def binary_op(x, y):
195                return x + y
196
197        if not unary_op:
198
199            def unary_op(x):
200                return x
201
202        if self.split_input:
203            d1 = unary_op(d1)
204            d2 = unary_op(d2)
205            d3 = unary_op(d3)
206            d4 = unary_op(d4)
207        else:
208            d1, d2, d3, d4 = (
209                unary_op(d1),
210                unary_op(d2),
211                unary_op(d1 + 0.001),
212                unary_op(d4),
213            )
214        a = binary_op(d1, d2)
215        b = binary_op(d3, d4)
216        c = a + b
217        return c
218
219    def forward(self, d1, d2, d3, d4):
220        binary_op = self.__class__.binary_op_pt_func
221        unary_op = self.__class__.unary_op_pt_func
222        return self._eval(d1, d2, d3, d4, binary_op, unary_op)
223
224    def reference(self):
225        binary_op = self.__class__.binary_op_np_func
226        unary_op = self.__class__.unary_op_np_func
227        [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]]
228        return self._eval(d1, d2, d3, d4, binary_op, unary_op)
229
230    def config(self):
231        return [self.M, self.N, self.K]
232
233    @classmethod
234    def module(cls):
235        return "broadcast_" + cls.op_str
236
237    def memory_workload(self):
238        input_count = len(self.inputs)
239        if self.mode == "fwd":
240            if self.split_input:
241                sol_count = 1
242                algorithmic_count = 1
243            else:
244                sol_count = 1
245                algorithmic_count = 1
246        else:
247            if self.split_input:
248                sol_count = 1
249                algorithmic_count = input_count
250            else:
251                sol_count = 1
252                algorithmic_count = input_count
253
254        buffer_size = self.M * self.N * self.K * 4
255        return {
256            "sol": buffer_size * sol_count,
257            "algorithmic": buffer_size * algorithmic_count,
258        }
259
260    @staticmethod
261    def default_configs():
262        return [[1 << 8, 1 << 7, 1 << 9]]
263
264
265def register_broadcast_ops():
266    binary_op_list = [
267        ["mul", operator.mul],
268        ["add", operator.add],
269        ["sub", operator.sub],
270        ["div", lambda a, b: a / (b + 1e-4)],
271        [
272            "pow",
273            torch.pow,
274            np.power,
275        ],  # no fuson triggered
276        ["max", torch.max, np.maximum],
277        ["min", torch.min, np.minimum],
278    ]
279
280    unary_op_list = [
281        ["erf", torch.erf, np.erf],
282        ["exp", torch.exp, np.exp],
283        ["sin", torch.sin, np.sin],
284        ["cos", torch.cos, np.cos],
285    ]
286
287    for split_input, binary_op in itertools.product([True, False], binary_op_list):
288        # Make a copy of BroadcastBench
289        if len(binary_op) == 2:
290            [op_str, op_pt_func] = binary_op
291            op_np_func = op_pt_func
292        elif len(binary_op) == 3:
293            [op_str, op_pt_func, op_np_func] = binary_op
294        split_str = "split" if split_input else "shared"
295        op_str = split_str + "_" + op_str
296        bm_cls = type("BroadcastBench_" + op_str, (BroadcastBench,), {})
297        bm_cls.op_str = op_str
298        bm_cls.binary_op_pt_func = op_pt_func
299        bm_cls.binary_op_np_func = op_np_func
300        bm_cls.split_input = split_input
301        benchmark.register_benchmark_class(bm_cls)
302
303    for split_input, unary_op in itertools.product([True, False], unary_op_list):
304        # Make a copy of BroadcastBench
305        if len(unary_op) == 2:
306            [op_str, op_pt_func] = unary_op
307            op_np_func = op_pt_func
308        elif len(unary_op) == 3:
309            [op_str, op_pt_func, op_np_func] = unary_op
310        split_str = "split" if split_input else "shared"
311        op_str = split_str + "_" + op_str
312        bm_cls = type("BroadcastBench_" + op_str, (BroadcastBench,), {})
313        bm_cls.op_str = op_str
314        bm_cls.unary_op_pt_func = op_pt_func
315        bm_cls.unary_op_np_func = op_np_func
316        bm_cls.split_input = split_input
317        benchmark.register_benchmark_class(bm_cls)
318
319
320register_broadcast_ops()
321