xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/elementwise.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import itertools
2import operator
3
4import numpy as np
5import scipy.special
6
7import torch
8
9from . import benchmark
10
11
12# A template class for elementwise operations.
13# A derived class will override the class instance to customize its behavior.
14class ElementBench(benchmark.Benchmark):
15    # List of customization class variables.
16    op_str = None
17    binary_op_pt_func = None
18    binary_op_np_func = None
19    unary_op_pt_func = None
20    unary_op_np_func = None
21    split_input = True
22
23    def __init__(self, mode, device, dtype, N):
24        super().__init__(mode, device, dtype)
25        self.N = N
26        self.d1 = self.rand(
27            [N], device=device, dtype=dtype, requires_grad=self.requires_grad
28        )
29        self.d2 = self.rand(
30            [N], device=device, dtype=dtype, requires_grad=self.requires_grad
31        )
32        self.d3 = self.rand(
33            [N], device=device, dtype=dtype, requires_grad=self.requires_grad
34        )
35        self.d4 = self.rand(
36            [N], device=device, dtype=dtype, requires_grad=self.requires_grad
37        )
38        self.inputs = [self.d1, self.d2, self.d3, self.d4]
39        self.deterministic = "rand" not in self.op_str
40
41    def _eval(self, d1, d2, d3, d4, binary_op, unary_op):
42        if not binary_op:
43
44            def binary_op(x, y):
45                return x + y
46
47        if not unary_op:
48
49            def unary_op(x):
50                return x
51
52        if self.split_input:
53            d1 = unary_op(d1)
54            d2 = unary_op(d2)
55            d3 = unary_op(d3)
56            d4 = unary_op(d4)
57        else:
58            d2 = unary_op(d1 + 0.001)
59            d3 = unary_op(d1 + 0.002)
60            d4 = unary_op(d1 + 0.003)
61            d1 = unary_op(d1)
62        a = binary_op(d1, d2)
63        b = binary_op(d3, d4)
64        c = a + b
65        return c
66
67    def forward(self, d1, d2, d3, d4):
68        binary_op = self.__class__.binary_op_pt_func
69        unary_op = self.__class__.unary_op_pt_func
70        return self._eval(d1, d2, d3, d4, binary_op, unary_op)
71
72    def reference(self):
73        binary_op = self.__class__.binary_op_np_func
74        unary_op = self.__class__.unary_op_np_func
75        [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]]
76        return self._eval(d1, d2, d3, d4, binary_op, unary_op)
77
78    def config(self):
79        return [self.N]
80
81    @classmethod
82    def module(cls):
83        return "element_" + cls.op_str
84
85    def memory_workload(self):
86        input_count = len(self.inputs)
87        if self.mode == "fwd":
88            if self.split_input:
89                sol_count = input_count + 1
90                algorithmic_count = input_count + 1
91            else:
92                sol_count = 1 + 1
93                algorithmic_count = 1 + 1
94            if "rand" in self.op_str:
95                sol_count = 1
96                algorithmic_count = 1
97        else:
98            if self.split_input:
99                sol_count = (input_count + 1) + (1 + input_count)
100                algorithmic_count = (input_count + 1) + ((2 + 1) * input_count)
101            else:
102                sol_count = 1 + 1
103                algorithmic_count = 1 + 1
104            if "rand" in self.op_str:
105                sol_count = 1
106                algorithmic_count = 1
107
108        buffer_size = self.N
109        return {
110            "sol": buffer_size * sol_count,
111            "algorithmic": buffer_size * algorithmic_count,
112        }
113
114    @staticmethod
115    def default_configs():
116        return [[1 << 25]]
117
118
119def register_element_ops():
120    binary_op_list = [
121        ["mul", operator.mul],
122        ["add", operator.add],
123        ["sub", operator.sub],
124        ["div", lambda a, b: a / (b + 1e-4)],
125        [
126            "pow",
127            torch.pow,
128            np.power,
129        ],  # no fuson triggered
130        ["max", torch.max, np.maximum],
131        ["min", torch.min, np.minimum],
132    ]
133
134    unary_op_list = [
135        ["erf", torch.erf, scipy.special.erf],
136        ["exp", torch.exp, np.exp],
137        ["sin", torch.sin, np.sin],
138        ["cos", torch.cos, np.cos],
139        ["rand_like", torch.rand_like, lambda x: np.random.rand(*x.shape)],
140    ]
141
142    for split_input, binary_op in itertools.product([True, False], binary_op_list):
143        # Make a copy of ElementBench
144        if len(binary_op) == 2:
145            [op_str, op_pt_func] = binary_op
146            op_np_func = op_pt_func
147        elif len(binary_op) == 3:
148            [op_str, op_pt_func, op_np_func] = binary_op
149        split_str = "split" if split_input else "shared"
150        op_str = split_str + "_" + op_str
151        bm_cls = type("ElementBench_" + op_str, (ElementBench,), {})
152        bm_cls.op_str = op_str
153        bm_cls.binary_op_pt_func = op_pt_func
154        bm_cls.binary_op_np_func = op_np_func
155        bm_cls.split_input = split_input
156        benchmark.register_benchmark_class(bm_cls)
157
158    for split_input, unary_op in itertools.product([True, False], unary_op_list):
159        # Make a copy of ElementBench
160        if len(unary_op) == 2:
161            [op_str, op_pt_func] = unary_op
162            op_np_func = op_pt_func
163        elif len(unary_op) == 3:
164            [op_str, op_pt_func, op_np_func] = unary_op
165        split_str = "split" if split_input else "shared"
166        op_str = split_str + "_" + op_str
167        bm_cls = type("ElementBench_" + op_str, (ElementBench,), {})
168        bm_cls.op_str = op_str
169        bm_cls.unary_op_pt_func = op_pt_func
170        bm_cls.unary_op_np_func = op_np_func
171        bm_cls.split_input = split_input
172        benchmark.register_benchmark_class(bm_cls)
173
174
175# benchmark.register_benchmark_class(ElementMulBench)
176register_element_ops()
177
178
179class SimpleElementBench(benchmark.Benchmark):
180    def __init__(self, mode, device, dtype, N):
181        super().__init__(mode, device, dtype)
182        self.N = N
183        self.data = self.rand(
184            [N], device=device, dtype=dtype, requires_grad=self.requires_grad
185        )
186        self.inputs = [self.data]
187
188    def forward(self, data):
189        a = data + 0.001
190        b = a + 0.002
191        return b
192
193    def reference(self):
194        binary_op = self.__class__.binary_op_np_func
195        unary_op = self.__class__.unary_op_np_func
196        [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]]
197        return self._eval(d1, d2, d3, d4, binary_op, unary_op)
198
199    def config(self):
200        return [self.N]
201
202    @staticmethod
203    def input_iterable():
204        return True
205
206    @classmethod
207    def module(cls):
208        return "simple_element"
209
210    def memory_workload(self):
211        input_count = len(self.inputs)
212        if self.mode == "fwd":
213            sol_count = 2
214            algorithmic_count = 2
215        else:
216            sol_count = 2
217            algorithmic_count = 2
218
219        buffer_size = self.N
220        return {
221            "sol": buffer_size * sol_count,
222            "algorithmic": buffer_size * algorithmic_count,
223        }
224
225    @staticmethod
226    def default_configs():
227        return [[1 << 25]]
228
229
230benchmark.register_benchmark_class(SimpleElementBench)
231
232
233class DynamicSimpleElementBench(benchmark.DynamicShape, SimpleElementBench):
234    def __init__(self, mode, device, dtype, N):
235        benchmark.DynamicShape.__init__(self)
236        SimpleElementBench.__init__(self, mode, device, dtype, N)
237
238    @classmethod
239    def module(cls):
240        return "simple_dynamic_element"
241
242    def instantiate_input(self):
243        (N,) = self.rand_shape([self.N])
244        data = self.rand(
245            [N], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad
246        )
247        self.inputs = [data]
248
249
250benchmark.register_benchmark_class(DynamicSimpleElementBench)
251