1import argparse 2import csv 3import dataclasses 4import os 5 6from generate import ( 7 get_arch_name, 8 run_llama2_7b_bf16, 9 run_llama2_7b_int8, 10 run_mixtral_8x7b_int8, 11) 12 13import torch 14import torch.nn as nn 15from torch._inductor.runtime.benchmarking import benchmarker 16from torch.utils.flop_counter import FlopCounterMode 17 18 19WARMUP_ITER = 5 20 21A100_40G_BF16_TFLOPS = 312 22 23 24@dataclasses.dataclass 25class Experiment: 26 name: str 27 metric: str 28 target: float 29 actual: float 30 dtype: str 31 device: str 32 arch: str # GPU name for CUDA or CPU arch for CPU 33 is_model: bool = False 34 35 36class SimpleMLP(nn.Module): 37 def __init__(self, input_dim, hidden_dim, output_dim, dtype): 38 super().__init__() 39 self.layers = nn.ModuleList( 40 [ 41 nn.Linear(input_dim, hidden_dim, dtype=dtype), 42 nn.LayerNorm(hidden_dim, dtype=dtype), 43 nn.Linear(hidden_dim, output_dim, dtype=dtype), 44 nn.LayerNorm(output_dim, dtype=dtype), 45 ] 46 ) 47 48 def forward(self, x): 49 for layer in self.layers: 50 x = layer(x) 51 return x 52 53 54def run_mlp_layer_norm_gelu(device: str = "cuda"): 55 dtype_flops_utilization_map = { 56 torch.bfloat16: "0.8", 57 } 58 input_shapes = [1024, 4096, 8192, 16384] 59 intermediate_size = 14336 60 results = [] 61 for dtype, expected_flops_utilization in dtype_flops_utilization_map.items(): 62 flops_utilization = 0 63 for D in input_shapes: 64 mod = SimpleMLP( 65 input_dim=D, hidden_dim=intermediate_size, output_dim=D, dtype=dtype 66 ).to(device) 67 68 x = torch.randn(D, device=device, dtype=torch.bfloat16) 69 70 with FlopCounterMode(display=False) as mode: 71 mod(x) 72 73 flops = mode.get_total_flops() 74 75 compiled_mod = torch.compile(mod, dynamic=False) 76 77 for _ in range(WARMUP_ITER): 78 compiled_mod(x) 79 80 benchmark_fn = ( 81 benchmarker.benchmark_gpu 82 if device == "cuda" 83 else benchmarker.benchmark_cpu 84 ) 85 us_per_iter = benchmark_fn(lambda: compiled_mod(x)) * 1000 86 flops_utilization += us_per_iter * flops / 1e9 / A100_40G_BF16_TFLOPS 87 88 flops_utilization = flops_utilization / len(input_shapes) 89 dtype_str = str(dtype).replace("torch.", "") 90 results.append( 91 Experiment( 92 "mlp_layer_norm_gelu", 93 "flops_utilization", 94 expected_flops_utilization, 95 f"{flops_utilization:.02f}", 96 dtype_str, 97 device, 98 get_arch_name(), 99 ) 100 ) 101 return results 102 103 104def run_layer_norm(device: str = "cuda"): 105 dtype_memory_bandwidth_map = { 106 torch.bfloat16: "950", 107 } 108 input_shapes = [1024, 4096, 8192, 16384] 109 BS = 4096 110 results = [] 111 for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items(): 112 memory_bandwidth = 0 113 for D in input_shapes: 114 mod = nn.LayerNorm(D).to(device) 115 116 x = torch.randn(BS, D, device=device, dtype=dtype) 117 118 compiled_mod = torch.compile(mod, dynamic=False) 119 120 for _ in range(WARMUP_ITER): 121 compiled_mod(x) 122 123 benchmark_fn = ( 124 benchmarker.benchmark_gpu 125 if device == "cuda" 126 else benchmarker.benchmark_cpu 127 ) 128 us_per_iter = benchmark_fn(lambda: compiled_mod(x)) * 1000 129 memory_bandwidth += (1e6 / us_per_iter) * 2 * BS * D * dtype.itemsize / 1e9 130 131 memory_bandwidth = memory_bandwidth / len(input_shapes) 132 dtype_str = str(dtype).replace("torch.", "") 133 results.append( 134 Experiment( 135 "layer_norm", 136 "memory_bandwidth(GB/s)", 137 expected_memory_bandwidth, 138 f"{memory_bandwidth:.02f}", 139 dtype_str, 140 device, 141 get_arch_name(), 142 ) 143 ) 144 return results 145 146 147@torch._inductor.config.patch(coordinate_descent_tuning=True) 148def run_gather_gemv(device: str = "cuda"): 149 E = 8 150 dtype_memory_bandwidth_map = { 151 torch.int8: "990", 152 torch.bfloat16: "1060", 153 } 154 input_shapes = [1024, 4096, 8192, 16384] 155 results = [] 156 for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items(): 157 memory_bandwidth = 0 158 for D in input_shapes: 159 160 def gather_gemv(W, score_idxs, x): 161 return W[score_idxs].to(x.dtype) @ x 162 163 W = torch.randn(E, D, D, device=device).to(dtype=dtype) 164 x = torch.randn(D, device=device, dtype=torch.bfloat16) 165 score_idxs = torch.tensor([3, 5], device=device) 166 167 compiled_fn = torch.compile(gather_gemv, dynamic=False) 168 169 for _ in range(WARMUP_ITER): 170 compiled_fn(W, score_idxs, x) 171 172 benchmark_fn = ( 173 benchmarker.benchmark_gpu 174 if device == "cuda" 175 else benchmarker.benchmark_cpu 176 ) 177 us_per_iter = benchmark_fn(lambda: compiled_fn(W, score_idxs, x)) * 1000 178 memory_bandwidth += (1e6 / us_per_iter) * 2 * D * D * dtype.itemsize / 1e9 179 180 memory_bandwidth = memory_bandwidth / len(input_shapes) 181 dtype_str = str(dtype).replace("torch.", "") 182 results.append( 183 Experiment( 184 "gather_gemv", 185 "memory_bandwidth(GB/s)", 186 expected_memory_bandwidth, 187 f"{memory_bandwidth:.02f}", 188 dtype_str, 189 device, 190 get_arch_name(), 191 ) 192 ) 193 return results 194 195 196@torch._inductor.config.patch(coordinate_descent_tuning=True) 197def run_gemv(device: str = "cuda"): 198 dtype_memory_bandwidth_map = { 199 torch.int8: "870", 200 torch.bfloat16: "990", 201 } 202 input_shapes = [1024, 4096, 8192, 16384] 203 results = [] 204 for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items(): 205 memory_bandwidth = 0 206 for D in input_shapes: 207 208 def gemv(W, x): 209 return W.to(x.dtype) @ x 210 211 W = torch.randn(D, D, device=device).to(dtype=dtype) 212 x = torch.randn(D, device=device, dtype=torch.bfloat16) 213 214 compiled_fn = torch.compile(gemv, dynamic=False) 215 216 for _ in range(WARMUP_ITER): 217 compiled_fn(W, x) 218 219 benchmark_fn = ( 220 benchmarker.benchmark_gpu 221 if device == "cuda" 222 else benchmarker.benchmark_cpu 223 ) 224 us_per_iter = benchmark_fn(lambda: compiled_fn(W, x)) * 1000 225 memory_bandwidth += (1e6 / us_per_iter) * D * D * dtype.itemsize / 1e9 226 227 memory_bandwidth = memory_bandwidth / len(input_shapes) 228 dtype_str = str(dtype).replace("torch.", "") 229 results.append( 230 Experiment( 231 "gemv", 232 "memory_bandwidth(GB/s)", 233 expected_memory_bandwidth, 234 f"{memory_bandwidth:.02f}", 235 dtype_str, 236 device, 237 get_arch_name(), 238 ) 239 ) 240 return results 241 242 243def output_csv(output_file, headers, row): 244 if os.path.exists(output_file): 245 with open(output_file) as fd: 246 lines = list(csv.reader(fd)) or [[]] 247 if headers and len(headers) > len(lines[0]): 248 # if prior results failed the header might not be filled in yet 249 lines[0] = headers 250 else: 251 headers = lines[0] 252 else: 253 lines = [headers] 254 255 if output_file != DEFAULT_OUTPUT_FILE: 256 os.makedirs(os.path.dirname(output_file), exist_ok=True) 257 lines.append([(f"{x:.6f}" if isinstance(x, float) else x) for x in row]) 258 with open(output_file, "w") as fd: 259 writer = csv.writer(fd, lineterminator="\n") 260 for line in lines: 261 writer.writerow(list(line) + ["0"] * (len(headers) - len(line))) 262 263 264DEFAULT_OUTPUT_FILE = "gpt_fast_benchmark.csv" 265 266all_experiments = { 267 # A list of GPT models: LlaMa, Mixtral, etc. 268 run_llama2_7b_bf16, 269 run_llama2_7b_int8, 270 run_mixtral_8x7b_int8, 271 # A list of micro-benchmarks. 272 run_mlp_layer_norm_gelu, 273 run_layer_norm, 274 run_gather_gemv, 275 run_gemv, 276} 277 278 279def main(output_file=DEFAULT_OUTPUT_FILE): 280 results = [] 281 282 for func in all_experiments: 283 try: 284 device = "cuda" if torch.cuda.is_available() else "cpu" 285 except AssertionError: 286 # This happens when torch is compiled with CUDA turning off completely 287 device = "cpu" 288 289 lst = func(device) 290 for x in lst: 291 results.append(dataclasses.astuple(x)) 292 293 headers = [field.name for field in dataclasses.fields(Experiment)] 294 295 for row in results: 296 output_csv(output_file, headers, row) 297 298 299if __name__ == "__main__": 300 parser = argparse.ArgumentParser(description="Run experiments.") 301 parser.add_argument( 302 "--output", 303 default=DEFAULT_OUTPUT_FILE, 304 help="Set the output CSV file to save the benchmark results", 305 ) 306 args = parser.parse_args() 307 308 main(output_file=args.output) 309