xref: /aosp_15_r20/external/pytorch/benchmarks/tensorexpr/attention.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# This is a copy of rnn_attention from MLPerf, with some common sizes hardcoded
2# for benchmarking and some control flow stripped out.
3# https://github.com/mlperf/training/blob/master/rnn_translator/pytorch/seq2seq/models/attention.py
4
5import torch
6
7from . import benchmark
8
9
10class BahdanauAttention(benchmark.Benchmark):
11    def __init__(self, mode, device, dtype, b, t_q, t_k, n):
12        super().__init__(mode, device, dtype)
13        self.b = b
14        self.t_q = t_q
15        self.t_k = t_k
16        self.n = n
17        self.att_query = self.rand(
18            [b, t_q, n], device=device, dtype=dtype, requires_grad=self.requires_grad
19        )
20        self.att_keys = self.rand(
21            [b, t_k, n], device=device, dtype=dtype, requires_grad=self.requires_grad
22        )
23        self.normalize_bias = self.rand(
24            [n], device=device, dtype=dtype, requires_grad=self.requires_grad
25        )
26        self.linear_att = self.rand(
27            [n], device=device, dtype=dtype, requires_grad=self.requires_grad
28        )
29        self.inputs = [
30            self.att_query,
31            self.att_keys,
32            self.normalize_bias,
33            self.linear_att,
34        ]
35
36    def forward(self, att_query, att_keys, normalize_bias, linear_att):
37        """
38        Calculate Bahdanau score
39
40        :param att_query: b x t_q x n
41        :param att_keys: b x t_k x n
42
43        return b x t_q x t_k scores
44        """
45
46        b, t_k, n = att_keys.size()
47        t_q = att_query.size(1)
48
49        att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n)
50        att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n)
51        sum_qk = att_query + att_keys + normalize_bias
52        out = torch.tanh(sum_qk).matmul(linear_att)
53        return out
54
55    def reference(self):
56        return self.numpy(self.forward(*self.inputs))
57
58    def config(self):
59        return [self.b, self.t_q, self.t_k, self.n]
60
61    @staticmethod
62    def module():
63        return "attention"
64
65    def memory_workload(self):
66        def memsize(t):
67            return t.numel() * t.element_size()
68
69        input_size = (
70            memsize(self.att_query)
71            + memsize(self.att_keys)
72            + memsize(self.normalize_bias)
73            + memsize(self.linear_att)
74        )
75        output_size = 4 * torch.Size([self.b, self.t_q, self.t_k]).numel()
76        io_size = input_size + output_size
77
78        # If matmul is not fused, must write and then read `sum_qk`.
79        intermediate_size = (
80            2 * 4 * torch.Size([self.b, self.t_q, self.t_k, self.n]).numel()
81        )
82        return {"sol": io_size, "algorithmic": io_size + intermediate_size}
83
84    @staticmethod
85    def default_configs():
86        mlperf_inference = [1280, 1, 66, 1024]
87        nvidia = [128, 10, 128, 1024]
88        return [mlperf_inference, nvidia]
89
90
91benchmark.register_benchmark_class(BahdanauAttention)
92