1import dataclasses 2import itertools 3import platform 4import time 5from typing import Optional, Tuple 6 7from mixtral_moe_model import ConditionalFeedForward, Transformer as MixtralMoE 8from mixtral_moe_quantize import ( 9 ConditionalFeedForwardInt8, 10 WeightOnlyInt8QuantHandler as MixtralMoEWeightOnlyInt8QuantHandler, 11) 12from model import Transformer as LLaMA 13from quantize import WeightOnlyInt8QuantHandler as LLaMAWeightOnlyInt8QuantHandler 14 15import torch 16import torch._inductor.config 17 18 19torch._inductor.config.coordinate_descent_tuning = True 20torch._inductor.config.triton.unique_kernel_names = True 21torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future 22torch._inductor.config.assert_indirect_indexing = False 23 24 25@dataclasses.dataclass 26class GPTModelConfig: 27 name: str 28 module: type 29 mode: Optional[str] 30 quantizer: type 31 token_per_sec: float 32 memory_bandwidth: float 33 compilation_time: float 34 35 36def device_sync(device): 37 if "cuda" in device: 38 torch.cuda.synchronize(device) 39 elif "cpu" in device: 40 pass 41 else: 42 print(f"device={device} is not yet suppported") 43 44 45def get_arch_name() -> str: 46 if torch.cuda.is_available(): 47 return torch.cuda.get_device_name() 48 else: 49 # This returns x86_64 or arm64 (for aarch64) 50 return platform.machine() 51 52 53def multinomial_sample_one_no_sync( 54 probs_sort, 55): # Does multinomial sampling without a cuda synchronization 56 q = torch.empty_like(probs_sort).exponential_(1) 57 return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) 58 59 60def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): 61 logits = logits / max(temperature, 1e-5) 62 63 if top_k is not None: 64 v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 65 pivot = v.select(-1, -1).unsqueeze(-1) 66 logits = torch.where(logits < pivot, -float("Inf"), logits) 67 probs = torch.nn.functional.softmax(logits, dim=-1) 68 return probs 69 70 71def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): 72 probs = logits_to_probs(logits[0, -1], temperature, top_k) 73 idx_next = multinomial_sample_one_no_sync(probs) 74 return idx_next, probs 75 76 77@torch.compile(fullgraph=True) 78def prefill( 79 model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs 80) -> torch.Tensor: 81 # input_pos: [B, S] 82 logits = model(x, input_pos) 83 return sample(logits, **sampling_kwargs)[0] 84 85 86@torch.compile(fullgraph=True, mode="reduce-overhead") 87def decode_one_token( 88 model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs 89) -> Tuple[torch.Tensor, torch.Tensor]: 90 # input_pos: [B, 1] 91 assert input_pos.shape[-1] == 1 92 logits = model(x, input_pos) 93 return sample(logits, **sampling_kwargs) 94 95 96def decode_n_tokens( 97 model: torch.nn.Module, 98 cur_token: torch.Tensor, 99 input_pos: torch.Tensor, 100 num_new_tokens: int, 101 **sampling_kwargs, 102): 103 new_tokens, new_probs = [], [] 104 for i in range(num_new_tokens): 105 with torch.nn.attention.sdpa_kernel( 106 torch.nn.attention.SDPBackend.MATH 107 ): # Actually better for Inductor to codegen attention here 108 next_token, next_prob = decode_one_token( 109 model, cur_token, input_pos, **sampling_kwargs 110 ) 111 input_pos += 1 112 new_tokens.append(next_token.clone()) 113 new_probs.append(next_prob.clone()) 114 cur_token = next_token.view(1, -1) 115 116 return new_tokens, new_probs 117 118 119@torch.no_grad() 120def generate( 121 model: torch.nn.Module, prompt: torch.Tensor, max_new_tokens: int, **sampling_kwargs 122) -> torch.Tensor: 123 device, dtype = prompt.device, prompt.dtype 124 T = prompt.size(0) 125 T_new = T + max_new_tokens 126 max_seq_length = min(T_new, model.config.block_size) 127 128 with torch.device(device): 129 model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) 130 131 # create an empty tensor of the expected final shape and fill in the current tokens 132 empty = torch.empty(T_new, dtype=dtype, device=device) 133 empty[:T] = prompt 134 seq = empty 135 input_pos = torch.arange(0, T, device=device) 136 137 next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs) 138 seq[T] = next_token 139 140 input_pos = torch.tensor([T], device=device, dtype=torch.int) 141 142 generated_tokens, _ = decode_n_tokens( 143 model, next_token.view(1, -1), input_pos, max_new_tokens - 1, **sampling_kwargs 144 ) 145 seq[T + 1 :] = torch.cat(generated_tokens) 146 return seq 147 148 149def _load_model(x: GPTModelConfig, device="cuda", precision=torch.bfloat16): 150 with torch.device("meta"): 151 model = x.module.from_name(x.name) 152 model = model.to(dtype=precision) 153 154 if x.mode == "int8": 155 print("Using int8 weight-only quantization!") 156 model = x.quantizer(model).convert_for_runtime() 157 158 state_dict = model.state_dict() 159 for k, v in state_dict.items(): 160 state_dict[k] = torch.nn.Parameter( 161 torch.randn(v.shape, device=device).to(dtype=v.dtype), 162 requires_grad=v.requires_grad, 163 ) 164 model.load_state_dict(state_dict, assign=True) 165 return model.eval() 166 167 168# Only count activated parameters and buffers. 169def _get_model_size(model): 170 model_size = 0 171 for name, child in model.named_children(): 172 if not isinstance(child, torch.nn.Embedding): 173 model_size += sum( 174 p.numel() * p.dtype.itemsize 175 for p in itertools.chain(child.parameters(), child.buffers()) 176 ) 177 178 # Remove the inactivated experts from the model size if this is mixture of experts 179 # architecture, since only activated experts are loaded. 180 if hasattr(model.config, "num_experts"): 181 config = model.config 182 for submodule in model.modules(): 183 if isinstance( 184 submodule, (ConditionalFeedForward, ConditionalFeedForwardInt8) 185 ): 186 model_size -= ( 187 sum( 188 p.numel() * p.dtype.itemsize 189 for p in itertools.chain( 190 submodule.parameters(), child.buffers() 191 ) 192 ) 193 * (config.num_experts - config.num_activated_experts) 194 / config.num_experts 195 ) 196 197 return model_size 198 199 200def run_experiment( 201 x: GPTModelConfig, 202 num_samples: int = 5, 203 max_new_tokens: int = 200, 204 top_k: int = 200, 205 temperature: float = 0.8, 206 device: str = "cuda", 207) -> None: 208 print(f"Loading model {x.name}") 209 t0 = time.time() 210 model = _load_model(x, device=device) 211 device_sync(device=device) # MKG 212 print(f"Time to load model: {time.time() - t0:.02f} seconds") 213 214 prompt = torch.tensor( 215 [1, 15043, 29892, 590, 1024, 338], device=device, dtype=torch.int32 216 ) 217 prompt_length = prompt.size(0) 218 219 torch.manual_seed(1234) 220 model_size = _get_model_size(model) 221 222 aggregate_metrics = {"tokens_per_sec": [], "memory_bandwidth": []} 223 start = -1 224 compilation_time = None 225 226 for i in range(start, num_samples): 227 device_sync(device=device) # MKG 228 229 t0 = time.perf_counter() 230 y = generate( 231 model, prompt, max_new_tokens, temperature=temperature, top_k=top_k 232 ) 233 234 if i == -1: 235 compilation_time = time.perf_counter() - t0 236 print(f"Compilation time: {compilation_time:.2f} seconds") 237 continue 238 239 device_sync(device=device) # MKG 240 t = time.perf_counter() - t0 241 tokens_generated = y.size(0) - prompt_length 242 tokens_sec = tokens_generated / t 243 aggregate_metrics["tokens_per_sec"].append(tokens_sec) 244 aggregate_metrics["memory_bandwidth"].append(model_size * tokens_sec / 1e9) 245 246 token_per_sec = torch.mean(torch.tensor(aggregate_metrics["tokens_per_sec"])).item() 247 memory_bandwidth = torch.mean( 248 torch.tensor(aggregate_metrics["memory_bandwidth"]) 249 ).item() 250 print(f"Average tokens/sec: {token_per_sec:.2f} tokens/sec") 251 print(f"Average bandwidth achieved: {memory_bandwidth:.02f} GB/s") 252 print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 253 return token_per_sec, memory_bandwidth, compilation_time 254 255 256# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB. 257def run_llama2_7b_bf16(device: str = "cuda"): 258 from benchmark import Experiment 259 260 model = GPTModelConfig( 261 "Llama-2-7b-chat-hf", 262 LLaMA, 263 "bfloat16", 264 LLaMAWeightOnlyInt8QuantHandler, 265 94, 266 1253, 267 162, 268 ) 269 token_per_sec, memory_bandwidth, compilation_time = run_experiment( 270 model, device=device 271 ) 272 return [ 273 Experiment( 274 model.name, 275 "token_per_sec", 276 model.token_per_sec, 277 f"{token_per_sec:.02f}", 278 model.mode, 279 device, 280 get_arch_name(), 281 True, 282 ), 283 Experiment( 284 model.name, 285 "memory_bandwidth(GB/s)", 286 model.memory_bandwidth, 287 f"{memory_bandwidth:.02f}", 288 model.mode, 289 device, 290 get_arch_name(), 291 True, 292 ), 293 Experiment( 294 model.name, 295 "compilation_time(s)", 296 model.compilation_time, 297 f"{compilation_time:.02f}", 298 model.mode, 299 device, 300 get_arch_name(), 301 True, 302 ), 303 ] 304 305 306# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB. 307def run_llama2_7b_int8(device: str = "cuda"): 308 from benchmark import Experiment 309 310 model = GPTModelConfig( 311 "Llama-2-7b-chat-hf", 312 LLaMA, 313 "int8", 314 LLaMAWeightOnlyInt8QuantHandler, 315 144, 316 957, 317 172, 318 ) 319 token_per_sec, memory_bandwidth, compilation_time = run_experiment( 320 model, device=device 321 ) 322 return [ 323 Experiment( 324 model.name, 325 "token_per_sec", 326 model.token_per_sec, 327 f"{token_per_sec:.02f}", 328 model.mode, 329 device, 330 get_arch_name(), 331 True, 332 ), 333 Experiment( 334 model.name, 335 "memory_bandwidth(GB/s)", 336 model.memory_bandwidth, 337 f"{memory_bandwidth:.02f}", 338 model.mode, 339 device, 340 get_arch_name(), 341 True, 342 ), 343 Experiment( 344 model.name, 345 "compilation_time(s)", 346 model.compilation_time, 347 f"{compilation_time:.02f}", 348 model.mode, 349 device, 350 get_arch_name(), 351 True, 352 ), 353 ] 354 355 356# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB. 357def run_mixtral_8x7b_int8(device: str = "cuda"): 358 from benchmark import Experiment 359 360 # We reduced the original number of layers from 32 to 16 to adapt CI memory limitation. 361 model = GPTModelConfig( 362 "Mixtral-8x7B-v0.1", 363 MixtralMoE, 364 "int8", 365 MixtralMoEWeightOnlyInt8QuantHandler, 366 175, 367 1130, 368 162, 369 ) 370 token_per_sec, memory_bandwidth, compilation_time = run_experiment( 371 model, device=device 372 ) 373 return [ 374 Experiment( 375 model.name, 376 "token_per_sec", 377 model.token_per_sec, 378 f"{token_per_sec:.02f}", 379 model.mode, 380 device, 381 get_arch_name(), 382 True, 383 ), 384 Experiment( 385 model.name, 386 "memory_bandwidth(GB/s)", 387 model.memory_bandwidth, 388 f"{memory_bandwidth:.02f}", 389 model.mode, 390 device, 391 get_arch_name(), 392 True, 393 ), 394 Experiment( 395 model.name, 396 "compilation_time(s)", 397 model.compilation_time, 398 f"{compilation_time:.02f}", 399 model.mode, 400 device, 401 get_arch_name(), 402 True, 403 ), 404 ] 405