xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/softmax.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport scipy.special
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom . import benchmark
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerclass SoftmaxBench(benchmark.Benchmark):
7*da0073e9SAndroid Build Coastguard Worker    def __init__(self, mode, device, dtype, M, N):
8*da0073e9SAndroid Build Coastguard Worker        super().__init__(mode, device, dtype)
9*da0073e9SAndroid Build Coastguard Worker        self.M = M
10*da0073e9SAndroid Build Coastguard Worker        self.N = N
11*da0073e9SAndroid Build Coastguard Worker        self.dtype = dtype
12*da0073e9SAndroid Build Coastguard Worker        self.inputs = [
13*da0073e9SAndroid Build Coastguard Worker            self.randn(
14*da0073e9SAndroid Build Coastguard Worker                [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad
15*da0073e9SAndroid Build Coastguard Worker            )
16*da0073e9SAndroid Build Coastguard Worker        ]
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker    def forward(self, inputs):
19*da0073e9SAndroid Build Coastguard Worker        x = self.add(inputs, 0.001)
20*da0073e9SAndroid Build Coastguard Worker        y = self.softmax(x, dim=-1, dtype=self.dtype)
21*da0073e9SAndroid Build Coastguard Worker        return y
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker    def reference(self):
24*da0073e9SAndroid Build Coastguard Worker        return scipy.special.softmax(self.numpy(self.inputs), axis=-1)
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker    def config(self):
27*da0073e9SAndroid Build Coastguard Worker        return [self.M, self.N]
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker    @staticmethod
30*da0073e9SAndroid Build Coastguard Worker    def module():
31*da0073e9SAndroid Build Coastguard Worker        return "softmax"
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker    def memory_workload(self):
34*da0073e9SAndroid Build Coastguard Worker        if self.mode == "fwd":
35*da0073e9SAndroid Build Coastguard Worker            sol_count = 1 + 1
36*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = 3 + 1
37*da0073e9SAndroid Build Coastguard Worker        else:
38*da0073e9SAndroid Build Coastguard Worker            sol_count = (1 + 1) + (1 + 1)
39*da0073e9SAndroid Build Coastguard Worker            algorithmic_count = (3 + 1) + (3 + 1)
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker        buffer_size = self.M * self.N
42*da0073e9SAndroid Build Coastguard Worker        return {
43*da0073e9SAndroid Build Coastguard Worker            "sol": buffer_size * sol_count,
44*da0073e9SAndroid Build Coastguard Worker            "algorithmic": buffer_size * algorithmic_count,
45*da0073e9SAndroid Build Coastguard Worker        }
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker    @staticmethod
48*da0073e9SAndroid Build Coastguard Worker    def default_configs():
49*da0073e9SAndroid Build Coastguard Worker        return [
50*da0073e9SAndroid Build Coastguard Worker            [480, 20],
51*da0073e9SAndroid Build Coastguard Worker            [1 << 15, 32],
52*da0073e9SAndroid Build Coastguard Worker            [128, 1 << 16],
53*da0073e9SAndroid Build Coastguard Worker        ]
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Workerbenchmark.register_benchmark_class(SoftmaxBench)
57