import argparse import itertools import random import warnings from dataclasses import dataclass from pathlib import Path from pprint import pprint from typing import List, Optional import numpy as np from prettytable import PrettyTable from tqdm import tqdm import torch import torch.utils.benchmark as benchmark from torch.backends.cuda import sdp_kernel warnings.filterwarnings("ignore") @dataclass(frozen=True) class ExperimentConfig: batch_size: int num_heads: int max_sequence_len: int embed_dimension: int dtype: torch.dtype pad_percentage: Optional[float] enable_math: bool enable_flash: bool enable_mem_efficient: bool enable_cudnn: bool def get_entries(self) -> List: return [ self.batch_size, self.num_heads, self.max_sequence_len, self.embed_dimension, self.dtype, self.pad_percentage, self.enable_math, self.enable_flash, self.enable_mem_efficient, self.enable_cudnn, ] @classmethod def get_entry_names(cls) -> List[str]: return [ "batch_size", "num_heads", "max_sequence_len", "embed_dimension", "dtype", "pad_percentage", "enable_math", "enable_flash", "enable_mem_efficient", "enable_cudnn", ] @dataclass(frozen=True) class ExperimentResults: nn_mha_time: float compiled_nn_mha_time: Optional[float] composite_mha_time: float compiled_composite_mha_time: Optional[float] def get_entries(self) -> List: return [ f"{self.nn_mha_time:2f}", f"{self.compiled_nn_mha_time:2f}" if self.compiled_nn_mha_time else None, f"{self.composite_mha_time:2f}", f"{self.compiled_composite_mha_time:2f}" if self.compiled_composite_mha_time else None, ] @classmethod def get_entry_names(cls) -> List[str]: return [ "nn_mha_time (\u00B5s)", "compiled_nn_mha_time (\u00B5s)", "composite_mha_time (\u00B5s)", "compiled_composite_mha_time (\u00B5s)", ] @dataclass(frozen=True) class Experiment: config: ExperimentConfig results: ExperimentResults def get_entries(self) -> List: return self.config.get_entries() + self.results.get_entries() class CompositeMHA(torch.nn.Module): def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj): super().__init__() self.in_proj_weight = in_proj_weight self.in_proj_bias = in_proj_bias self.out_proj = out_proj self.num_heads = num_heads def forward(self, query, key, value, mask): if not (query is key and key is value): raise NotImplementedError( "query, key and value must be the same Tensor for now." ) if mask is not None: raise NotImplementedError("mask is currently not supported.") query_projected = torch.nn.functional.linear( query, self.in_proj_weight, self.in_proj_bias ) batch_size = query_projected.size(0) embed_dim = query_projected.size(2) head_dim = embed_dim // (self.num_heads * 3) query, key, value = query_projected.chunk(3, -1) query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) attn = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, ) attn = attn.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim) # Match return signature of nn.MHA return self.out_proj(attn), None def build_composite_mha_from_nn_mha(pt): assert pt._qkv_same_embed_dim in_proj_weight = pt.in_proj_weight assert in_proj_weight is not None assert pt.batch_first return CompositeMHA(pt.num_heads, pt.in_proj_weight, pt.in_proj_bias, pt.out_proj) def generate_rand_batch( batch_size, max_sequence_len, embed_dimension, pad_percentage=None, dtype=torch.float16, device="cuda", ): if not pad_percentage: return ( torch.randn( batch_size, max_sequence_len, embed_dimension, dtype=dtype, device=device, ), None, ) # Really slow but should work seq_len_list = [ int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01))) for _ in range(batch_size) ] # Make random ele max length seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len # print(f"Theoretical padding: {pad_percentage} actual: {1 - (sum(seq_len_list) / (batch_size * max_sequence_len))}") return ( torch.nested.nested_tensor( [ torch.randn(seq_len, embed_dimension, dtype=dtype, device=device) for seq_len in seq_len_list ] ), seq_len_list, ) def benchmark_torch_function_in_microseconds(f, *args, **kwargs): t0 = benchmark.Timer( stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} ) return t0.blocked_autorange().mean * 1e6 def assert_close_tensors(tensor_a, tensor_b): # First order sanity check. Not a replacement for rigorous tests. if tensor_a.is_nested and tensor_b.is_nested: for a, b in zip(tensor_a.unbind(), tensor_b.unbind()): assert torch.allclose(a, b, atol=1e-2, rtol=1e-2) else: assert torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3) def run_single_experiment(config: ExperimentConfig) -> ExperimentResults: with sdp_kernel( enable_math=config.enable_math, enable_flash=config.enable_flash, enable_mem_efficient=config.enable_mem_efficient, enable_cudnn=config.enable_cudnn, ) as kernel_choice, torch.inference_mode() as inference_mode: dropout_p = 0.0 mask = None nn_mha = torch.nn.MultiheadAttention( embed_dim=config.embed_dimension, num_heads=config.num_heads, batch_first=True, dropout=dropout_p, ) nn_mha = nn_mha.eval().to("cuda", config.dtype) composite_mha = build_composite_mha_from_nn_mha(nn_mha) qkv, lengths = generate_rand_batch( config.batch_size, config.max_sequence_len, config.embed_dimension, config.pad_percentage, config.dtype, ) nn_mha_output, _ = nn_mha(qkv, qkv, qkv, mask) composite_mha_output, _ = composite_mha(qkv, qkv, qkv, mask) # First order sanity check assert_close_tensors(nn_mha_output, composite_mha_output) nn_mha_time = benchmark_torch_function_in_microseconds( nn_mha, qkv, qkv, qkv, mask ) composite_mha_time = benchmark_torch_function_in_microseconds( composite_mha, qkv, qkv, qkv, mask ) # TorchDynamo will error on NestedTensors if config.pad_percentage is None: compiled_nn_mha = torch.compile(nn_mha) compiled_composite_mha = torch.compile(composite_mha) compiled_nn_mha_time = benchmark_torch_function_in_microseconds( compiled_nn_mha, qkv, qkv, qkv, mask ) compiled_composite_mha_time = benchmark_torch_function_in_microseconds( compiled_composite_mha, qkv, qkv, qkv, mask, ) else: compiled_nn_mha_time = None compiled_composite_mha_time = None results = ExperimentResults( nn_mha_time, compiled_nn_mha_time, composite_mha_time, compiled_composite_mha_time, ) return Experiment(config, results) # Could return generator def generate_experiments( batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages ) -> List[ExperimentConfig]: configs = [] for bsz, n_heads, seq_len, embed_dim, dtype, padding in itertools.product( batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages ): configs.append( ExperimentConfig( batch_size=bsz, num_heads=n_heads, max_sequence_len=seq_len, embed_dimension=embed_dim, dtype=dtype, pad_percentage=padding, enable_math=False, enable_flash=True, enable_mem_efficient=True, enable_cudnn=True, ) ) return configs def main(save_path: Optional[Path]): seed = 123 np.random.seed(seed) torch.manual_seed(seed) # Run one timing experiment comparing nn_mha vs composite_mha config = ExperimentConfig( batch_size=128, num_heads=8, max_sequence_len=512, embed_dimension=512, dtype=torch.float16, pad_percentage=None, enable_math=False, enable_flash=True, enable_mem_efficient=True, enable_cudnn=True, ) experiment = run_single_experiment(config) pprint(experiment) table = PrettyTable() table.float_format = ".3" table.field_names = ( ExperimentConfig.get_entry_names() + ExperimentResults.get_entry_names() ) # Run a bunch of experiments batch_sizes = [256] num_heads = [32] max_seq_lens = [256] embed_dims = [512] dtypes = [torch.bfloat16, torch.float16, torch.float32] pad_percentages = [None, 0.9] experiment_configs = generate_experiments( batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages ) experiments: List[Experiment] = [] for experiment_config in tqdm(experiment_configs): experiment = run_single_experiment(experiment_config) experiments.append(experiment) table.add_row(experiment.get_entries()) print(table) csv_string = table.get_csv_string() if save_path is not None: with open(save_path, "w") as csvfile: csvfile.write(csv_string) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--save-path", "--save_path", type=str, help="Path to save the results" ) args = parser.parse_args() save_path = Path(args.save_path) if args.save_path else None main(save_path)