xref: /aosp_15_r20/external/pytorch/benchmarks/gpt_fast/generate.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import dataclasses
2import itertools
3import platform
4import time
5from typing import Optional, Tuple
6
7from mixtral_moe_model import ConditionalFeedForward, Transformer as MixtralMoE
8from mixtral_moe_quantize import (
9    ConditionalFeedForwardInt8,
10    WeightOnlyInt8QuantHandler as MixtralMoEWeightOnlyInt8QuantHandler,
11)
12from model import Transformer as LLaMA
13from quantize import WeightOnlyInt8QuantHandler as LLaMAWeightOnlyInt8QuantHandler
14
15import torch
16import torch._inductor.config
17
18
19torch._inductor.config.coordinate_descent_tuning = True
20torch._inductor.config.triton.unique_kernel_names = True
21torch._inductor.config.fx_graph_cache = True  # Experimental feature to reduce compilation times, will be on by default in future
22torch._inductor.config.assert_indirect_indexing = False
23
24
25@dataclasses.dataclass
26class GPTModelConfig:
27    name: str
28    module: type
29    mode: Optional[str]
30    quantizer: type
31    token_per_sec: float
32    memory_bandwidth: float
33    compilation_time: float
34
35
36def device_sync(device):
37    if "cuda" in device:
38        torch.cuda.synchronize(device)
39    elif "cpu" in device:
40        pass
41    else:
42        print(f"device={device} is not yet suppported")
43
44
45def get_arch_name() -> str:
46    if torch.cuda.is_available():
47        return torch.cuda.get_device_name()
48    else:
49        # This returns x86_64 or arm64 (for aarch64)
50        return platform.machine()
51
52
53def multinomial_sample_one_no_sync(
54    probs_sort,
55):  # Does multinomial sampling without a cuda synchronization
56    q = torch.empty_like(probs_sort).exponential_(1)
57    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
58
59
60def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
61    logits = logits / max(temperature, 1e-5)
62
63    if top_k is not None:
64        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
65        pivot = v.select(-1, -1).unsqueeze(-1)
66        logits = torch.where(logits < pivot, -float("Inf"), logits)
67    probs = torch.nn.functional.softmax(logits, dim=-1)
68    return probs
69
70
71def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
72    probs = logits_to_probs(logits[0, -1], temperature, top_k)
73    idx_next = multinomial_sample_one_no_sync(probs)
74    return idx_next, probs
75
76
77@torch.compile(fullgraph=True)
78def prefill(
79    model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
80) -> torch.Tensor:
81    # input_pos: [B, S]
82    logits = model(x, input_pos)
83    return sample(logits, **sampling_kwargs)[0]
84
85
86@torch.compile(fullgraph=True, mode="reduce-overhead")
87def decode_one_token(
88    model: torch.nn.Module, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
89) -> Tuple[torch.Tensor, torch.Tensor]:
90    # input_pos: [B, 1]
91    assert input_pos.shape[-1] == 1
92    logits = model(x, input_pos)
93    return sample(logits, **sampling_kwargs)
94
95
96def decode_n_tokens(
97    model: torch.nn.Module,
98    cur_token: torch.Tensor,
99    input_pos: torch.Tensor,
100    num_new_tokens: int,
101    **sampling_kwargs,
102):
103    new_tokens, new_probs = [], []
104    for i in range(num_new_tokens):
105        with torch.nn.attention.sdpa_kernel(
106            torch.nn.attention.SDPBackend.MATH
107        ):  # Actually better for Inductor to codegen attention here
108            next_token, next_prob = decode_one_token(
109                model, cur_token, input_pos, **sampling_kwargs
110            )
111            input_pos += 1
112            new_tokens.append(next_token.clone())
113            new_probs.append(next_prob.clone())
114            cur_token = next_token.view(1, -1)
115
116    return new_tokens, new_probs
117
118
119@torch.no_grad()
120def generate(
121    model: torch.nn.Module, prompt: torch.Tensor, max_new_tokens: int, **sampling_kwargs
122) -> torch.Tensor:
123    device, dtype = prompt.device, prompt.dtype
124    T = prompt.size(0)
125    T_new = T + max_new_tokens
126    max_seq_length = min(T_new, model.config.block_size)
127
128    with torch.device(device):
129        model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
130
131    # create an empty tensor of the expected final shape and fill in the current tokens
132    empty = torch.empty(T_new, dtype=dtype, device=device)
133    empty[:T] = prompt
134    seq = empty
135    input_pos = torch.arange(0, T, device=device)
136
137    next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs)
138    seq[T] = next_token
139
140    input_pos = torch.tensor([T], device=device, dtype=torch.int)
141
142    generated_tokens, _ = decode_n_tokens(
143        model, next_token.view(1, -1), input_pos, max_new_tokens - 1, **sampling_kwargs
144    )
145    seq[T + 1 :] = torch.cat(generated_tokens)
146    return seq
147
148
149def _load_model(x: GPTModelConfig, device="cuda", precision=torch.bfloat16):
150    with torch.device("meta"):
151        model = x.module.from_name(x.name)
152    model = model.to(dtype=precision)
153
154    if x.mode == "int8":
155        print("Using int8 weight-only quantization!")
156        model = x.quantizer(model).convert_for_runtime()
157
158    state_dict = model.state_dict()
159    for k, v in state_dict.items():
160        state_dict[k] = torch.nn.Parameter(
161            torch.randn(v.shape, device=device).to(dtype=v.dtype),
162            requires_grad=v.requires_grad,
163        )
164    model.load_state_dict(state_dict, assign=True)
165    return model.eval()
166
167
168# Only count activated parameters and buffers.
169def _get_model_size(model):
170    model_size = 0
171    for name, child in model.named_children():
172        if not isinstance(child, torch.nn.Embedding):
173            model_size += sum(
174                p.numel() * p.dtype.itemsize
175                for p in itertools.chain(child.parameters(), child.buffers())
176            )
177
178    # Remove the inactivated experts from the model size if this is mixture of experts
179    # architecture, since only activated experts are loaded.
180    if hasattr(model.config, "num_experts"):
181        config = model.config
182        for submodule in model.modules():
183            if isinstance(
184                submodule, (ConditionalFeedForward, ConditionalFeedForwardInt8)
185            ):
186                model_size -= (
187                    sum(
188                        p.numel() * p.dtype.itemsize
189                        for p in itertools.chain(
190                            submodule.parameters(), child.buffers()
191                        )
192                    )
193                    * (config.num_experts - config.num_activated_experts)
194                    / config.num_experts
195                )
196
197    return model_size
198
199
200def run_experiment(
201    x: GPTModelConfig,
202    num_samples: int = 5,
203    max_new_tokens: int = 200,
204    top_k: int = 200,
205    temperature: float = 0.8,
206    device: str = "cuda",
207) -> None:
208    print(f"Loading model {x.name}")
209    t0 = time.time()
210    model = _load_model(x, device=device)
211    device_sync(device=device)  # MKG
212    print(f"Time to load model: {time.time() - t0:.02f} seconds")
213
214    prompt = torch.tensor(
215        [1, 15043, 29892, 590, 1024, 338], device=device, dtype=torch.int32
216    )
217    prompt_length = prompt.size(0)
218
219    torch.manual_seed(1234)
220    model_size = _get_model_size(model)
221
222    aggregate_metrics = {"tokens_per_sec": [], "memory_bandwidth": []}
223    start = -1
224    compilation_time = None
225
226    for i in range(start, num_samples):
227        device_sync(device=device)  # MKG
228
229        t0 = time.perf_counter()
230        y = generate(
231            model, prompt, max_new_tokens, temperature=temperature, top_k=top_k
232        )
233
234        if i == -1:
235            compilation_time = time.perf_counter() - t0
236            print(f"Compilation time: {compilation_time:.2f} seconds")
237            continue
238
239        device_sync(device=device)  # MKG
240        t = time.perf_counter() - t0
241        tokens_generated = y.size(0) - prompt_length
242        tokens_sec = tokens_generated / t
243        aggregate_metrics["tokens_per_sec"].append(tokens_sec)
244        aggregate_metrics["memory_bandwidth"].append(model_size * tokens_sec / 1e9)
245
246    token_per_sec = torch.mean(torch.tensor(aggregate_metrics["tokens_per_sec"])).item()
247    memory_bandwidth = torch.mean(
248        torch.tensor(aggregate_metrics["memory_bandwidth"])
249    ).item()
250    print(f"Average tokens/sec: {token_per_sec:.2f} tokens/sec")
251    print(f"Average bandwidth achieved: {memory_bandwidth:.02f} GB/s")
252    print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
253    return token_per_sec, memory_bandwidth, compilation_time
254
255
256# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
257def run_llama2_7b_bf16(device: str = "cuda"):
258    from benchmark import Experiment
259
260    model = GPTModelConfig(
261        "Llama-2-7b-chat-hf",
262        LLaMA,
263        "bfloat16",
264        LLaMAWeightOnlyInt8QuantHandler,
265        94,
266        1253,
267        162,
268    )
269    token_per_sec, memory_bandwidth, compilation_time = run_experiment(
270        model, device=device
271    )
272    return [
273        Experiment(
274            model.name,
275            "token_per_sec",
276            model.token_per_sec,
277            f"{token_per_sec:.02f}",
278            model.mode,
279            device,
280            get_arch_name(),
281            True,
282        ),
283        Experiment(
284            model.name,
285            "memory_bandwidth(GB/s)",
286            model.memory_bandwidth,
287            f"{memory_bandwidth:.02f}",
288            model.mode,
289            device,
290            get_arch_name(),
291            True,
292        ),
293        Experiment(
294            model.name,
295            "compilation_time(s)",
296            model.compilation_time,
297            f"{compilation_time:.02f}",
298            model.mode,
299            device,
300            get_arch_name(),
301            True,
302        ),
303    ]
304
305
306# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
307def run_llama2_7b_int8(device: str = "cuda"):
308    from benchmark import Experiment
309
310    model = GPTModelConfig(
311        "Llama-2-7b-chat-hf",
312        LLaMA,
313        "int8",
314        LLaMAWeightOnlyInt8QuantHandler,
315        144,
316        957,
317        172,
318    )
319    token_per_sec, memory_bandwidth, compilation_time = run_experiment(
320        model, device=device
321    )
322    return [
323        Experiment(
324            model.name,
325            "token_per_sec",
326            model.token_per_sec,
327            f"{token_per_sec:.02f}",
328            model.mode,
329            device,
330            get_arch_name(),
331            True,
332        ),
333        Experiment(
334            model.name,
335            "memory_bandwidth(GB/s)",
336            model.memory_bandwidth,
337            f"{memory_bandwidth:.02f}",
338            model.mode,
339            device,
340            get_arch_name(),
341            True,
342        ),
343        Experiment(
344            model.name,
345            "compilation_time(s)",
346            model.compilation_time,
347            f"{compilation_time:.02f}",
348            model.mode,
349            device,
350            get_arch_name(),
351            True,
352        ),
353    ]
354
355
356# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
357def run_mixtral_8x7b_int8(device: str = "cuda"):
358    from benchmark import Experiment
359
360    # We reduced the original number of layers from 32 to 16 to adapt CI memory limitation.
361    model = GPTModelConfig(
362        "Mixtral-8x7B-v0.1",
363        MixtralMoE,
364        "int8",
365        MixtralMoEWeightOnlyInt8QuantHandler,
366        175,
367        1130,
368        162,
369    )
370    token_per_sec, memory_bandwidth, compilation_time = run_experiment(
371        model, device=device
372    )
373    return [
374        Experiment(
375            model.name,
376            "token_per_sec",
377            model.token_per_sec,
378            f"{token_per_sec:.02f}",
379            model.mode,
380            device,
381            get_arch_name(),
382            True,
383        ),
384        Experiment(
385            model.name,
386            "memory_bandwidth(GB/s)",
387            model.memory_bandwidth,
388            f"{memory_bandwidth:.02f}",
389            model.mode,
390            device,
391            get_arch_name(),
392            True,
393        ),
394        Experiment(
395            model.name,
396            "compilation_time(s)",
397            model.compilation_time,
398            f"{compilation_time:.02f}",
399            model.mode,
400            device,
401            get_arch_name(),
402            True,
403        ),
404    ]
405