xref: /aosp_15_r20/external/pytorch/benchmarks/transformer/attention_bias_benchmarks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import itertools
2from dataclasses import asdict, dataclass
3from functools import partial
4from typing import Callable, List, Union
5
6import numpy as np
7from tabulate import tabulate
8from tqdm import tqdm
9
10import torch
11import torch.nn as nn
12import torch.nn.functional as F
13import torch.utils.benchmark as benchmark
14from torch.nn.attention.bias import CausalBias, CausalVariant
15from torch.nn.parameter import Parameter
16
17
18def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
19    # warmup
20    for _ in range(5):
21        func(*args, **kwargs)
22    t0 = benchmark.Timer(
23        stmt="func(*args, **kwargs)",
24        globals={"args": args, "kwargs": kwargs, "func": func},
25    )
26    return t0.adaptive_autorange(min_run_time=0.1).median * 1e6
27
28
29@dataclass(frozen=True)
30class ExperimentConfig:
31    batch_size: int
32    num_heads: int
33    q_seq_len: int
34    k_seq_len: int
35    embed_dim: int
36    dtype: torch.dtype
37
38    @property
39    def head_dim(self) -> int:
40        return self.embed_dim // self.num_heads
41
42    def asdict(self):
43        dict_obj = asdict(self)
44        dict_obj["head_dim"] = self.head_dim
45        return dict_obj
46
47
48@dataclass(frozen=True)
49class ExperimentResults:
50    materialized_mask_time: float
51    attn_mask_subclass_time: float
52
53    def get_entries(self) -> List:
54        return [
55            f"{self.materialized_mask_time:2f}",
56            f"{self.attn_mask_subclass_time:2f}",
57        ]
58
59
60@dataclass(frozen=True)
61class Experiment:
62    config: ExperimentConfig
63    results: ExperimentResults
64
65    def get_entries(self) -> List:
66        return self.config.get_entries() + self.results.get_entries()
67
68
69def generate_inputs(
70    batch_size, q_sequence_length, kv_sequence_length, embed_dim, dtype, device
71):
72    q_shape = (batch_size, q_sequence_length, embed_dim)
73    kv_shape = (batch_size, kv_sequence_length, embed_dim)
74
75    make_q = partial(torch.rand, q_shape, device=device, dtype=dtype)
76    make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
77    return make_q(), make_kv(), make_kv()
78
79
80class CompositeMHA(torch.nn.Module):
81    def __init__(self, num_heads, embed_dim, device=None, dtype=None):
82        factory_kwargs = {"device": device, "dtype": dtype}
83        super().__init__()
84
85        self.head_dim = embed_dim // num_heads
86        self.embed_dim = embed_dim
87        assert (
88            self.head_dim * num_heads == self.embed_dim
89        ), "embed_dim must be divisible by num_heads"
90
91        self.q_proj_weight = Parameter(
92            torch.empty((embed_dim, embed_dim), **factory_kwargs)
93        )
94        self.k_proj_weight = Parameter(
95            torch.empty((embed_dim, embed_dim), **factory_kwargs)
96        )
97        self.v_proj_weight = Parameter(
98            torch.empty((embed_dim, embed_dim), **factory_kwargs)
99        )
100        self.out_proj = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
101        self.num_heads = num_heads
102
103    def forward(
104        self,
105        query: torch.Tensor,
106        key: torch.Tensor,
107        value: torch.Tensor,
108        mask: Union[torch.Tensor, CausalBias],
109    ):
110        query_projected = F.linear(query, self.q_proj_weight)
111        key_projected = F.linear(key, self.k_proj_weight)
112        value_projected = F.linear(value, self.v_proj_weight)
113
114        query = query.view(
115            query_projected.size(0), -1, self.num_heads, self.head_dim
116        ).transpose(1, 2)
117        key = key.view(
118            key_projected.size(0), -1, self.num_heads, self.head_dim
119        ).transpose(1, 2)
120        value = value.view(
121            value_projected.size(0), -1, self.num_heads, self.head_dim
122        ).transpose(1, 2)
123
124        attn = torch.nn.functional.scaled_dot_product_attention(
125            query,
126            key,
127            value,
128            attn_mask=mask,
129            dropout_p=0.0,
130        )
131
132        attn = attn.transpose(1, 2).reshape(query.size(0), -1, self.embed_dim)
133        # Match return signature of nn.MHA
134        return F.linear(attn, self.out_proj)
135
136    def reset_parameters(self):
137        nn.init.xavier_uniform_(self.q_proj_weight)
138        nn.init.xavier_uniform_(self.k_proj_weight)
139        nn.init.xavier_uniform_(self.v_proj_weight)
140        nn.init.constant_(self.out_proj, 0.0)
141
142
143def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
144    device = torch.device("cuda")
145    composite_mha = CompositeMHA(
146        config.num_heads, config.embed_dim, device, config.dtype
147    )
148    composite_mha.reset_parameters()
149    query, key, value = generate_inputs(
150        config.batch_size,
151        config.q_seq_len,
152        config.k_seq_len,
153        config.embed_dim,
154        config.dtype,
155        device,
156    )
157    attn_mask = CausalBias(
158        CausalVariant.LOWER_RIGHT, config.q_seq_len, config.k_seq_len
159    )
160    attn_mask_tensor = attn_mask._materialize(device)
161
162    materialized_mask_time = benchmark_torch_function_in_microseconds(
163        composite_mha, query, key, value, attn_mask_tensor
164    )
165    attn_mask_subclass_time = benchmark_torch_function_in_microseconds(
166        composite_mha, query, key, value, attn_mask
167    )
168    torch.testing.assert_close(
169        composite_mha(query, key, value, attn_mask_tensor),
170        composite_mha(query, key, value, attn_mask),
171    )
172
173    return ExperimentResults(
174        materialized_mask_time=materialized_mask_time,
175        attn_mask_subclass_time=attn_mask_subclass_time,
176    )
177
178
179def generate_experiment_configs() -> List[ExperimentConfig]:
180    batch_sizes = [1, 8, 16, 128]
181    num_heads = [16, 32]
182    q_kv_seq_lens = [(128, 256), (256, 416), (512, 4097), (1024, 2048), (1, 2048)]
183    embed_dims = [2048, 4096]
184    dtypes = [
185        torch.bfloat16,
186    ]
187    all_configs = []
188    for bsz, heads, (q_seq_len, kv_seq_len), embed_dim, dtype in itertools.product(
189        batch_sizes, num_heads, q_kv_seq_lens, embed_dims, dtypes
190    ):
191        all_configs.append(
192            ExperimentConfig(
193                batch_size=bsz,
194                num_heads=heads,
195                q_seq_len=q_seq_len,
196                k_seq_len=kv_seq_len,
197                embed_dim=embed_dim,
198                dtype=dtype,
199            )
200        )
201
202    return all_configs
203
204
205def calculate_speedup(results: ExperimentResults) -> float:
206    return results.materialized_mask_time / results.attn_mask_subclass_time
207
208
209def print_results(results: List[Experiment]):
210    # Calculate speedups
211    speedups = [calculate_speedup(r.results) for r in results]
212
213    # Find indices of max and min speedups
214    max_speedup_index = np.argmax(speedups)
215    min_speedup_index = np.argmin(speedups)
216
217    # Get the config dictionaries
218    max_config_dict = results[max_speedup_index].config.asdict()
219    min_config_dict = results[min_speedup_index].config.asdict()
220
221    # Create table data
222    table_data = [
223        {
224            "Type": "Average",
225            "Speedup": np.mean(speedups),
226            **dict.fromkeys(max_config_dict),
227        },
228        {"Type": "Max", "Speedup": speedups[max_speedup_index], **max_config_dict},
229        {"Type": "Min", "Speedup": speedups[min_speedup_index], **min_config_dict},
230    ]
231
232    # Print table
233    print(tabulate(table_data, headers="keys", tablefmt="pretty"))
234
235
236def main():
237    seed = 123
238    np.random.seed(seed)
239    torch.manual_seed(seed)
240    results = []
241    # Run one timing experiment comparing nn_mha vs composite_mha
242    for config in tqdm(generate_experiment_configs()):
243        results.append(Experiment(config, run_single_experiment(config)))
244
245    print_results(results)
246
247
248if __name__ == "__main__":
249    main()
250