xref: /aosp_15_r20/external/pytorch/benchmarks/transformer/sdp.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import itertools
3import random
4import warnings
5from dataclasses import dataclass
6from pathlib import Path
7from pprint import pprint
8from typing import List, Optional
9
10import numpy as np
11from prettytable import PrettyTable
12from tqdm import tqdm
13
14import torch
15import torch.utils.benchmark as benchmark
16from torch.backends.cuda import sdp_kernel
17
18
19warnings.filterwarnings("ignore")
20
21
22@dataclass(frozen=True)
23class ExperimentConfig:
24    batch_size: int
25    num_heads: int
26    max_sequence_len: int
27    embed_dimension: int
28    dtype: torch.dtype
29    pad_percentage: Optional[float]
30    enable_math: bool
31    enable_flash: bool
32    enable_mem_efficient: bool
33    enable_cudnn: bool
34
35    def get_entries(self) -> List:
36        return [
37            self.batch_size,
38            self.num_heads,
39            self.max_sequence_len,
40            self.embed_dimension,
41            self.dtype,
42            self.pad_percentage,
43            self.enable_math,
44            self.enable_flash,
45            self.enable_mem_efficient,
46            self.enable_cudnn,
47        ]
48
49    @classmethod
50    def get_entry_names(cls) -> List[str]:
51        return [
52            "batch_size",
53            "num_heads",
54            "max_sequence_len",
55            "embed_dimension",
56            "dtype",
57            "pad_percentage",
58            "enable_math",
59            "enable_flash",
60            "enable_mem_efficient",
61            "enable_cudnn",
62        ]
63
64
65@dataclass(frozen=True)
66class ExperimentResults:
67    nn_mha_time: float
68    compiled_nn_mha_time: Optional[float]
69    composite_mha_time: float
70    compiled_composite_mha_time: Optional[float]
71
72    def get_entries(self) -> List:
73        return [
74            f"{self.nn_mha_time:2f}",
75            f"{self.compiled_nn_mha_time:2f}" if self.compiled_nn_mha_time else None,
76            f"{self.composite_mha_time:2f}",
77            f"{self.compiled_composite_mha_time:2f}"
78            if self.compiled_composite_mha_time
79            else None,
80        ]
81
82    @classmethod
83    def get_entry_names(cls) -> List[str]:
84        return [
85            "nn_mha_time (\u00B5s)",
86            "compiled_nn_mha_time (\u00B5s)",
87            "composite_mha_time (\u00B5s)",
88            "compiled_composite_mha_time (\u00B5s)",
89        ]
90
91
92@dataclass(frozen=True)
93class Experiment:
94    config: ExperimentConfig
95    results: ExperimentResults
96
97    def get_entries(self) -> List:
98        return self.config.get_entries() + self.results.get_entries()
99
100
101class CompositeMHA(torch.nn.Module):
102    def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj):
103        super().__init__()
104        self.in_proj_weight = in_proj_weight
105        self.in_proj_bias = in_proj_bias
106        self.out_proj = out_proj
107        self.num_heads = num_heads
108
109    def forward(self, query, key, value, mask):
110        if not (query is key and key is value):
111            raise NotImplementedError(
112                "query, key and value must be the same Tensor for now."
113            )
114        if mask is not None:
115            raise NotImplementedError("mask is currently not supported.")
116
117        query_projected = torch.nn.functional.linear(
118            query, self.in_proj_weight, self.in_proj_bias
119        )
120
121        batch_size = query_projected.size(0)
122        embed_dim = query_projected.size(2)
123        head_dim = embed_dim // (self.num_heads * 3)
124
125        query, key, value = query_projected.chunk(3, -1)
126
127        query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
128        key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
129        value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
130
131        # the output of sdp = (batch, num_heads, seq_len, head_dim)
132        attn = torch.nn.functional.scaled_dot_product_attention(
133            query,
134            key,
135            value,
136            attn_mask=None,
137            dropout_p=0.0,
138            is_causal=False,
139        )
140
141        attn = attn.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim)
142        # Match return signature of nn.MHA
143        return self.out_proj(attn), None
144
145
146def build_composite_mha_from_nn_mha(pt):
147    assert pt._qkv_same_embed_dim
148    in_proj_weight = pt.in_proj_weight
149    assert in_proj_weight is not None
150    assert pt.batch_first
151    return CompositeMHA(pt.num_heads, pt.in_proj_weight, pt.in_proj_bias, pt.out_proj)
152
153
154def generate_rand_batch(
155    batch_size,
156    max_sequence_len,
157    embed_dimension,
158    pad_percentage=None,
159    dtype=torch.float16,
160    device="cuda",
161):
162    if not pad_percentage:
163        return (
164            torch.randn(
165                batch_size,
166                max_sequence_len,
167                embed_dimension,
168                dtype=dtype,
169                device=device,
170            ),
171            None,
172        )
173    # Really slow but should work
174    seq_len_list = [
175        int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
176        for _ in range(batch_size)
177    ]
178    # Make random ele max length
179    seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
180    # print(f"Theoretical padding: {pad_percentage} actual: {1 - (sum(seq_len_list) / (batch_size * max_sequence_len))}")
181    return (
182        torch.nested.nested_tensor(
183            [
184                torch.randn(seq_len, embed_dimension, dtype=dtype, device=device)
185                for seq_len in seq_len_list
186            ]
187        ),
188        seq_len_list,
189    )
190
191
192def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
193    t0 = benchmark.Timer(
194        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
195    )
196    return t0.blocked_autorange().mean * 1e6
197
198
199def assert_close_tensors(tensor_a, tensor_b):
200    # First order sanity check. Not a replacement for rigorous tests.
201    if tensor_a.is_nested and tensor_b.is_nested:
202        for a, b in zip(tensor_a.unbind(), tensor_b.unbind()):
203            assert torch.allclose(a, b, atol=1e-2, rtol=1e-2)
204    else:
205        assert torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
206
207
208def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
209    with sdp_kernel(
210        enable_math=config.enable_math,
211        enable_flash=config.enable_flash,
212        enable_mem_efficient=config.enable_mem_efficient,
213        enable_cudnn=config.enable_cudnn,
214    ) as kernel_choice, torch.inference_mode() as inference_mode:
215        dropout_p = 0.0
216        mask = None
217
218        nn_mha = torch.nn.MultiheadAttention(
219            embed_dim=config.embed_dimension,
220            num_heads=config.num_heads,
221            batch_first=True,
222            dropout=dropout_p,
223        )
224        nn_mha = nn_mha.eval().to("cuda", config.dtype)
225        composite_mha = build_composite_mha_from_nn_mha(nn_mha)
226        qkv, lengths = generate_rand_batch(
227            config.batch_size,
228            config.max_sequence_len,
229            config.embed_dimension,
230            config.pad_percentage,
231            config.dtype,
232        )
233        nn_mha_output, _ = nn_mha(qkv, qkv, qkv, mask)
234        composite_mha_output, _ = composite_mha(qkv, qkv, qkv, mask)
235
236        # First order sanity check
237        assert_close_tensors(nn_mha_output, composite_mha_output)
238
239        nn_mha_time = benchmark_torch_function_in_microseconds(
240            nn_mha, qkv, qkv, qkv, mask
241        )
242        composite_mha_time = benchmark_torch_function_in_microseconds(
243            composite_mha, qkv, qkv, qkv, mask
244        )
245
246        # TorchDynamo will error on NestedTensors
247        if config.pad_percentage is None:
248            compiled_nn_mha = torch.compile(nn_mha)
249            compiled_composite_mha = torch.compile(composite_mha)
250
251            compiled_nn_mha_time = benchmark_torch_function_in_microseconds(
252                compiled_nn_mha, qkv, qkv, qkv, mask
253            )
254
255            compiled_composite_mha_time = benchmark_torch_function_in_microseconds(
256                compiled_composite_mha,
257                qkv,
258                qkv,
259                qkv,
260                mask,
261            )
262        else:
263            compiled_nn_mha_time = None
264            compiled_composite_mha_time = None
265
266        results = ExperimentResults(
267            nn_mha_time,
268            compiled_nn_mha_time,
269            composite_mha_time,
270            compiled_composite_mha_time,
271        )
272        return Experiment(config, results)
273
274
275# Could return generator
276def generate_experiments(
277    batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages
278) -> List[ExperimentConfig]:
279    configs = []
280    for bsz, n_heads, seq_len, embed_dim, dtype, padding in itertools.product(
281        batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages
282    ):
283        configs.append(
284            ExperimentConfig(
285                batch_size=bsz,
286                num_heads=n_heads,
287                max_sequence_len=seq_len,
288                embed_dimension=embed_dim,
289                dtype=dtype,
290                pad_percentage=padding,
291                enable_math=False,
292                enable_flash=True,
293                enable_mem_efficient=True,
294                enable_cudnn=True,
295            )
296        )
297    return configs
298
299
300def main(save_path: Optional[Path]):
301    seed = 123
302    np.random.seed(seed)
303    torch.manual_seed(seed)
304
305    # Run one timing experiment comparing nn_mha vs composite_mha
306    config = ExperimentConfig(
307        batch_size=128,
308        num_heads=8,
309        max_sequence_len=512,
310        embed_dimension=512,
311        dtype=torch.float16,
312        pad_percentage=None,
313        enable_math=False,
314        enable_flash=True,
315        enable_mem_efficient=True,
316        enable_cudnn=True,
317    )
318
319    experiment = run_single_experiment(config)
320    pprint(experiment)
321
322    table = PrettyTable()
323    table.float_format = ".3"
324    table.field_names = (
325        ExperimentConfig.get_entry_names() + ExperimentResults.get_entry_names()
326    )
327
328    # Run a bunch of experiments
329    batch_sizes = [256]
330    num_heads = [32]
331    max_seq_lens = [256]
332    embed_dims = [512]
333    dtypes = [torch.bfloat16, torch.float16, torch.float32]
334    pad_percentages = [None, 0.9]
335
336    experiment_configs = generate_experiments(
337        batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages
338    )
339
340    experiments: List[Experiment] = []
341    for experiment_config in tqdm(experiment_configs):
342        experiment = run_single_experiment(experiment_config)
343        experiments.append(experiment)
344        table.add_row(experiment.get_entries())
345
346    print(table)
347
348    csv_string = table.get_csv_string()
349    if save_path is not None:
350        with open(save_path, "w") as csvfile:
351            csvfile.write(csv_string)
352
353
354if __name__ == "__main__":
355    parser = argparse.ArgumentParser()
356    parser.add_argument(
357        "--save-path", "--save_path", type=str, help="Path to save the results"
358    )
359
360    args = parser.parse_args()
361    save_path = Path(args.save_path) if args.save_path else None
362    main(save_path)
363