1import itertools 2from dataclasses import asdict, dataclass 3from functools import partial 4from typing import Callable, List, Union 5 6import numpy as np 7from tabulate import tabulate 8from tqdm import tqdm 9 10import torch 11import torch.nn as nn 12import torch.nn.functional as F 13import torch.utils.benchmark as benchmark 14from torch.nn.attention.bias import CausalBias, CausalVariant 15from torch.nn.parameter import Parameter 16 17 18def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float: 19 # warmup 20 for _ in range(5): 21 func(*args, **kwargs) 22 t0 = benchmark.Timer( 23 stmt="func(*args, **kwargs)", 24 globals={"args": args, "kwargs": kwargs, "func": func}, 25 ) 26 return t0.adaptive_autorange(min_run_time=0.1).median * 1e6 27 28 29@dataclass(frozen=True) 30class ExperimentConfig: 31 batch_size: int 32 num_heads: int 33 q_seq_len: int 34 k_seq_len: int 35 embed_dim: int 36 dtype: torch.dtype 37 38 @property 39 def head_dim(self) -> int: 40 return self.embed_dim // self.num_heads 41 42 def asdict(self): 43 dict_obj = asdict(self) 44 dict_obj["head_dim"] = self.head_dim 45 return dict_obj 46 47 48@dataclass(frozen=True) 49class ExperimentResults: 50 materialized_mask_time: float 51 attn_mask_subclass_time: float 52 53 def get_entries(self) -> List: 54 return [ 55 f"{self.materialized_mask_time:2f}", 56 f"{self.attn_mask_subclass_time:2f}", 57 ] 58 59 60@dataclass(frozen=True) 61class Experiment: 62 config: ExperimentConfig 63 results: ExperimentResults 64 65 def get_entries(self) -> List: 66 return self.config.get_entries() + self.results.get_entries() 67 68 69def generate_inputs( 70 batch_size, q_sequence_length, kv_sequence_length, embed_dim, dtype, device 71): 72 q_shape = (batch_size, q_sequence_length, embed_dim) 73 kv_shape = (batch_size, kv_sequence_length, embed_dim) 74 75 make_q = partial(torch.rand, q_shape, device=device, dtype=dtype) 76 make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype) 77 return make_q(), make_kv(), make_kv() 78 79 80class CompositeMHA(torch.nn.Module): 81 def __init__(self, num_heads, embed_dim, device=None, dtype=None): 82 factory_kwargs = {"device": device, "dtype": dtype} 83 super().__init__() 84 85 self.head_dim = embed_dim // num_heads 86 self.embed_dim = embed_dim 87 assert ( 88 self.head_dim * num_heads == self.embed_dim 89 ), "embed_dim must be divisible by num_heads" 90 91 self.q_proj_weight = Parameter( 92 torch.empty((embed_dim, embed_dim), **factory_kwargs) 93 ) 94 self.k_proj_weight = Parameter( 95 torch.empty((embed_dim, embed_dim), **factory_kwargs) 96 ) 97 self.v_proj_weight = Parameter( 98 torch.empty((embed_dim, embed_dim), **factory_kwargs) 99 ) 100 self.out_proj = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs)) 101 self.num_heads = num_heads 102 103 def forward( 104 self, 105 query: torch.Tensor, 106 key: torch.Tensor, 107 value: torch.Tensor, 108 mask: Union[torch.Tensor, CausalBias], 109 ): 110 query_projected = F.linear(query, self.q_proj_weight) 111 key_projected = F.linear(key, self.k_proj_weight) 112 value_projected = F.linear(value, self.v_proj_weight) 113 114 query = query.view( 115 query_projected.size(0), -1, self.num_heads, self.head_dim 116 ).transpose(1, 2) 117 key = key.view( 118 key_projected.size(0), -1, self.num_heads, self.head_dim 119 ).transpose(1, 2) 120 value = value.view( 121 value_projected.size(0), -1, self.num_heads, self.head_dim 122 ).transpose(1, 2) 123 124 attn = torch.nn.functional.scaled_dot_product_attention( 125 query, 126 key, 127 value, 128 attn_mask=mask, 129 dropout_p=0.0, 130 ) 131 132 attn = attn.transpose(1, 2).reshape(query.size(0), -1, self.embed_dim) 133 # Match return signature of nn.MHA 134 return F.linear(attn, self.out_proj) 135 136 def reset_parameters(self): 137 nn.init.xavier_uniform_(self.q_proj_weight) 138 nn.init.xavier_uniform_(self.k_proj_weight) 139 nn.init.xavier_uniform_(self.v_proj_weight) 140 nn.init.constant_(self.out_proj, 0.0) 141 142 143def run_single_experiment(config: ExperimentConfig) -> ExperimentResults: 144 device = torch.device("cuda") 145 composite_mha = CompositeMHA( 146 config.num_heads, config.embed_dim, device, config.dtype 147 ) 148 composite_mha.reset_parameters() 149 query, key, value = generate_inputs( 150 config.batch_size, 151 config.q_seq_len, 152 config.k_seq_len, 153 config.embed_dim, 154 config.dtype, 155 device, 156 ) 157 attn_mask = CausalBias( 158 CausalVariant.LOWER_RIGHT, config.q_seq_len, config.k_seq_len 159 ) 160 attn_mask_tensor = attn_mask._materialize(device) 161 162 materialized_mask_time = benchmark_torch_function_in_microseconds( 163 composite_mha, query, key, value, attn_mask_tensor 164 ) 165 attn_mask_subclass_time = benchmark_torch_function_in_microseconds( 166 composite_mha, query, key, value, attn_mask 167 ) 168 torch.testing.assert_close( 169 composite_mha(query, key, value, attn_mask_tensor), 170 composite_mha(query, key, value, attn_mask), 171 ) 172 173 return ExperimentResults( 174 materialized_mask_time=materialized_mask_time, 175 attn_mask_subclass_time=attn_mask_subclass_time, 176 ) 177 178 179def generate_experiment_configs() -> List[ExperimentConfig]: 180 batch_sizes = [1, 8, 16, 128] 181 num_heads = [16, 32] 182 q_kv_seq_lens = [(128, 256), (256, 416), (512, 4097), (1024, 2048), (1, 2048)] 183 embed_dims = [2048, 4096] 184 dtypes = [ 185 torch.bfloat16, 186 ] 187 all_configs = [] 188 for bsz, heads, (q_seq_len, kv_seq_len), embed_dim, dtype in itertools.product( 189 batch_sizes, num_heads, q_kv_seq_lens, embed_dims, dtypes 190 ): 191 all_configs.append( 192 ExperimentConfig( 193 batch_size=bsz, 194 num_heads=heads, 195 q_seq_len=q_seq_len, 196 k_seq_len=kv_seq_len, 197 embed_dim=embed_dim, 198 dtype=dtype, 199 ) 200 ) 201 202 return all_configs 203 204 205def calculate_speedup(results: ExperimentResults) -> float: 206 return results.materialized_mask_time / results.attn_mask_subclass_time 207 208 209def print_results(results: List[Experiment]): 210 # Calculate speedups 211 speedups = [calculate_speedup(r.results) for r in results] 212 213 # Find indices of max and min speedups 214 max_speedup_index = np.argmax(speedups) 215 min_speedup_index = np.argmin(speedups) 216 217 # Get the config dictionaries 218 max_config_dict = results[max_speedup_index].config.asdict() 219 min_config_dict = results[min_speedup_index].config.asdict() 220 221 # Create table data 222 table_data = [ 223 { 224 "Type": "Average", 225 "Speedup": np.mean(speedups), 226 **dict.fromkeys(max_config_dict), 227 }, 228 {"Type": "Max", "Speedup": speedups[max_speedup_index], **max_config_dict}, 229 {"Type": "Min", "Speedup": speedups[min_speedup_index], **min_config_dict}, 230 ] 231 232 # Print table 233 print(tabulate(table_data, headers="keys", tablefmt="pretty")) 234 235 236def main(): 237 seed = 123 238 np.random.seed(seed) 239 torch.manual_seed(seed) 240 results = [] 241 # Run one timing experiment comparing nn_mha vs composite_mha 242 for config in tqdm(generate_experiment_configs()): 243 results.append(Experiment(config, run_single_experiment(config))) 244 245 print_results(results) 246 247 248if __name__ == "__main__": 249 main() 250