1import argparse 2import itertools 3import random 4import warnings 5from dataclasses import dataclass 6from pathlib import Path 7from pprint import pprint 8from typing import List, Optional 9 10import numpy as np 11from prettytable import PrettyTable 12from tqdm import tqdm 13 14import torch 15import torch.utils.benchmark as benchmark 16from torch.backends.cuda import sdp_kernel 17 18 19warnings.filterwarnings("ignore") 20 21 22@dataclass(frozen=True) 23class ExperimentConfig: 24 batch_size: int 25 num_heads: int 26 max_sequence_len: int 27 embed_dimension: int 28 dtype: torch.dtype 29 pad_percentage: Optional[float] 30 enable_math: bool 31 enable_flash: bool 32 enable_mem_efficient: bool 33 enable_cudnn: bool 34 35 def get_entries(self) -> List: 36 return [ 37 self.batch_size, 38 self.num_heads, 39 self.max_sequence_len, 40 self.embed_dimension, 41 self.dtype, 42 self.pad_percentage, 43 self.enable_math, 44 self.enable_flash, 45 self.enable_mem_efficient, 46 self.enable_cudnn, 47 ] 48 49 @classmethod 50 def get_entry_names(cls) -> List[str]: 51 return [ 52 "batch_size", 53 "num_heads", 54 "max_sequence_len", 55 "embed_dimension", 56 "dtype", 57 "pad_percentage", 58 "enable_math", 59 "enable_flash", 60 "enable_mem_efficient", 61 "enable_cudnn", 62 ] 63 64 65@dataclass(frozen=True) 66class ExperimentResults: 67 nn_mha_time: float 68 compiled_nn_mha_time: Optional[float] 69 composite_mha_time: float 70 compiled_composite_mha_time: Optional[float] 71 72 def get_entries(self) -> List: 73 return [ 74 f"{self.nn_mha_time:2f}", 75 f"{self.compiled_nn_mha_time:2f}" if self.compiled_nn_mha_time else None, 76 f"{self.composite_mha_time:2f}", 77 f"{self.compiled_composite_mha_time:2f}" 78 if self.compiled_composite_mha_time 79 else None, 80 ] 81 82 @classmethod 83 def get_entry_names(cls) -> List[str]: 84 return [ 85 "nn_mha_time (\u00B5s)", 86 "compiled_nn_mha_time (\u00B5s)", 87 "composite_mha_time (\u00B5s)", 88 "compiled_composite_mha_time (\u00B5s)", 89 ] 90 91 92@dataclass(frozen=True) 93class Experiment: 94 config: ExperimentConfig 95 results: ExperimentResults 96 97 def get_entries(self) -> List: 98 return self.config.get_entries() + self.results.get_entries() 99 100 101class CompositeMHA(torch.nn.Module): 102 def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj): 103 super().__init__() 104 self.in_proj_weight = in_proj_weight 105 self.in_proj_bias = in_proj_bias 106 self.out_proj = out_proj 107 self.num_heads = num_heads 108 109 def forward(self, query, key, value, mask): 110 if not (query is key and key is value): 111 raise NotImplementedError( 112 "query, key and value must be the same Tensor for now." 113 ) 114 if mask is not None: 115 raise NotImplementedError("mask is currently not supported.") 116 117 query_projected = torch.nn.functional.linear( 118 query, self.in_proj_weight, self.in_proj_bias 119 ) 120 121 batch_size = query_projected.size(0) 122 embed_dim = query_projected.size(2) 123 head_dim = embed_dim // (self.num_heads * 3) 124 125 query, key, value = query_projected.chunk(3, -1) 126 127 query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) 128 key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) 129 value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) 130 131 # the output of sdp = (batch, num_heads, seq_len, head_dim) 132 attn = torch.nn.functional.scaled_dot_product_attention( 133 query, 134 key, 135 value, 136 attn_mask=None, 137 dropout_p=0.0, 138 is_causal=False, 139 ) 140 141 attn = attn.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim) 142 # Match return signature of nn.MHA 143 return self.out_proj(attn), None 144 145 146def build_composite_mha_from_nn_mha(pt): 147 assert pt._qkv_same_embed_dim 148 in_proj_weight = pt.in_proj_weight 149 assert in_proj_weight is not None 150 assert pt.batch_first 151 return CompositeMHA(pt.num_heads, pt.in_proj_weight, pt.in_proj_bias, pt.out_proj) 152 153 154def generate_rand_batch( 155 batch_size, 156 max_sequence_len, 157 embed_dimension, 158 pad_percentage=None, 159 dtype=torch.float16, 160 device="cuda", 161): 162 if not pad_percentage: 163 return ( 164 torch.randn( 165 batch_size, 166 max_sequence_len, 167 embed_dimension, 168 dtype=dtype, 169 device=device, 170 ), 171 None, 172 ) 173 # Really slow but should work 174 seq_len_list = [ 175 int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01))) 176 for _ in range(batch_size) 177 ] 178 # Make random ele max length 179 seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len 180 # print(f"Theoretical padding: {pad_percentage} actual: {1 - (sum(seq_len_list) / (batch_size * max_sequence_len))}") 181 return ( 182 torch.nested.nested_tensor( 183 [ 184 torch.randn(seq_len, embed_dimension, dtype=dtype, device=device) 185 for seq_len in seq_len_list 186 ] 187 ), 188 seq_len_list, 189 ) 190 191 192def benchmark_torch_function_in_microseconds(f, *args, **kwargs): 193 t0 = benchmark.Timer( 194 stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} 195 ) 196 return t0.blocked_autorange().mean * 1e6 197 198 199def assert_close_tensors(tensor_a, tensor_b): 200 # First order sanity check. Not a replacement for rigorous tests. 201 if tensor_a.is_nested and tensor_b.is_nested: 202 for a, b in zip(tensor_a.unbind(), tensor_b.unbind()): 203 assert torch.allclose(a, b, atol=1e-2, rtol=1e-2) 204 else: 205 assert torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3) 206 207 208def run_single_experiment(config: ExperimentConfig) -> ExperimentResults: 209 with sdp_kernel( 210 enable_math=config.enable_math, 211 enable_flash=config.enable_flash, 212 enable_mem_efficient=config.enable_mem_efficient, 213 enable_cudnn=config.enable_cudnn, 214 ) as kernel_choice, torch.inference_mode() as inference_mode: 215 dropout_p = 0.0 216 mask = None 217 218 nn_mha = torch.nn.MultiheadAttention( 219 embed_dim=config.embed_dimension, 220 num_heads=config.num_heads, 221 batch_first=True, 222 dropout=dropout_p, 223 ) 224 nn_mha = nn_mha.eval().to("cuda", config.dtype) 225 composite_mha = build_composite_mha_from_nn_mha(nn_mha) 226 qkv, lengths = generate_rand_batch( 227 config.batch_size, 228 config.max_sequence_len, 229 config.embed_dimension, 230 config.pad_percentage, 231 config.dtype, 232 ) 233 nn_mha_output, _ = nn_mha(qkv, qkv, qkv, mask) 234 composite_mha_output, _ = composite_mha(qkv, qkv, qkv, mask) 235 236 # First order sanity check 237 assert_close_tensors(nn_mha_output, composite_mha_output) 238 239 nn_mha_time = benchmark_torch_function_in_microseconds( 240 nn_mha, qkv, qkv, qkv, mask 241 ) 242 composite_mha_time = benchmark_torch_function_in_microseconds( 243 composite_mha, qkv, qkv, qkv, mask 244 ) 245 246 # TorchDynamo will error on NestedTensors 247 if config.pad_percentage is None: 248 compiled_nn_mha = torch.compile(nn_mha) 249 compiled_composite_mha = torch.compile(composite_mha) 250 251 compiled_nn_mha_time = benchmark_torch_function_in_microseconds( 252 compiled_nn_mha, qkv, qkv, qkv, mask 253 ) 254 255 compiled_composite_mha_time = benchmark_torch_function_in_microseconds( 256 compiled_composite_mha, 257 qkv, 258 qkv, 259 qkv, 260 mask, 261 ) 262 else: 263 compiled_nn_mha_time = None 264 compiled_composite_mha_time = None 265 266 results = ExperimentResults( 267 nn_mha_time, 268 compiled_nn_mha_time, 269 composite_mha_time, 270 compiled_composite_mha_time, 271 ) 272 return Experiment(config, results) 273 274 275# Could return generator 276def generate_experiments( 277 batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages 278) -> List[ExperimentConfig]: 279 configs = [] 280 for bsz, n_heads, seq_len, embed_dim, dtype, padding in itertools.product( 281 batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages 282 ): 283 configs.append( 284 ExperimentConfig( 285 batch_size=bsz, 286 num_heads=n_heads, 287 max_sequence_len=seq_len, 288 embed_dimension=embed_dim, 289 dtype=dtype, 290 pad_percentage=padding, 291 enable_math=False, 292 enable_flash=True, 293 enable_mem_efficient=True, 294 enable_cudnn=True, 295 ) 296 ) 297 return configs 298 299 300def main(save_path: Optional[Path]): 301 seed = 123 302 np.random.seed(seed) 303 torch.manual_seed(seed) 304 305 # Run one timing experiment comparing nn_mha vs composite_mha 306 config = ExperimentConfig( 307 batch_size=128, 308 num_heads=8, 309 max_sequence_len=512, 310 embed_dimension=512, 311 dtype=torch.float16, 312 pad_percentage=None, 313 enable_math=False, 314 enable_flash=True, 315 enable_mem_efficient=True, 316 enable_cudnn=True, 317 ) 318 319 experiment = run_single_experiment(config) 320 pprint(experiment) 321 322 table = PrettyTable() 323 table.float_format = ".3" 324 table.field_names = ( 325 ExperimentConfig.get_entry_names() + ExperimentResults.get_entry_names() 326 ) 327 328 # Run a bunch of experiments 329 batch_sizes = [256] 330 num_heads = [32] 331 max_seq_lens = [256] 332 embed_dims = [512] 333 dtypes = [torch.bfloat16, torch.float16, torch.float32] 334 pad_percentages = [None, 0.9] 335 336 experiment_configs = generate_experiments( 337 batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages 338 ) 339 340 experiments: List[Experiment] = [] 341 for experiment_config in tqdm(experiment_configs): 342 experiment = run_single_experiment(experiment_config) 343 experiments.append(experiment) 344 table.add_row(experiment.get_entries()) 345 346 print(table) 347 348 csv_string = table.get_csv_string() 349 if save_path is not None: 350 with open(save_path, "w") as csvfile: 351 csvfile.write(csv_string) 352 353 354if __name__ == "__main__": 355 parser = argparse.ArgumentParser() 356 parser.add_argument( 357 "--save-path", "--save_path", type=str, help="Path to save the results" 358 ) 359 360 args = parser.parse_args() 361 save_path = Path(args.save_path) if args.save_path else None 362 main(save_path) 363