xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/broadcast.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport itertools
2*da0073e9SAndroid Build Coastguard Workerimport operator
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport numpy as np
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerfrom . import benchmark
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerclass BroadcastMulBench(benchmark.Benchmark):
12*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, case, M, N, K):
13*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype)
14*da0073e9SAndroid Build Coastguard Worker        self.case = case
15*da0073e9SAndroid Build Coastguard Worker        self.M = M
16*da0073e9SAndroid Build Coastguard Worker        self.N = N
17*da0073e9SAndroid Build Coastguard Worker        self.K = K
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker        if case == "row":
20*da0073e9SAndroid Build Coastguard Worker            self.d1 = self.rand(
21*da0073e9SAndroid Build Coastguard Worker                [M, N, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
22*da0073e9SAndroid Build Coastguard Worker            )
23*da0073e9SAndroid Build Coastguard Worker            self.d2 = self.rand(
24*da0073e9SAndroid Build Coastguard Worker                [M, 1, K], device=device, dtype=dtype, requires_grad=self.requires_grad
25*da0073e9SAndroid Build Coastguard Worker            )
26*da0073e9SAndroid Build Coastguard Worker        elif case == "mid":
27*da0073e9SAndroid Build Coastguard Worker            self.d1 = self.rand(
28*da0073e9SAndroid Build Coastguard Worker                [M, N, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
29*da0073e9SAndroid Build Coastguard Worker            )
30*da0073e9SAndroid Build Coastguard Worker            self.d2 = self.rand(
31*da0073e9SAndroid Build Coastguard Worker                [1, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad
32*da0073e9SAndroid Build Coastguard Worker            )
33*da0073e9SAndroid Build Coastguard Worker        elif case == "col":
34*da0073e9SAndroid Build Coastguard Worker            self.d1 = self.rand(
35*da0073e9SAndroid Build Coastguard Worker                [M, 1, K], device=device, dtype=dtype, requires_grad=self.requires_grad
36*da0073e9SAndroid Build Coastguard Worker            )
37*da0073e9SAndroid Build Coastguard Worker            self.d2 = self.rand(
38*da0073e9SAndroid Build Coastguard Worker                [1, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad
39*da0073e9SAndroid Build Coastguard Worker            )
40*da0073e9SAndroid Build Coastguard Worker        else:
41*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"invalid case: {case}")
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker        self.inputs = [self.d1, self.d2]
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker    def forward(self, d1, d2):
46*da0073e9SAndroid Build Coastguard Worker        y = d1 + d2
47*da0073e9SAndroid Build Coastguard Worker        return y
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker    def reference(self):
50*da0073e9SAndroid Build Coastguard Worker        return self.numpy(self.d1) + self.numpy(self.d2)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker    def config(self):
53*da0073e9SAndroid Build Coastguard Worker        return [self.M, self.N, self.K]
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    @staticmethod
56*da0073e9SAndroid Build Coastguard Worker    def default_configs():
57*da0073e9SAndroid Build Coastguard Worker        return [[128, 256, 128]]
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker    def memory_workload(self):
60*da0073e9SAndroid Build Coastguard Worker        if self.mode == "fwd":
61*da0073e9SAndroid Build Coastguard Worker            sol_count = 1
62*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = 1
63*da0073e9SAndroid Build Coastguard Worker        else:
64*da0073e9SAndroid Build Coastguard Worker            sol_count = (1) + (1)
65*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = 1 + (1 + 1)
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker        buffer_size = self.M * self.N * self.K
68*da0073e9SAndroid Build Coastguard Worker        return {
69*da0073e9SAndroid Build Coastguard Worker            "sol": buffer_size * sol_count,
70*da0073e9SAndroid Build Coastguard Worker            "algorithmic": buffer_size * algorithmic_count,
71*da0073e9SAndroid Build Coastguard Worker        }
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Workerclass BroadcastRowBench(BroadcastMulBench):
75*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, M, N, K):
76*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype, "row", M, N, K)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    @staticmethod
79*da0073e9SAndroid Build Coastguard Worker    def module():
80*da0073e9SAndroid Build Coastguard Worker        return "broadcast_row"
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Workerclass BroadcastMidBench(BroadcastMulBench):
84*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, M, N, K):
85*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype, "mid", M, N, K)
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker    @staticmethod
88*da0073e9SAndroid Build Coastguard Worker    def module():
89*da0073e9SAndroid Build Coastguard Worker        return "broadcast_mid"
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Workerclass BroadcastColBench(BroadcastMulBench):
93*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, M, N, K):
94*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype, "col", M, N, K)
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker    @staticmethod
97*da0073e9SAndroid Build Coastguard Worker    def module():
98*da0073e9SAndroid Build Coastguard Worker        return "broadcast_col"
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Workerclass BroadcastThreeArgs(benchmark.Benchmark):
102*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, M, N, K, L):
103*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype)
104*da0073e9SAndroid Build Coastguard Worker        self.M = M
105*da0073e9SAndroid Build Coastguard Worker        self.N = N
106*da0073e9SAndroid Build Coastguard Worker        self.K = K
107*da0073e9SAndroid Build Coastguard Worker        self.L = L
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker        self.d1 = self.rand(
110*da0073e9SAndroid Build Coastguard Worker            [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad
111*da0073e9SAndroid Build Coastguard Worker        )
112*da0073e9SAndroid Build Coastguard Worker        self.d2 = self.rand(
113*da0073e9SAndroid Build Coastguard Worker            [K, M, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
114*da0073e9SAndroid Build Coastguard Worker        )
115*da0073e9SAndroid Build Coastguard Worker        self.d3 = self.rand(
116*da0073e9SAndroid Build Coastguard Worker            [L, K, 1, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
117*da0073e9SAndroid Build Coastguard Worker        )
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker        self.inputs = [self.d1, self.d2, self.d3]
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker    def forward(self, d1, d2, d3):
122*da0073e9SAndroid Build Coastguard Worker        y = d1 + d2 + d3
123*da0073e9SAndroid Build Coastguard Worker        return y
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker    def reference(self):
126*da0073e9SAndroid Build Coastguard Worker        return self.numpy(self.d1) + self.numpy(self.d2) + self.numpy(self.d3)
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker    def config(self):
129*da0073e9SAndroid Build Coastguard Worker        return [self.M, self.N, self.K, self.L]
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker    @staticmethod
132*da0073e9SAndroid Build Coastguard Worker    def default_configs():
133*da0073e9SAndroid Build Coastguard Worker        return [[32, 16, 64, 128]]
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker    def memory_workload(self):
136*da0073e9SAndroid Build Coastguard Worker        if self.mode == "fwd":
137*da0073e9SAndroid Build Coastguard Worker            sol_count = 1
138*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = 1
139*da0073e9SAndroid Build Coastguard Worker        else:
140*da0073e9SAndroid Build Coastguard Worker            sol_count = (1) + (1)
141*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = 1 + (1 + 1 + 1)
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker        buffer_size = self.M * self.N * self.K * self.L * 4
144*da0073e9SAndroid Build Coastguard Worker        return {
145*da0073e9SAndroid Build Coastguard Worker            "sol": buffer_size * sol_count,
146*da0073e9SAndroid Build Coastguard Worker            "algorithmic": buffer_size * algorithmic_count,
147*da0073e9SAndroid Build Coastguard Worker        }
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker    @staticmethod
150*da0073e9SAndroid Build Coastguard Worker    def module():
151*da0073e9SAndroid Build Coastguard Worker        return "broadcast_3args"
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker# benchmark.register_benchmark_class(BroadcastRowBench)
155*da0073e9SAndroid Build Coastguard Worker# benchmark.register_benchmark_class(BroadcastMidBench)
156*da0073e9SAndroid Build Coastguard Worker# benchmark.register_benchmark_class(BroadcastColBench)
157*da0073e9SAndroid Build Coastguard Worker# benchmark.register_benchmark_class(BroadcastThreeArgs)
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker# TODO: merge this with elementwise bench
161*da0073e9SAndroid Build Coastguard Worker# A template class for elementwise operations.
162*da0073e9SAndroid Build Coastguard Worker# A derived class will override the class instance to customize its behavior.
163*da0073e9SAndroid Build Coastguard Workerclass BroadcastBench(benchmark.Benchmark):
164*da0073e9SAndroid Build Coastguard Worker    # List of customization class variables.
165*da0073e9SAndroid Build Coastguard Worker    op_str = None
166*da0073e9SAndroid Build Coastguard Worker    binary_op_pt_func = None
167*da0073e9SAndroid Build Coastguard Worker    binary_op_np_func = None
168*da0073e9SAndroid Build Coastguard Worker    unary_op_pt_func = None
169*da0073e9SAndroid Build Coastguard Worker    unary_op_np_func = None
170*da0073e9SAndroid Build Coastguard Worker    split_input = True
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, M, N, K):
173*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype)
174*da0073e9SAndroid Build Coastguard Worker        self.M = M
175*da0073e9SAndroid Build Coastguard Worker        self.N = N
176*da0073e9SAndroid Build Coastguard Worker        self.K = K
177*da0073e9SAndroid Build Coastguard Worker        self.d1 = self.rand(
178*da0073e9SAndroid Build Coastguard Worker            [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad
179*da0073e9SAndroid Build Coastguard Worker        )
180*da0073e9SAndroid Build Coastguard Worker        self.d2 = self.rand(
181*da0073e9SAndroid Build Coastguard Worker            [K, 1, N], device=device, dtype=dtype, requires_grad=self.requires_grad
182*da0073e9SAndroid Build Coastguard Worker        )
183*da0073e9SAndroid Build Coastguard Worker        self.d3 = self.rand(
184*da0073e9SAndroid Build Coastguard Worker            [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad
185*da0073e9SAndroid Build Coastguard Worker        )
186*da0073e9SAndroid Build Coastguard Worker        self.d4 = self.rand(
187*da0073e9SAndroid Build Coastguard Worker            [K, M, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
188*da0073e9SAndroid Build Coastguard Worker        )
189*da0073e9SAndroid Build Coastguard Worker        self.inputs = [self.d1, self.d2, self.d3, self.d4]
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker    def _eval(self, d1, d2, d3, d4, binary_op, unary_op):
192*da0073e9SAndroid Build Coastguard Worker        if not binary_op:
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker            def binary_op(x, y):
195*da0073e9SAndroid Build Coastguard Worker                return x + y
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker        if not unary_op:
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker            def unary_op(x):
200*da0073e9SAndroid Build Coastguard Worker                return x
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker        if self.split_input:
203*da0073e9SAndroid Build Coastguard Worker            d1 = unary_op(d1)
204*da0073e9SAndroid Build Coastguard Worker            d2 = unary_op(d2)
205*da0073e9SAndroid Build Coastguard Worker            d3 = unary_op(d3)
206*da0073e9SAndroid Build Coastguard Worker            d4 = unary_op(d4)
207*da0073e9SAndroid Build Coastguard Worker        else:
208*da0073e9SAndroid Build Coastguard Worker            d1, d2, d3, d4 = (
209*da0073e9SAndroid Build Coastguard Worker                unary_op(d1),
210*da0073e9SAndroid Build Coastguard Worker                unary_op(d2),
211*da0073e9SAndroid Build Coastguard Worker                unary_op(d1 + 0.001),
212*da0073e9SAndroid Build Coastguard Worker                unary_op(d4),
213*da0073e9SAndroid Build Coastguard Worker            )
214*da0073e9SAndroid Build Coastguard Worker        a = binary_op(d1, d2)
215*da0073e9SAndroid Build Coastguard Worker        b = binary_op(d3, d4)
216*da0073e9SAndroid Build Coastguard Worker        c = a + b
217*da0073e9SAndroid Build Coastguard Worker        return c
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker    def forward(self, d1, d2, d3, d4):
220*da0073e9SAndroid Build Coastguard Worker        binary_op = self.__class__.binary_op_pt_func
221*da0073e9SAndroid Build Coastguard Worker        unary_op = self.__class__.unary_op_pt_func
222*da0073e9SAndroid Build Coastguard Worker        return self._eval(d1, d2, d3, d4, binary_op, unary_op)
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker    def reference(self):
225*da0073e9SAndroid Build Coastguard Worker        binary_op = self.__class__.binary_op_np_func
226*da0073e9SAndroid Build Coastguard Worker        unary_op = self.__class__.unary_op_np_func
227*da0073e9SAndroid Build Coastguard Worker        [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]]
228*da0073e9SAndroid Build Coastguard Worker        return self._eval(d1, d2, d3, d4, binary_op, unary_op)
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker    def config(self):
231*da0073e9SAndroid Build Coastguard Worker        return [self.M, self.N, self.K]
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker    @classmethod
234*da0073e9SAndroid Build Coastguard Worker    def module(cls):
235*da0073e9SAndroid Build Coastguard Worker        return "broadcast_" + cls.op_str
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker    def memory_workload(self):
238*da0073e9SAndroid Build Coastguard Worker        input_count = len(self.inputs)
239*da0073e9SAndroid Build Coastguard Worker        if self.mode == "fwd":
240*da0073e9SAndroid Build Coastguard Worker            if self.split_input:
241*da0073e9SAndroid Build Coastguard Worker                sol_count = 1
242*da0073e9SAndroid Build Coastguard Worker                algorithmic_count = 1
243*da0073e9SAndroid Build Coastguard Worker            else:
244*da0073e9SAndroid Build Coastguard Worker                sol_count = 1
245*da0073e9SAndroid Build Coastguard Worker                algorithmic_count = 1
246*da0073e9SAndroid Build Coastguard Worker        else:
247*da0073e9SAndroid Build Coastguard Worker            if self.split_input:
248*da0073e9SAndroid Build Coastguard Worker                sol_count = 1
249*da0073e9SAndroid Build Coastguard Worker                algorithmic_count = input_count
250*da0073e9SAndroid Build Coastguard Worker            else:
251*da0073e9SAndroid Build Coastguard Worker                sol_count = 1
252*da0073e9SAndroid Build Coastguard Worker                algorithmic_count = input_count
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker        buffer_size = self.M * self.N * self.K * 4
255*da0073e9SAndroid Build Coastguard Worker        return {
256*da0073e9SAndroid Build Coastguard Worker            "sol": buffer_size * sol_count,
257*da0073e9SAndroid Build Coastguard Worker            "algorithmic": buffer_size * algorithmic_count,
258*da0073e9SAndroid Build Coastguard Worker        }
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker    @staticmethod
261*da0073e9SAndroid Build Coastguard Worker    def default_configs():
262*da0073e9SAndroid Build Coastguard Worker        return [[1 << 8, 1 << 7, 1 << 9]]
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Workerdef register_broadcast_ops():
266*da0073e9SAndroid Build Coastguard Worker    binary_op_list = [
267*da0073e9SAndroid Build Coastguard Worker        ["mul", operator.mul],
268*da0073e9SAndroid Build Coastguard Worker        ["add", operator.add],
269*da0073e9SAndroid Build Coastguard Worker        ["sub", operator.sub],
270*da0073e9SAndroid Build Coastguard Worker        ["div", lambda a, b: a / (b + 1e-4)],
271*da0073e9SAndroid Build Coastguard Worker        [
272*da0073e9SAndroid Build Coastguard Worker            "pow",
273*da0073e9SAndroid Build Coastguard Worker            torch.pow,
274*da0073e9SAndroid Build Coastguard Worker            np.power,
275*da0073e9SAndroid Build Coastguard Worker        ],  # no fuson triggered
276*da0073e9SAndroid Build Coastguard Worker        ["max", torch.max, np.maximum],
277*da0073e9SAndroid Build Coastguard Worker        ["min", torch.min, np.minimum],
278*da0073e9SAndroid Build Coastguard Worker    ]
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker    unary_op_list = [
281*da0073e9SAndroid Build Coastguard Worker        ["erf", torch.erf, np.erf],
282*da0073e9SAndroid Build Coastguard Worker        ["exp", torch.exp, np.exp],
283*da0073e9SAndroid Build Coastguard Worker        ["sin", torch.sin, np.sin],
284*da0073e9SAndroid Build Coastguard Worker        ["cos", torch.cos, np.cos],
285*da0073e9SAndroid Build Coastguard Worker    ]
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker    for split_input, binary_op in itertools.product([True, False], binary_op_list):
288*da0073e9SAndroid Build Coastguard Worker        # Make a copy of BroadcastBench
289*da0073e9SAndroid Build Coastguard Worker        if len(binary_op) == 2:
290*da0073e9SAndroid Build Coastguard Worker            [op_str, op_pt_func] = binary_op
291*da0073e9SAndroid Build Coastguard Worker            op_np_func = op_pt_func
292*da0073e9SAndroid Build Coastguard Worker        elif len(binary_op) == 3:
293*da0073e9SAndroid Build Coastguard Worker            [op_str, op_pt_func, op_np_func] = binary_op
294*da0073e9SAndroid Build Coastguard Worker        split_str = "split" if split_input else "shared"
295*da0073e9SAndroid Build Coastguard Worker        op_str = split_str + "_" + op_str
296*da0073e9SAndroid Build Coastguard Worker        bm_cls = type("BroadcastBench_" + op_str, (BroadcastBench,), {})
297*da0073e9SAndroid Build Coastguard Worker        bm_cls.op_str = op_str
298*da0073e9SAndroid Build Coastguard Worker        bm_cls.binary_op_pt_func = op_pt_func
299*da0073e9SAndroid Build Coastguard Worker        bm_cls.binary_op_np_func = op_np_func
300*da0073e9SAndroid Build Coastguard Worker        bm_cls.split_input = split_input
301*da0073e9SAndroid Build Coastguard Worker        benchmark.register_benchmark_class(bm_cls)
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker    for split_input, unary_op in itertools.product([True, False], unary_op_list):
304*da0073e9SAndroid Build Coastguard Worker        # Make a copy of BroadcastBench
305*da0073e9SAndroid Build Coastguard Worker        if len(unary_op) == 2:
306*da0073e9SAndroid Build Coastguard Worker            [op_str, op_pt_func] = unary_op
307*da0073e9SAndroid Build Coastguard Worker            op_np_func = op_pt_func
308*da0073e9SAndroid Build Coastguard Worker        elif len(unary_op) == 3:
309*da0073e9SAndroid Build Coastguard Worker            [op_str, op_pt_func, op_np_func] = unary_op
310*da0073e9SAndroid Build Coastguard Worker        split_str = "split" if split_input else "shared"
311*da0073e9SAndroid Build Coastguard Worker        op_str = split_str + "_" + op_str
312*da0073e9SAndroid Build Coastguard Worker        bm_cls = type("BroadcastBench_" + op_str, (BroadcastBench,), {})
313*da0073e9SAndroid Build Coastguard Worker        bm_cls.op_str = op_str
314*da0073e9SAndroid Build Coastguard Worker        bm_cls.unary_op_pt_func = op_pt_func
315*da0073e9SAndroid Build Coastguard Worker        bm_cls.unary_op_np_func = op_np_func
316*da0073e9SAndroid Build Coastguard Worker        bm_cls.split_input = split_input
317*da0073e9SAndroid Build Coastguard Worker        benchmark.register_benchmark_class(bm_cls)
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Workerregister_broadcast_ops()
321