xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/softmax.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import scipy.special
2
3from . import benchmark
4
5
6class SoftmaxBench(benchmark.Benchmark):
7    def __init__(self, mode, device, dtype, M, N):
8        super().__init__(mode, device, dtype)
9        self.M = M
10        self.N = N
11        self.dtype = dtype
12        self.inputs = [
13            self.randn(
14                [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad
15            )
16        ]
17
18    def forward(self, inputs):
19        x = self.add(inputs, 0.001)
20        y = self.softmax(x, dim=-1, dtype=self.dtype)
21        return y
22
23    def reference(self):
24        return scipy.special.softmax(self.numpy(self.inputs), axis=-1)
25
26    def config(self):
27        return [self.M, self.N]
28
29    @staticmethod
30    def module():
31        return "softmax"
32
33    def memory_workload(self):
34        if self.mode == "fwd":
35            sol_count = 1 + 1
36            algorithmic_count = 3 + 1
37        else:
38            sol_count = (1 + 1) + (1 + 1)
39            algorithmic_count = (3 + 1) + (3 + 1)
40
41        buffer_size = self.M * self.N
42        return {
43            "sol": buffer_size * sol_count,
44            "algorithmic": buffer_size * algorithmic_count,
45        }
46
47    @staticmethod
48    def default_configs():
49        return [
50            [480, 20],
51            [1 << 15, 32],
52            [128, 1 << 16],
53        ]
54
55
56benchmark.register_benchmark_class(SoftmaxBench)
57