xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/elementwise.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport itertools
2*da0073e9SAndroid Build Coastguard Workerimport operator
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport numpy as np
5*da0073e9SAndroid Build Coastguard Workerimport scipy.special
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerfrom . import benchmark
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker# A template class for elementwise operations.
13*da0073e9SAndroid Build Coastguard Worker# A derived class will override the class instance to customize its behavior.
14*da0073e9SAndroid Build Coastguard Workerclass ElementBench(benchmark.Benchmark):
15*da0073e9SAndroid Build Coastguard Worker    # List of customization class variables.
16*da0073e9SAndroid Build Coastguard Worker    op_str = None
17*da0073e9SAndroid Build Coastguard Worker    binary_op_pt_func = None
18*da0073e9SAndroid Build Coastguard Worker    binary_op_np_func = None
19*da0073e9SAndroid Build Coastguard Worker    unary_op_pt_func = None
20*da0073e9SAndroid Build Coastguard Worker    unary_op_np_func = None
21*da0073e9SAndroid Build Coastguard Worker    split_input = True
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, N):
24*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype)
25*da0073e9SAndroid Build Coastguard Worker        self.N = N
26*da0073e9SAndroid Build Coastguard Worker        self.d1 = self.rand(
27*da0073e9SAndroid Build Coastguard Worker            [N], device=device, dtype=dtype, requires_grad=self.requires_grad
28*da0073e9SAndroid Build Coastguard Worker        )
29*da0073e9SAndroid Build Coastguard Worker        self.d2 = self.rand(
30*da0073e9SAndroid Build Coastguard Worker            [N], device=device, dtype=dtype, requires_grad=self.requires_grad
31*da0073e9SAndroid Build Coastguard Worker        )
32*da0073e9SAndroid Build Coastguard Worker        self.d3 = self.rand(
33*da0073e9SAndroid Build Coastguard Worker            [N], device=device, dtype=dtype, requires_grad=self.requires_grad
34*da0073e9SAndroid Build Coastguard Worker        )
35*da0073e9SAndroid Build Coastguard Worker        self.d4 = self.rand(
36*da0073e9SAndroid Build Coastguard Worker            [N], device=device, dtype=dtype, requires_grad=self.requires_grad
37*da0073e9SAndroid Build Coastguard Worker        )
38*da0073e9SAndroid Build Coastguard Worker        self.inputs = [self.d1, self.d2, self.d3, self.d4]
39*da0073e9SAndroid Build Coastguard Worker        self.deterministic = "rand" not in self.op_str
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker    def _eval(self, d1, d2, d3, d4, binary_op, unary_op):
42*da0073e9SAndroid Build Coastguard Worker        if not binary_op:
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker            def binary_op(x, y):
45*da0073e9SAndroid Build Coastguard Worker                return x + y
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker        if not unary_op:
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker            def unary_op(x):
50*da0073e9SAndroid Build Coastguard Worker                return x
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker        if self.split_input:
53*da0073e9SAndroid Build Coastguard Worker            d1 = unary_op(d1)
54*da0073e9SAndroid Build Coastguard Worker            d2 = unary_op(d2)
55*da0073e9SAndroid Build Coastguard Worker            d3 = unary_op(d3)
56*da0073e9SAndroid Build Coastguard Worker            d4 = unary_op(d4)
57*da0073e9SAndroid Build Coastguard Worker        else:
58*da0073e9SAndroid Build Coastguard Worker            d2 = unary_op(d1 + 0.001)
59*da0073e9SAndroid Build Coastguard Worker            d3 = unary_op(d1 + 0.002)
60*da0073e9SAndroid Build Coastguard Worker            d4 = unary_op(d1 + 0.003)
61*da0073e9SAndroid Build Coastguard Worker            d1 = unary_op(d1)
62*da0073e9SAndroid Build Coastguard Worker        a = binary_op(d1, d2)
63*da0073e9SAndroid Build Coastguard Worker        b = binary_op(d3, d4)
64*da0073e9SAndroid Build Coastguard Worker        c = a + b
65*da0073e9SAndroid Build Coastguard Worker        return c
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker    def forward(self, d1, d2, d3, d4):
68*da0073e9SAndroid Build Coastguard Worker        binary_op = self.__class__.binary_op_pt_func
69*da0073e9SAndroid Build Coastguard Worker        unary_op = self.__class__.unary_op_pt_func
70*da0073e9SAndroid Build Coastguard Worker        return self._eval(d1, d2, d3, d4, binary_op, unary_op)
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker    def reference(self):
73*da0073e9SAndroid Build Coastguard Worker        binary_op = self.__class__.binary_op_np_func
74*da0073e9SAndroid Build Coastguard Worker        unary_op = self.__class__.unary_op_np_func
75*da0073e9SAndroid Build Coastguard Worker        [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]]
76*da0073e9SAndroid Build Coastguard Worker        return self._eval(d1, d2, d3, d4, binary_op, unary_op)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker    def config(self):
79*da0073e9SAndroid Build Coastguard Worker        return [self.N]
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker    @classmethod
82*da0073e9SAndroid Build Coastguard Worker    def module(cls):
83*da0073e9SAndroid Build Coastguard Worker        return "element_" + cls.op_str
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    def memory_workload(self):
86*da0073e9SAndroid Build Coastguard Worker        input_count = len(self.inputs)
87*da0073e9SAndroid Build Coastguard Worker        if self.mode == "fwd":
88*da0073e9SAndroid Build Coastguard Worker            if self.split_input:
89*da0073e9SAndroid Build Coastguard Worker                sol_count = input_count + 1
90*da0073e9SAndroid Build Coastguard Worker                algorithmic_count = input_count + 1
91*da0073e9SAndroid Build Coastguard Worker            else:
92*da0073e9SAndroid Build Coastguard Worker                sol_count = 1 + 1
93*da0073e9SAndroid Build Coastguard Worker                algorithmic_count = 1 + 1
94*da0073e9SAndroid Build Coastguard Worker            if "rand" in self.op_str:
95*da0073e9SAndroid Build Coastguard Worker                sol_count = 1
96*da0073e9SAndroid Build Coastguard Worker                algorithmic_count = 1
97*da0073e9SAndroid Build Coastguard Worker        else:
98*da0073e9SAndroid Build Coastguard Worker            if self.split_input:
99*da0073e9SAndroid Build Coastguard Worker                sol_count = (input_count + 1) + (1 + input_count)
100*da0073e9SAndroid Build Coastguard Worker                algorithmic_count = (input_count + 1) + ((2 + 1) * input_count)
101*da0073e9SAndroid Build Coastguard Worker            else:
102*da0073e9SAndroid Build Coastguard Worker                sol_count = 1 + 1
103*da0073e9SAndroid Build Coastguard Worker                algorithmic_count = 1 + 1
104*da0073e9SAndroid Build Coastguard Worker            if "rand" in self.op_str:
105*da0073e9SAndroid Build Coastguard Worker                sol_count = 1
106*da0073e9SAndroid Build Coastguard Worker                algorithmic_count = 1
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker        buffer_size = self.N
109*da0073e9SAndroid Build Coastguard Worker        return {
110*da0073e9SAndroid Build Coastguard Worker            "sol": buffer_size * sol_count,
111*da0073e9SAndroid Build Coastguard Worker            "algorithmic": buffer_size * algorithmic_count,
112*da0073e9SAndroid Build Coastguard Worker        }
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker    @staticmethod
115*da0073e9SAndroid Build Coastguard Worker    def default_configs():
116*da0073e9SAndroid Build Coastguard Worker        return [[1 << 25]]
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Workerdef register_element_ops():
120*da0073e9SAndroid Build Coastguard Worker    binary_op_list = [
121*da0073e9SAndroid Build Coastguard Worker        ["mul", operator.mul],
122*da0073e9SAndroid Build Coastguard Worker        ["add", operator.add],
123*da0073e9SAndroid Build Coastguard Worker        ["sub", operator.sub],
124*da0073e9SAndroid Build Coastguard Worker        ["div", lambda a, b: a / (b + 1e-4)],
125*da0073e9SAndroid Build Coastguard Worker        [
126*da0073e9SAndroid Build Coastguard Worker            "pow",
127*da0073e9SAndroid Build Coastguard Worker            torch.pow,
128*da0073e9SAndroid Build Coastguard Worker            np.power,
129*da0073e9SAndroid Build Coastguard Worker        ],  # no fuson triggered
130*da0073e9SAndroid Build Coastguard Worker        ["max", torch.max, np.maximum],
131*da0073e9SAndroid Build Coastguard Worker        ["min", torch.min, np.minimum],
132*da0073e9SAndroid Build Coastguard Worker    ]
133*da0073e9SAndroid Build Coastguard Worker
134*da0073e9SAndroid Build Coastguard Worker    unary_op_list = [
135*da0073e9SAndroid Build Coastguard Worker        ["erf", torch.erf, scipy.special.erf],
136*da0073e9SAndroid Build Coastguard Worker        ["exp", torch.exp, np.exp],
137*da0073e9SAndroid Build Coastguard Worker        ["sin", torch.sin, np.sin],
138*da0073e9SAndroid Build Coastguard Worker        ["cos", torch.cos, np.cos],
139*da0073e9SAndroid Build Coastguard Worker        ["rand_like", torch.rand_like, lambda x: np.random.rand(*x.shape)],
140*da0073e9SAndroid Build Coastguard Worker    ]
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker    for split_input, binary_op in itertools.product([True, False], binary_op_list):
143*da0073e9SAndroid Build Coastguard Worker        # Make a copy of ElementBench
144*da0073e9SAndroid Build Coastguard Worker        if len(binary_op) == 2:
145*da0073e9SAndroid Build Coastguard Worker            [op_str, op_pt_func] = binary_op
146*da0073e9SAndroid Build Coastguard Worker            op_np_func = op_pt_func
147*da0073e9SAndroid Build Coastguard Worker        elif len(binary_op) == 3:
148*da0073e9SAndroid Build Coastguard Worker            [op_str, op_pt_func, op_np_func] = binary_op
149*da0073e9SAndroid Build Coastguard Worker        split_str = "split" if split_input else "shared"
150*da0073e9SAndroid Build Coastguard Worker        op_str = split_str + "_" + op_str
151*da0073e9SAndroid Build Coastguard Worker        bm_cls = type("ElementBench_" + op_str, (ElementBench,), {})
152*da0073e9SAndroid Build Coastguard Worker        bm_cls.op_str = op_str
153*da0073e9SAndroid Build Coastguard Worker        bm_cls.binary_op_pt_func = op_pt_func
154*da0073e9SAndroid Build Coastguard Worker        bm_cls.binary_op_np_func = op_np_func
155*da0073e9SAndroid Build Coastguard Worker        bm_cls.split_input = split_input
156*da0073e9SAndroid Build Coastguard Worker        benchmark.register_benchmark_class(bm_cls)
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker    for split_input, unary_op in itertools.product([True, False], unary_op_list):
159*da0073e9SAndroid Build Coastguard Worker        # Make a copy of ElementBench
160*da0073e9SAndroid Build Coastguard Worker        if len(unary_op) == 2:
161*da0073e9SAndroid Build Coastguard Worker            [op_str, op_pt_func] = unary_op
162*da0073e9SAndroid Build Coastguard Worker            op_np_func = op_pt_func
163*da0073e9SAndroid Build Coastguard Worker        elif len(unary_op) == 3:
164*da0073e9SAndroid Build Coastguard Worker            [op_str, op_pt_func, op_np_func] = unary_op
165*da0073e9SAndroid Build Coastguard Worker        split_str = "split" if split_input else "shared"
166*da0073e9SAndroid Build Coastguard Worker        op_str = split_str + "_" + op_str
167*da0073e9SAndroid Build Coastguard Worker        bm_cls = type("ElementBench_" + op_str, (ElementBench,), {})
168*da0073e9SAndroid Build Coastguard Worker        bm_cls.op_str = op_str
169*da0073e9SAndroid Build Coastguard Worker        bm_cls.unary_op_pt_func = op_pt_func
170*da0073e9SAndroid Build Coastguard Worker        bm_cls.unary_op_np_func = op_np_func
171*da0073e9SAndroid Build Coastguard Worker        bm_cls.split_input = split_input
172*da0073e9SAndroid Build Coastguard Worker        benchmark.register_benchmark_class(bm_cls)
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker# benchmark.register_benchmark_class(ElementMulBench)
176*da0073e9SAndroid Build Coastguard Workerregister_element_ops()
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Workerclass SimpleElementBench(benchmark.Benchmark):
180*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, N):
181*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype)
182*da0073e9SAndroid Build Coastguard Worker        self.N = N
183*da0073e9SAndroid Build Coastguard Worker        self.data = self.rand(
184*da0073e9SAndroid Build Coastguard Worker            [N], device=device, dtype=dtype, requires_grad=self.requires_grad
185*da0073e9SAndroid Build Coastguard Worker        )
186*da0073e9SAndroid Build Coastguard Worker        self.inputs = [self.data]
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker    def forward(self, data):
189*da0073e9SAndroid Build Coastguard Worker        a = data + 0.001
190*da0073e9SAndroid Build Coastguard Worker        b = a + 0.002
191*da0073e9SAndroid Build Coastguard Worker        return b
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker    def reference(self):
194*da0073e9SAndroid Build Coastguard Worker        binary_op = self.__class__.binary_op_np_func
195*da0073e9SAndroid Build Coastguard Worker        unary_op = self.__class__.unary_op_np_func
196*da0073e9SAndroid Build Coastguard Worker        [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]]
197*da0073e9SAndroid Build Coastguard Worker        return self._eval(d1, d2, d3, d4, binary_op, unary_op)
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker    def config(self):
200*da0073e9SAndroid Build Coastguard Worker        return [self.N]
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker    @staticmethod
203*da0073e9SAndroid Build Coastguard Worker    def input_iterable():
204*da0073e9SAndroid Build Coastguard Worker        return True
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker    @classmethod
207*da0073e9SAndroid Build Coastguard Worker    def module(cls):
208*da0073e9SAndroid Build Coastguard Worker        return "simple_element"
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker    def memory_workload(self):
211*da0073e9SAndroid Build Coastguard Worker        input_count = len(self.inputs)
212*da0073e9SAndroid Build Coastguard Worker        if self.mode == "fwd":
213*da0073e9SAndroid Build Coastguard Worker            sol_count = 2
214*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = 2
215*da0073e9SAndroid Build Coastguard Worker        else:
216*da0073e9SAndroid Build Coastguard Worker            sol_count = 2
217*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = 2
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker        buffer_size = self.N
220*da0073e9SAndroid Build Coastguard Worker        return {
221*da0073e9SAndroid Build Coastguard Worker            "sol": buffer_size * sol_count,
222*da0073e9SAndroid Build Coastguard Worker            "algorithmic": buffer_size * algorithmic_count,
223*da0073e9SAndroid Build Coastguard Worker        }
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker    @staticmethod
226*da0073e9SAndroid Build Coastguard Worker    def default_configs():
227*da0073e9SAndroid Build Coastguard Worker        return [[1 << 25]]
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(SimpleElementBench)
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Workerclass DynamicSimpleElementBench(benchmark.DynamicShape, SimpleElementBench):
234*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, N):
235*da0073e9SAndroid Build Coastguard Worker        benchmark.DynamicShape.__init__(self)
236*da0073e9SAndroid Build Coastguard Worker        SimpleElementBench.__init__(self, mode, device, dtype, N)
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker    @classmethod
239*da0073e9SAndroid Build Coastguard Worker    def module(cls):
240*da0073e9SAndroid Build Coastguard Worker        return "simple_dynamic_element"
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker    def instantiate_input(self):
243*da0073e9SAndroid Build Coastguard Worker        (N,) = self.rand_shape([self.N])
244*da0073e9SAndroid Build Coastguard Worker        data = self.rand(
245*da0073e9SAndroid Build Coastguard Worker            [N], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad
246*da0073e9SAndroid Build Coastguard Worker        )
247*da0073e9SAndroid Build Coastguard Worker        self.inputs = [data]
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(DynamicSimpleElementBench)
251