1*da0073e9SAndroid Build Coastguard Workerimport argparse 2*da0073e9SAndroid Build Coastguard Workerimport itertools 3*da0073e9SAndroid Build Coastguard Workerimport random 4*da0073e9SAndroid Build Coastguard Workerimport warnings 5*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass 6*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path 7*da0073e9SAndroid Build Coastguard Workerfrom pprint import pprint 8*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Optional 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerimport numpy as np 11*da0073e9SAndroid Build Coastguard Workerfrom prettytable import PrettyTable 12*da0073e9SAndroid Build Coastguard Workerfrom tqdm import tqdm 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerimport torch 15*da0073e9SAndroid Build Coastguard Workerimport torch.utils.benchmark as benchmark 16*da0073e9SAndroid Build Coastguard Workerfrom torch.backends.cuda import sdp_kernel 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerwarnings.filterwarnings("ignore") 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 23*da0073e9SAndroid Build Coastguard Workerclass ExperimentConfig: 24*da0073e9SAndroid Build Coastguard Worker batch_size: int 25*da0073e9SAndroid Build Coastguard Worker num_heads: int 26*da0073e9SAndroid Build Coastguard Worker max_sequence_len: int 27*da0073e9SAndroid Build Coastguard Worker embed_dimension: int 28*da0073e9SAndroid Build Coastguard Worker dtype: torch.dtype 29*da0073e9SAndroid Build Coastguard Worker pad_percentage: Optional[float] 30*da0073e9SAndroid Build Coastguard Worker enable_math: bool 31*da0073e9SAndroid Build Coastguard Worker enable_flash: bool 32*da0073e9SAndroid Build Coastguard Worker enable_mem_efficient: bool 33*da0073e9SAndroid Build Coastguard Worker enable_cudnn: bool 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker def get_entries(self) -> List: 36*da0073e9SAndroid Build Coastguard Worker return [ 37*da0073e9SAndroid Build Coastguard Worker self.batch_size, 38*da0073e9SAndroid Build Coastguard Worker self.num_heads, 39*da0073e9SAndroid Build Coastguard Worker self.max_sequence_len, 40*da0073e9SAndroid Build Coastguard Worker self.embed_dimension, 41*da0073e9SAndroid Build Coastguard Worker self.dtype, 42*da0073e9SAndroid Build Coastguard Worker self.pad_percentage, 43*da0073e9SAndroid Build Coastguard Worker self.enable_math, 44*da0073e9SAndroid Build Coastguard Worker self.enable_flash, 45*da0073e9SAndroid Build Coastguard Worker self.enable_mem_efficient, 46*da0073e9SAndroid Build Coastguard Worker self.enable_cudnn, 47*da0073e9SAndroid Build Coastguard Worker ] 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker @classmethod 50*da0073e9SAndroid Build Coastguard Worker def get_entry_names(cls) -> List[str]: 51*da0073e9SAndroid Build Coastguard Worker return [ 52*da0073e9SAndroid Build Coastguard Worker "batch_size", 53*da0073e9SAndroid Build Coastguard Worker "num_heads", 54*da0073e9SAndroid Build Coastguard Worker "max_sequence_len", 55*da0073e9SAndroid Build Coastguard Worker "embed_dimension", 56*da0073e9SAndroid Build Coastguard Worker "dtype", 57*da0073e9SAndroid Build Coastguard Worker "pad_percentage", 58*da0073e9SAndroid Build Coastguard Worker "enable_math", 59*da0073e9SAndroid Build Coastguard Worker "enable_flash", 60*da0073e9SAndroid Build Coastguard Worker "enable_mem_efficient", 61*da0073e9SAndroid Build Coastguard Worker "enable_cudnn", 62*da0073e9SAndroid Build Coastguard Worker ] 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 66*da0073e9SAndroid Build Coastguard Workerclass ExperimentResults: 67*da0073e9SAndroid Build Coastguard Worker nn_mha_time: float 68*da0073e9SAndroid Build Coastguard Worker compiled_nn_mha_time: Optional[float] 69*da0073e9SAndroid Build Coastguard Worker composite_mha_time: float 70*da0073e9SAndroid Build Coastguard Worker compiled_composite_mha_time: Optional[float] 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker def get_entries(self) -> List: 73*da0073e9SAndroid Build Coastguard Worker return [ 74*da0073e9SAndroid Build Coastguard Worker f"{self.nn_mha_time:2f}", 75*da0073e9SAndroid Build Coastguard Worker f"{self.compiled_nn_mha_time:2f}" if self.compiled_nn_mha_time else None, 76*da0073e9SAndroid Build Coastguard Worker f"{self.composite_mha_time:2f}", 77*da0073e9SAndroid Build Coastguard Worker f"{self.compiled_composite_mha_time:2f}" 78*da0073e9SAndroid Build Coastguard Worker if self.compiled_composite_mha_time 79*da0073e9SAndroid Build Coastguard Worker else None, 80*da0073e9SAndroid Build Coastguard Worker ] 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker @classmethod 83*da0073e9SAndroid Build Coastguard Worker def get_entry_names(cls) -> List[str]: 84*da0073e9SAndroid Build Coastguard Worker return [ 85*da0073e9SAndroid Build Coastguard Worker "nn_mha_time (\u00B5s)", 86*da0073e9SAndroid Build Coastguard Worker "compiled_nn_mha_time (\u00B5s)", 87*da0073e9SAndroid Build Coastguard Worker "composite_mha_time (\u00B5s)", 88*da0073e9SAndroid Build Coastguard Worker "compiled_composite_mha_time (\u00B5s)", 89*da0073e9SAndroid Build Coastguard Worker ] 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 93*da0073e9SAndroid Build Coastguard Workerclass Experiment: 94*da0073e9SAndroid Build Coastguard Worker config: ExperimentConfig 95*da0073e9SAndroid Build Coastguard Worker results: ExperimentResults 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker def get_entries(self) -> List: 98*da0073e9SAndroid Build Coastguard Worker return self.config.get_entries() + self.results.get_entries() 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Workerclass CompositeMHA(torch.nn.Module): 102*da0073e9SAndroid Build Coastguard Worker def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj): 103*da0073e9SAndroid Build Coastguard Worker super().__init__() 104*da0073e9SAndroid Build Coastguard Worker self.in_proj_weight = in_proj_weight 105*da0073e9SAndroid Build Coastguard Worker self.in_proj_bias = in_proj_bias 106*da0073e9SAndroid Build Coastguard Worker self.out_proj = out_proj 107*da0073e9SAndroid Build Coastguard Worker self.num_heads = num_heads 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker def forward(self, query, key, value, mask): 110*da0073e9SAndroid Build Coastguard Worker if not (query is key and key is value): 111*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 112*da0073e9SAndroid Build Coastguard Worker "query, key and value must be the same Tensor for now." 113*da0073e9SAndroid Build Coastguard Worker ) 114*da0073e9SAndroid Build Coastguard Worker if mask is not None: 115*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("mask is currently not supported.") 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker query_projected = torch.nn.functional.linear( 118*da0073e9SAndroid Build Coastguard Worker query, self.in_proj_weight, self.in_proj_bias 119*da0073e9SAndroid Build Coastguard Worker ) 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker batch_size = query_projected.size(0) 122*da0073e9SAndroid Build Coastguard Worker embed_dim = query_projected.size(2) 123*da0073e9SAndroid Build Coastguard Worker head_dim = embed_dim // (self.num_heads * 3) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker query, key, value = query_projected.chunk(3, -1) 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) 128*da0073e9SAndroid Build Coastguard Worker key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) 129*da0073e9SAndroid Build Coastguard Worker value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker # the output of sdp = (batch, num_heads, seq_len, head_dim) 132*da0073e9SAndroid Build Coastguard Worker attn = torch.nn.functional.scaled_dot_product_attention( 133*da0073e9SAndroid Build Coastguard Worker query, 134*da0073e9SAndroid Build Coastguard Worker key, 135*da0073e9SAndroid Build Coastguard Worker value, 136*da0073e9SAndroid Build Coastguard Worker attn_mask=None, 137*da0073e9SAndroid Build Coastguard Worker dropout_p=0.0, 138*da0073e9SAndroid Build Coastguard Worker is_causal=False, 139*da0073e9SAndroid Build Coastguard Worker ) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker attn = attn.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim) 142*da0073e9SAndroid Build Coastguard Worker # Match return signature of nn.MHA 143*da0073e9SAndroid Build Coastguard Worker return self.out_proj(attn), None 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Workerdef build_composite_mha_from_nn_mha(pt): 147*da0073e9SAndroid Build Coastguard Worker assert pt._qkv_same_embed_dim 148*da0073e9SAndroid Build Coastguard Worker in_proj_weight = pt.in_proj_weight 149*da0073e9SAndroid Build Coastguard Worker assert in_proj_weight is not None 150*da0073e9SAndroid Build Coastguard Worker assert pt.batch_first 151*da0073e9SAndroid Build Coastguard Worker return CompositeMHA(pt.num_heads, pt.in_proj_weight, pt.in_proj_bias, pt.out_proj) 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Workerdef generate_rand_batch( 155*da0073e9SAndroid Build Coastguard Worker batch_size, 156*da0073e9SAndroid Build Coastguard Worker max_sequence_len, 157*da0073e9SAndroid Build Coastguard Worker embed_dimension, 158*da0073e9SAndroid Build Coastguard Worker pad_percentage=None, 159*da0073e9SAndroid Build Coastguard Worker dtype=torch.float16, 160*da0073e9SAndroid Build Coastguard Worker device="cuda", 161*da0073e9SAndroid Build Coastguard Worker): 162*da0073e9SAndroid Build Coastguard Worker if not pad_percentage: 163*da0073e9SAndroid Build Coastguard Worker return ( 164*da0073e9SAndroid Build Coastguard Worker torch.randn( 165*da0073e9SAndroid Build Coastguard Worker batch_size, 166*da0073e9SAndroid Build Coastguard Worker max_sequence_len, 167*da0073e9SAndroid Build Coastguard Worker embed_dimension, 168*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 169*da0073e9SAndroid Build Coastguard Worker device=device, 170*da0073e9SAndroid Build Coastguard Worker ), 171*da0073e9SAndroid Build Coastguard Worker None, 172*da0073e9SAndroid Build Coastguard Worker ) 173*da0073e9SAndroid Build Coastguard Worker # Really slow but should work 174*da0073e9SAndroid Build Coastguard Worker seq_len_list = [ 175*da0073e9SAndroid Build Coastguard Worker int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01))) 176*da0073e9SAndroid Build Coastguard Worker for _ in range(batch_size) 177*da0073e9SAndroid Build Coastguard Worker ] 178*da0073e9SAndroid Build Coastguard Worker # Make random ele max length 179*da0073e9SAndroid Build Coastguard Worker seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len 180*da0073e9SAndroid Build Coastguard Worker # print(f"Theoretical padding: {pad_percentage} actual: {1 - (sum(seq_len_list) / (batch_size * max_sequence_len))}") 181*da0073e9SAndroid Build Coastguard Worker return ( 182*da0073e9SAndroid Build Coastguard Worker torch.nested.nested_tensor( 183*da0073e9SAndroid Build Coastguard Worker [ 184*da0073e9SAndroid Build Coastguard Worker torch.randn(seq_len, embed_dimension, dtype=dtype, device=device) 185*da0073e9SAndroid Build Coastguard Worker for seq_len in seq_len_list 186*da0073e9SAndroid Build Coastguard Worker ] 187*da0073e9SAndroid Build Coastguard Worker ), 188*da0073e9SAndroid Build Coastguard Worker seq_len_list, 189*da0073e9SAndroid Build Coastguard Worker ) 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Workerdef benchmark_torch_function_in_microseconds(f, *args, **kwargs): 193*da0073e9SAndroid Build Coastguard Worker t0 = benchmark.Timer( 194*da0073e9SAndroid Build Coastguard Worker stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} 195*da0073e9SAndroid Build Coastguard Worker ) 196*da0073e9SAndroid Build Coastguard Worker return t0.blocked_autorange().mean * 1e6 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Workerdef assert_close_tensors(tensor_a, tensor_b): 200*da0073e9SAndroid Build Coastguard Worker # First order sanity check. Not a replacement for rigorous tests. 201*da0073e9SAndroid Build Coastguard Worker if tensor_a.is_nested and tensor_b.is_nested: 202*da0073e9SAndroid Build Coastguard Worker for a, b in zip(tensor_a.unbind(), tensor_b.unbind()): 203*da0073e9SAndroid Build Coastguard Worker assert torch.allclose(a, b, atol=1e-2, rtol=1e-2) 204*da0073e9SAndroid Build Coastguard Worker else: 205*da0073e9SAndroid Build Coastguard Worker assert torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3) 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Workerdef run_single_experiment(config: ExperimentConfig) -> ExperimentResults: 209*da0073e9SAndroid Build Coastguard Worker with sdp_kernel( 210*da0073e9SAndroid Build Coastguard Worker enable_math=config.enable_math, 211*da0073e9SAndroid Build Coastguard Worker enable_flash=config.enable_flash, 212*da0073e9SAndroid Build Coastguard Worker enable_mem_efficient=config.enable_mem_efficient, 213*da0073e9SAndroid Build Coastguard Worker enable_cudnn=config.enable_cudnn, 214*da0073e9SAndroid Build Coastguard Worker ) as kernel_choice, torch.inference_mode() as inference_mode: 215*da0073e9SAndroid Build Coastguard Worker dropout_p = 0.0 216*da0073e9SAndroid Build Coastguard Worker mask = None 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker nn_mha = torch.nn.MultiheadAttention( 219*da0073e9SAndroid Build Coastguard Worker embed_dim=config.embed_dimension, 220*da0073e9SAndroid Build Coastguard Worker num_heads=config.num_heads, 221*da0073e9SAndroid Build Coastguard Worker batch_first=True, 222*da0073e9SAndroid Build Coastguard Worker dropout=dropout_p, 223*da0073e9SAndroid Build Coastguard Worker ) 224*da0073e9SAndroid Build Coastguard Worker nn_mha = nn_mha.eval().to("cuda", config.dtype) 225*da0073e9SAndroid Build Coastguard Worker composite_mha = build_composite_mha_from_nn_mha(nn_mha) 226*da0073e9SAndroid Build Coastguard Worker qkv, lengths = generate_rand_batch( 227*da0073e9SAndroid Build Coastguard Worker config.batch_size, 228*da0073e9SAndroid Build Coastguard Worker config.max_sequence_len, 229*da0073e9SAndroid Build Coastguard Worker config.embed_dimension, 230*da0073e9SAndroid Build Coastguard Worker config.pad_percentage, 231*da0073e9SAndroid Build Coastguard Worker config.dtype, 232*da0073e9SAndroid Build Coastguard Worker ) 233*da0073e9SAndroid Build Coastguard Worker nn_mha_output, _ = nn_mha(qkv, qkv, qkv, mask) 234*da0073e9SAndroid Build Coastguard Worker composite_mha_output, _ = composite_mha(qkv, qkv, qkv, mask) 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker # First order sanity check 237*da0073e9SAndroid Build Coastguard Worker assert_close_tensors(nn_mha_output, composite_mha_output) 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker nn_mha_time = benchmark_torch_function_in_microseconds( 240*da0073e9SAndroid Build Coastguard Worker nn_mha, qkv, qkv, qkv, mask 241*da0073e9SAndroid Build Coastguard Worker ) 242*da0073e9SAndroid Build Coastguard Worker composite_mha_time = benchmark_torch_function_in_microseconds( 243*da0073e9SAndroid Build Coastguard Worker composite_mha, qkv, qkv, qkv, mask 244*da0073e9SAndroid Build Coastguard Worker ) 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker # TorchDynamo will error on NestedTensors 247*da0073e9SAndroid Build Coastguard Worker if config.pad_percentage is None: 248*da0073e9SAndroid Build Coastguard Worker compiled_nn_mha = torch.compile(nn_mha) 249*da0073e9SAndroid Build Coastguard Worker compiled_composite_mha = torch.compile(composite_mha) 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker compiled_nn_mha_time = benchmark_torch_function_in_microseconds( 252*da0073e9SAndroid Build Coastguard Worker compiled_nn_mha, qkv, qkv, qkv, mask 253*da0073e9SAndroid Build Coastguard Worker ) 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker compiled_composite_mha_time = benchmark_torch_function_in_microseconds( 256*da0073e9SAndroid Build Coastguard Worker compiled_composite_mha, 257*da0073e9SAndroid Build Coastguard Worker qkv, 258*da0073e9SAndroid Build Coastguard Worker qkv, 259*da0073e9SAndroid Build Coastguard Worker qkv, 260*da0073e9SAndroid Build Coastguard Worker mask, 261*da0073e9SAndroid Build Coastguard Worker ) 262*da0073e9SAndroid Build Coastguard Worker else: 263*da0073e9SAndroid Build Coastguard Worker compiled_nn_mha_time = None 264*da0073e9SAndroid Build Coastguard Worker compiled_composite_mha_time = None 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker results = ExperimentResults( 267*da0073e9SAndroid Build Coastguard Worker nn_mha_time, 268*da0073e9SAndroid Build Coastguard Worker compiled_nn_mha_time, 269*da0073e9SAndroid Build Coastguard Worker composite_mha_time, 270*da0073e9SAndroid Build Coastguard Worker compiled_composite_mha_time, 271*da0073e9SAndroid Build Coastguard Worker ) 272*da0073e9SAndroid Build Coastguard Worker return Experiment(config, results) 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker# Could return generator 276*da0073e9SAndroid Build Coastguard Workerdef generate_experiments( 277*da0073e9SAndroid Build Coastguard Worker batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages 278*da0073e9SAndroid Build Coastguard Worker) -> List[ExperimentConfig]: 279*da0073e9SAndroid Build Coastguard Worker configs = [] 280*da0073e9SAndroid Build Coastguard Worker for bsz, n_heads, seq_len, embed_dim, dtype, padding in itertools.product( 281*da0073e9SAndroid Build Coastguard Worker batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages 282*da0073e9SAndroid Build Coastguard Worker ): 283*da0073e9SAndroid Build Coastguard Worker configs.append( 284*da0073e9SAndroid Build Coastguard Worker ExperimentConfig( 285*da0073e9SAndroid Build Coastguard Worker batch_size=bsz, 286*da0073e9SAndroid Build Coastguard Worker num_heads=n_heads, 287*da0073e9SAndroid Build Coastguard Worker max_sequence_len=seq_len, 288*da0073e9SAndroid Build Coastguard Worker embed_dimension=embed_dim, 289*da0073e9SAndroid Build Coastguard Worker dtype=dtype, 290*da0073e9SAndroid Build Coastguard Worker pad_percentage=padding, 291*da0073e9SAndroid Build Coastguard Worker enable_math=False, 292*da0073e9SAndroid Build Coastguard Worker enable_flash=True, 293*da0073e9SAndroid Build Coastguard Worker enable_mem_efficient=True, 294*da0073e9SAndroid Build Coastguard Worker enable_cudnn=True, 295*da0073e9SAndroid Build Coastguard Worker ) 296*da0073e9SAndroid Build Coastguard Worker ) 297*da0073e9SAndroid Build Coastguard Worker return configs 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Workerdef main(save_path: Optional[Path]): 301*da0073e9SAndroid Build Coastguard Worker seed = 123 302*da0073e9SAndroid Build Coastguard Worker np.random.seed(seed) 303*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(seed) 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker # Run one timing experiment comparing nn_mha vs composite_mha 306*da0073e9SAndroid Build Coastguard Worker config = ExperimentConfig( 307*da0073e9SAndroid Build Coastguard Worker batch_size=128, 308*da0073e9SAndroid Build Coastguard Worker num_heads=8, 309*da0073e9SAndroid Build Coastguard Worker max_sequence_len=512, 310*da0073e9SAndroid Build Coastguard Worker embed_dimension=512, 311*da0073e9SAndroid Build Coastguard Worker dtype=torch.float16, 312*da0073e9SAndroid Build Coastguard Worker pad_percentage=None, 313*da0073e9SAndroid Build Coastguard Worker enable_math=False, 314*da0073e9SAndroid Build Coastguard Worker enable_flash=True, 315*da0073e9SAndroid Build Coastguard Worker enable_mem_efficient=True, 316*da0073e9SAndroid Build Coastguard Worker enable_cudnn=True, 317*da0073e9SAndroid Build Coastguard Worker ) 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker experiment = run_single_experiment(config) 320*da0073e9SAndroid Build Coastguard Worker pprint(experiment) 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker table = PrettyTable() 323*da0073e9SAndroid Build Coastguard Worker table.float_format = ".3" 324*da0073e9SAndroid Build Coastguard Worker table.field_names = ( 325*da0073e9SAndroid Build Coastguard Worker ExperimentConfig.get_entry_names() + ExperimentResults.get_entry_names() 326*da0073e9SAndroid Build Coastguard Worker ) 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker # Run a bunch of experiments 329*da0073e9SAndroid Build Coastguard Worker batch_sizes = [256] 330*da0073e9SAndroid Build Coastguard Worker num_heads = [32] 331*da0073e9SAndroid Build Coastguard Worker max_seq_lens = [256] 332*da0073e9SAndroid Build Coastguard Worker embed_dims = [512] 333*da0073e9SAndroid Build Coastguard Worker dtypes = [torch.bfloat16, torch.float16, torch.float32] 334*da0073e9SAndroid Build Coastguard Worker pad_percentages = [None, 0.9] 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker experiment_configs = generate_experiments( 337*da0073e9SAndroid Build Coastguard Worker batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages 338*da0073e9SAndroid Build Coastguard Worker ) 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker experiments: List[Experiment] = [] 341*da0073e9SAndroid Build Coastguard Worker for experiment_config in tqdm(experiment_configs): 342*da0073e9SAndroid Build Coastguard Worker experiment = run_single_experiment(experiment_config) 343*da0073e9SAndroid Build Coastguard Worker experiments.append(experiment) 344*da0073e9SAndroid Build Coastguard Worker table.add_row(experiment.get_entries()) 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker print(table) 347*da0073e9SAndroid Build Coastguard Worker 348*da0073e9SAndroid Build Coastguard Worker csv_string = table.get_csv_string() 349*da0073e9SAndroid Build Coastguard Worker if save_path is not None: 350*da0073e9SAndroid Build Coastguard Worker with open(save_path, "w") as csvfile: 351*da0073e9SAndroid Build Coastguard Worker csvfile.write(csv_string) 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 355*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser() 356*da0073e9SAndroid Build Coastguard Worker parser.add_argument( 357*da0073e9SAndroid Build Coastguard Worker "--save-path", "--save_path", type=str, help="Path to save the results" 358*da0073e9SAndroid Build Coastguard Worker ) 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker args = parser.parse_args() 361*da0073e9SAndroid Build Coastguard Worker save_path = Path(args.save_path) if args.save_path else None 362*da0073e9SAndroid Build Coastguard Worker main(save_path) 363