xref: /aosp_15_r20/external/pytorch/benchmarks/transformer/sdp.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Workerimport itertools
3*da0073e9SAndroid Build Coastguard Workerimport random
4*da0073e9SAndroid Build Coastguard Workerimport warnings
5*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass
6*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path
7*da0073e9SAndroid Build Coastguard Workerfrom pprint import pprint
8*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Optional
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerimport numpy as np
11*da0073e9SAndroid Build Coastguard Workerfrom prettytable import PrettyTable
12*da0073e9SAndroid Build Coastguard Workerfrom tqdm import tqdm
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerimport torch
15*da0073e9SAndroid Build Coastguard Workerimport torch.utils.benchmark as benchmark
16*da0073e9SAndroid Build Coastguard Workerfrom torch.backends.cuda import sdp_kernel
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Workerwarnings.filterwarnings("ignore")
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
23*da0073e9SAndroid Build Coastguard Workerclass ExperimentConfig:
24*da0073e9SAndroid Build Coastguard Worker    batch_size: int
25*da0073e9SAndroid Build Coastguard Worker    num_heads: int
26*da0073e9SAndroid Build Coastguard Worker    max_sequence_len: int
27*da0073e9SAndroid Build Coastguard Worker    embed_dimension: int
28*da0073e9SAndroid Build Coastguard Worker    dtype: torch.dtype
29*da0073e9SAndroid Build Coastguard Worker    pad_percentage: Optional[float]
30*da0073e9SAndroid Build Coastguard Worker    enable_math: bool
31*da0073e9SAndroid Build Coastguard Worker    enable_flash: bool
32*da0073e9SAndroid Build Coastguard Worker    enable_mem_efficient: bool
33*da0073e9SAndroid Build Coastguard Worker    enable_cudnn: bool
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker    def get_entries(self) -> List:
36*da0073e9SAndroid Build Coastguard Worker        return [
37*da0073e9SAndroid Build Coastguard Worker            self.batch_size,
38*da0073e9SAndroid Build Coastguard Worker            self.num_heads,
39*da0073e9SAndroid Build Coastguard Worker            self.max_sequence_len,
40*da0073e9SAndroid Build Coastguard Worker            self.embed_dimension,
41*da0073e9SAndroid Build Coastguard Worker            self.dtype,
42*da0073e9SAndroid Build Coastguard Worker            self.pad_percentage,
43*da0073e9SAndroid Build Coastguard Worker            self.enable_math,
44*da0073e9SAndroid Build Coastguard Worker            self.enable_flash,
45*da0073e9SAndroid Build Coastguard Worker            self.enable_mem_efficient,
46*da0073e9SAndroid Build Coastguard Worker            self.enable_cudnn,
47*da0073e9SAndroid Build Coastguard Worker        ]
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker    @classmethod
50*da0073e9SAndroid Build Coastguard Worker    def get_entry_names(cls) -> List[str]:
51*da0073e9SAndroid Build Coastguard Worker        return [
52*da0073e9SAndroid Build Coastguard Worker            "batch_size",
53*da0073e9SAndroid Build Coastguard Worker            "num_heads",
54*da0073e9SAndroid Build Coastguard Worker            "max_sequence_len",
55*da0073e9SAndroid Build Coastguard Worker            "embed_dimension",
56*da0073e9SAndroid Build Coastguard Worker            "dtype",
57*da0073e9SAndroid Build Coastguard Worker            "pad_percentage",
58*da0073e9SAndroid Build Coastguard Worker            "enable_math",
59*da0073e9SAndroid Build Coastguard Worker            "enable_flash",
60*da0073e9SAndroid Build Coastguard Worker            "enable_mem_efficient",
61*da0073e9SAndroid Build Coastguard Worker            "enable_cudnn",
62*da0073e9SAndroid Build Coastguard Worker        ]
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
66*da0073e9SAndroid Build Coastguard Workerclass ExperimentResults:
67*da0073e9SAndroid Build Coastguard Worker    nn_mha_time: float
68*da0073e9SAndroid Build Coastguard Worker    compiled_nn_mha_time: Optional[float]
69*da0073e9SAndroid Build Coastguard Worker    composite_mha_time: float
70*da0073e9SAndroid Build Coastguard Worker    compiled_composite_mha_time: Optional[float]
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Worker    def get_entries(self) -> List:
73*da0073e9SAndroid Build Coastguard Worker        return [
74*da0073e9SAndroid Build Coastguard Worker            f"{self.nn_mha_time:2f}",
75*da0073e9SAndroid Build Coastguard Worker            f"{self.compiled_nn_mha_time:2f}" if self.compiled_nn_mha_time else None,
76*da0073e9SAndroid Build Coastguard Worker            f"{self.composite_mha_time:2f}",
77*da0073e9SAndroid Build Coastguard Worker            f"{self.compiled_composite_mha_time:2f}"
78*da0073e9SAndroid Build Coastguard Worker            if self.compiled_composite_mha_time
79*da0073e9SAndroid Build Coastguard Worker            else None,
80*da0073e9SAndroid Build Coastguard Worker        ]
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    @classmethod
83*da0073e9SAndroid Build Coastguard Worker    def get_entry_names(cls) -> List[str]:
84*da0073e9SAndroid Build Coastguard Worker        return [
85*da0073e9SAndroid Build Coastguard Worker            "nn_mha_time (\u00B5s)",
86*da0073e9SAndroid Build Coastguard Worker            "compiled_nn_mha_time (\u00B5s)",
87*da0073e9SAndroid Build Coastguard Worker            "composite_mha_time (\u00B5s)",
88*da0073e9SAndroid Build Coastguard Worker            "compiled_composite_mha_time (\u00B5s)",
89*da0073e9SAndroid Build Coastguard Worker        ]
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
93*da0073e9SAndroid Build Coastguard Workerclass Experiment:
94*da0073e9SAndroid Build Coastguard Worker    config: ExperimentConfig
95*da0073e9SAndroid Build Coastguard Worker    results: ExperimentResults
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    def get_entries(self) -> List:
98*da0073e9SAndroid Build Coastguard Worker        return self.config.get_entries() + self.results.get_entries()
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Workerclass CompositeMHA(torch.nn.Module):
102*da0073e9SAndroid Build Coastguard Worker    def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj):
103*da0073e9SAndroid Build Coastguard Worker        super().__init__()
104*da0073e9SAndroid Build Coastguard Worker        self.in_proj_weight = in_proj_weight
105*da0073e9SAndroid Build Coastguard Worker        self.in_proj_bias = in_proj_bias
106*da0073e9SAndroid Build Coastguard Worker        self.out_proj = out_proj
107*da0073e9SAndroid Build Coastguard Worker        self.num_heads = num_heads
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker    def forward(self, query, key, value, mask):
110*da0073e9SAndroid Build Coastguard Worker        if not (query is key and key is value):
111*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError(
112*da0073e9SAndroid Build Coastguard Worker                "query, key and value must be the same Tensor for now."
113*da0073e9SAndroid Build Coastguard Worker            )
114*da0073e9SAndroid Build Coastguard Worker        if mask is not None:
115*da0073e9SAndroid Build Coastguard Worker            raise NotImplementedError("mask is currently not supported.")
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker        query_projected = torch.nn.functional.linear(
118*da0073e9SAndroid Build Coastguard Worker            query, self.in_proj_weight, self.in_proj_bias
119*da0073e9SAndroid Build Coastguard Worker        )
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker        batch_size = query_projected.size(0)
122*da0073e9SAndroid Build Coastguard Worker        embed_dim = query_projected.size(2)
123*da0073e9SAndroid Build Coastguard Worker        head_dim = embed_dim // (self.num_heads * 3)
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker        query, key, value = query_projected.chunk(3, -1)
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker        query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
128*da0073e9SAndroid Build Coastguard Worker        key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
129*da0073e9SAndroid Build Coastguard Worker        value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker        # the output of sdp = (batch, num_heads, seq_len, head_dim)
132*da0073e9SAndroid Build Coastguard Worker        attn = torch.nn.functional.scaled_dot_product_attention(
133*da0073e9SAndroid Build Coastguard Worker            query,
134*da0073e9SAndroid Build Coastguard Worker            key,
135*da0073e9SAndroid Build Coastguard Worker            value,
136*da0073e9SAndroid Build Coastguard Worker            attn_mask=None,
137*da0073e9SAndroid Build Coastguard Worker            dropout_p=0.0,
138*da0073e9SAndroid Build Coastguard Worker            is_causal=False,
139*da0073e9SAndroid Build Coastguard Worker        )
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker        attn = attn.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim)
142*da0073e9SAndroid Build Coastguard Worker        # Match return signature of nn.MHA
143*da0073e9SAndroid Build Coastguard Worker        return self.out_proj(attn), None
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Workerdef build_composite_mha_from_nn_mha(pt):
147*da0073e9SAndroid Build Coastguard Worker    assert pt._qkv_same_embed_dim
148*da0073e9SAndroid Build Coastguard Worker    in_proj_weight = pt.in_proj_weight
149*da0073e9SAndroid Build Coastguard Worker    assert in_proj_weight is not None
150*da0073e9SAndroid Build Coastguard Worker    assert pt.batch_first
151*da0073e9SAndroid Build Coastguard Worker    return CompositeMHA(pt.num_heads, pt.in_proj_weight, pt.in_proj_bias, pt.out_proj)
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Workerdef generate_rand_batch(
155*da0073e9SAndroid Build Coastguard Worker    batch_size,
156*da0073e9SAndroid Build Coastguard Worker    max_sequence_len,
157*da0073e9SAndroid Build Coastguard Worker    embed_dimension,
158*da0073e9SAndroid Build Coastguard Worker    pad_percentage=None,
159*da0073e9SAndroid Build Coastguard Worker    dtype=torch.float16,
160*da0073e9SAndroid Build Coastguard Worker    device="cuda",
161*da0073e9SAndroid Build Coastguard Worker):
162*da0073e9SAndroid Build Coastguard Worker    if not pad_percentage:
163*da0073e9SAndroid Build Coastguard Worker        return (
164*da0073e9SAndroid Build Coastguard Worker            torch.randn(
165*da0073e9SAndroid Build Coastguard Worker                batch_size,
166*da0073e9SAndroid Build Coastguard Worker                max_sequence_len,
167*da0073e9SAndroid Build Coastguard Worker                embed_dimension,
168*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
169*da0073e9SAndroid Build Coastguard Worker                device=device,
170*da0073e9SAndroid Build Coastguard Worker            ),
171*da0073e9SAndroid Build Coastguard Worker            None,
172*da0073e9SAndroid Build Coastguard Worker        )
173*da0073e9SAndroid Build Coastguard Worker    # Really slow but should work
174*da0073e9SAndroid Build Coastguard Worker    seq_len_list = [
175*da0073e9SAndroid Build Coastguard Worker        int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
176*da0073e9SAndroid Build Coastguard Worker        for _ in range(batch_size)
177*da0073e9SAndroid Build Coastguard Worker    ]
178*da0073e9SAndroid Build Coastguard Worker    # Make random ele max length
179*da0073e9SAndroid Build Coastguard Worker    seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
180*da0073e9SAndroid Build Coastguard Worker    # print(f"Theoretical padding: {pad_percentage} actual: {1 - (sum(seq_len_list) / (batch_size * max_sequence_len))}")
181*da0073e9SAndroid Build Coastguard Worker    return (
182*da0073e9SAndroid Build Coastguard Worker        torch.nested.nested_tensor(
183*da0073e9SAndroid Build Coastguard Worker            [
184*da0073e9SAndroid Build Coastguard Worker                torch.randn(seq_len, embed_dimension, dtype=dtype, device=device)
185*da0073e9SAndroid Build Coastguard Worker                for seq_len in seq_len_list
186*da0073e9SAndroid Build Coastguard Worker            ]
187*da0073e9SAndroid Build Coastguard Worker        ),
188*da0073e9SAndroid Build Coastguard Worker        seq_len_list,
189*da0073e9SAndroid Build Coastguard Worker    )
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Workerdef benchmark_torch_function_in_microseconds(f, *args, **kwargs):
193*da0073e9SAndroid Build Coastguard Worker    t0 = benchmark.Timer(
194*da0073e9SAndroid Build Coastguard Worker        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
195*da0073e9SAndroid Build Coastguard Worker    )
196*da0073e9SAndroid Build Coastguard Worker    return t0.blocked_autorange().mean * 1e6
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Workerdef assert_close_tensors(tensor_a, tensor_b):
200*da0073e9SAndroid Build Coastguard Worker    # First order sanity check. Not a replacement for rigorous tests.
201*da0073e9SAndroid Build Coastguard Worker    if tensor_a.is_nested and tensor_b.is_nested:
202*da0073e9SAndroid Build Coastguard Worker        for a, b in zip(tensor_a.unbind(), tensor_b.unbind()):
203*da0073e9SAndroid Build Coastguard Worker            assert torch.allclose(a, b, atol=1e-2, rtol=1e-2)
204*da0073e9SAndroid Build Coastguard Worker    else:
205*da0073e9SAndroid Build Coastguard Worker        assert torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Workerdef run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
209*da0073e9SAndroid Build Coastguard Worker    with sdp_kernel(
210*da0073e9SAndroid Build Coastguard Worker        enable_math=config.enable_math,
211*da0073e9SAndroid Build Coastguard Worker        enable_flash=config.enable_flash,
212*da0073e9SAndroid Build Coastguard Worker        enable_mem_efficient=config.enable_mem_efficient,
213*da0073e9SAndroid Build Coastguard Worker        enable_cudnn=config.enable_cudnn,
214*da0073e9SAndroid Build Coastguard Worker    ) as kernel_choice, torch.inference_mode() as inference_mode:
215*da0073e9SAndroid Build Coastguard Worker        dropout_p = 0.0
216*da0073e9SAndroid Build Coastguard Worker        mask = None
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker        nn_mha = torch.nn.MultiheadAttention(
219*da0073e9SAndroid Build Coastguard Worker            embed_dim=config.embed_dimension,
220*da0073e9SAndroid Build Coastguard Worker            num_heads=config.num_heads,
221*da0073e9SAndroid Build Coastguard Worker            batch_first=True,
222*da0073e9SAndroid Build Coastguard Worker            dropout=dropout_p,
223*da0073e9SAndroid Build Coastguard Worker        )
224*da0073e9SAndroid Build Coastguard Worker        nn_mha = nn_mha.eval().to("cuda", config.dtype)
225*da0073e9SAndroid Build Coastguard Worker        composite_mha = build_composite_mha_from_nn_mha(nn_mha)
226*da0073e9SAndroid Build Coastguard Worker        qkv, lengths = generate_rand_batch(
227*da0073e9SAndroid Build Coastguard Worker            config.batch_size,
228*da0073e9SAndroid Build Coastguard Worker            config.max_sequence_len,
229*da0073e9SAndroid Build Coastguard Worker            config.embed_dimension,
230*da0073e9SAndroid Build Coastguard Worker            config.pad_percentage,
231*da0073e9SAndroid Build Coastguard Worker            config.dtype,
232*da0073e9SAndroid Build Coastguard Worker        )
233*da0073e9SAndroid Build Coastguard Worker        nn_mha_output, _ = nn_mha(qkv, qkv, qkv, mask)
234*da0073e9SAndroid Build Coastguard Worker        composite_mha_output, _ = composite_mha(qkv, qkv, qkv, mask)
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker        # First order sanity check
237*da0073e9SAndroid Build Coastguard Worker        assert_close_tensors(nn_mha_output, composite_mha_output)
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker        nn_mha_time = benchmark_torch_function_in_microseconds(
240*da0073e9SAndroid Build Coastguard Worker            nn_mha, qkv, qkv, qkv, mask
241*da0073e9SAndroid Build Coastguard Worker        )
242*da0073e9SAndroid Build Coastguard Worker        composite_mha_time = benchmark_torch_function_in_microseconds(
243*da0073e9SAndroid Build Coastguard Worker            composite_mha, qkv, qkv, qkv, mask
244*da0073e9SAndroid Build Coastguard Worker        )
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker        # TorchDynamo will error on NestedTensors
247*da0073e9SAndroid Build Coastguard Worker        if config.pad_percentage is None:
248*da0073e9SAndroid Build Coastguard Worker            compiled_nn_mha = torch.compile(nn_mha)
249*da0073e9SAndroid Build Coastguard Worker            compiled_composite_mha = torch.compile(composite_mha)
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker            compiled_nn_mha_time = benchmark_torch_function_in_microseconds(
252*da0073e9SAndroid Build Coastguard Worker                compiled_nn_mha, qkv, qkv, qkv, mask
253*da0073e9SAndroid Build Coastguard Worker            )
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker            compiled_composite_mha_time = benchmark_torch_function_in_microseconds(
256*da0073e9SAndroid Build Coastguard Worker                compiled_composite_mha,
257*da0073e9SAndroid Build Coastguard Worker                qkv,
258*da0073e9SAndroid Build Coastguard Worker                qkv,
259*da0073e9SAndroid Build Coastguard Worker                qkv,
260*da0073e9SAndroid Build Coastguard Worker                mask,
261*da0073e9SAndroid Build Coastguard Worker            )
262*da0073e9SAndroid Build Coastguard Worker        else:
263*da0073e9SAndroid Build Coastguard Worker            compiled_nn_mha_time = None
264*da0073e9SAndroid Build Coastguard Worker            compiled_composite_mha_time = None
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker        results = ExperimentResults(
267*da0073e9SAndroid Build Coastguard Worker            nn_mha_time,
268*da0073e9SAndroid Build Coastguard Worker            compiled_nn_mha_time,
269*da0073e9SAndroid Build Coastguard Worker            composite_mha_time,
270*da0073e9SAndroid Build Coastguard Worker            compiled_composite_mha_time,
271*da0073e9SAndroid Build Coastguard Worker        )
272*da0073e9SAndroid Build Coastguard Worker        return Experiment(config, results)
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker# Could return generator
276*da0073e9SAndroid Build Coastguard Workerdef generate_experiments(
277*da0073e9SAndroid Build Coastguard Worker    batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages
278*da0073e9SAndroid Build Coastguard Worker) -> List[ExperimentConfig]:
279*da0073e9SAndroid Build Coastguard Worker    configs = []
280*da0073e9SAndroid Build Coastguard Worker    for bsz, n_heads, seq_len, embed_dim, dtype, padding in itertools.product(
281*da0073e9SAndroid Build Coastguard Worker        batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages
282*da0073e9SAndroid Build Coastguard Worker    ):
283*da0073e9SAndroid Build Coastguard Worker        configs.append(
284*da0073e9SAndroid Build Coastguard Worker            ExperimentConfig(
285*da0073e9SAndroid Build Coastguard Worker                batch_size=bsz,
286*da0073e9SAndroid Build Coastguard Worker                num_heads=n_heads,
287*da0073e9SAndroid Build Coastguard Worker                max_sequence_len=seq_len,
288*da0073e9SAndroid Build Coastguard Worker                embed_dimension=embed_dim,
289*da0073e9SAndroid Build Coastguard Worker                dtype=dtype,
290*da0073e9SAndroid Build Coastguard Worker                pad_percentage=padding,
291*da0073e9SAndroid Build Coastguard Worker                enable_math=False,
292*da0073e9SAndroid Build Coastguard Worker                enable_flash=True,
293*da0073e9SAndroid Build Coastguard Worker                enable_mem_efficient=True,
294*da0073e9SAndroid Build Coastguard Worker                enable_cudnn=True,
295*da0073e9SAndroid Build Coastguard Worker            )
296*da0073e9SAndroid Build Coastguard Worker        )
297*da0073e9SAndroid Build Coastguard Worker    return configs
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Workerdef main(save_path: Optional[Path]):
301*da0073e9SAndroid Build Coastguard Worker    seed = 123
302*da0073e9SAndroid Build Coastguard Worker    np.random.seed(seed)
303*da0073e9SAndroid Build Coastguard Worker    torch.manual_seed(seed)
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker    # Run one timing experiment comparing nn_mha vs composite_mha
306*da0073e9SAndroid Build Coastguard Worker    config = ExperimentConfig(
307*da0073e9SAndroid Build Coastguard Worker        batch_size=128,
308*da0073e9SAndroid Build Coastguard Worker        num_heads=8,
309*da0073e9SAndroid Build Coastguard Worker        max_sequence_len=512,
310*da0073e9SAndroid Build Coastguard Worker        embed_dimension=512,
311*da0073e9SAndroid Build Coastguard Worker        dtype=torch.float16,
312*da0073e9SAndroid Build Coastguard Worker        pad_percentage=None,
313*da0073e9SAndroid Build Coastguard Worker        enable_math=False,
314*da0073e9SAndroid Build Coastguard Worker        enable_flash=True,
315*da0073e9SAndroid Build Coastguard Worker        enable_mem_efficient=True,
316*da0073e9SAndroid Build Coastguard Worker        enable_cudnn=True,
317*da0073e9SAndroid Build Coastguard Worker    )
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker    experiment = run_single_experiment(config)
320*da0073e9SAndroid Build Coastguard Worker    pprint(experiment)
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker    table = PrettyTable()
323*da0073e9SAndroid Build Coastguard Worker    table.float_format = ".3"
324*da0073e9SAndroid Build Coastguard Worker    table.field_names = (
325*da0073e9SAndroid Build Coastguard Worker        ExperimentConfig.get_entry_names() + ExperimentResults.get_entry_names()
326*da0073e9SAndroid Build Coastguard Worker    )
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker    # Run a bunch of experiments
329*da0073e9SAndroid Build Coastguard Worker    batch_sizes = [256]
330*da0073e9SAndroid Build Coastguard Worker    num_heads = [32]
331*da0073e9SAndroid Build Coastguard Worker    max_seq_lens = [256]
332*da0073e9SAndroid Build Coastguard Worker    embed_dims = [512]
333*da0073e9SAndroid Build Coastguard Worker    dtypes = [torch.bfloat16, torch.float16, torch.float32]
334*da0073e9SAndroid Build Coastguard Worker    pad_percentages = [None, 0.9]
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker    experiment_configs = generate_experiments(
337*da0073e9SAndroid Build Coastguard Worker        batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages
338*da0073e9SAndroid Build Coastguard Worker    )
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker    experiments: List[Experiment] = []
341*da0073e9SAndroid Build Coastguard Worker    for experiment_config in tqdm(experiment_configs):
342*da0073e9SAndroid Build Coastguard Worker        experiment = run_single_experiment(experiment_config)
343*da0073e9SAndroid Build Coastguard Worker        experiments.append(experiment)
344*da0073e9SAndroid Build Coastguard Worker        table.add_row(experiment.get_entries())
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker    print(table)
347*da0073e9SAndroid Build Coastguard Worker
348*da0073e9SAndroid Build Coastguard Worker    csv_string = table.get_csv_string()
349*da0073e9SAndroid Build Coastguard Worker    if save_path is not None:
350*da0073e9SAndroid Build Coastguard Worker        with open(save_path, "w") as csvfile:
351*da0073e9SAndroid Build Coastguard Worker            csvfile.write(csv_string)
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
355*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser()
356*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
357*da0073e9SAndroid Build Coastguard Worker        "--save-path", "--save_path", type=str, help="Path to save the results"
358*da0073e9SAndroid Build Coastguard Worker    )
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
361*da0073e9SAndroid Build Coastguard Worker    save_path = Path(args.save_path) if args.save_path else None
362*da0073e9SAndroid Build Coastguard Worker    main(save_path)
363