xref: /aosp_15_r20/external/pytorch/benchmarks/gpt_fast/benchmark.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import csv
3import dataclasses
4import os
5
6from generate import (
7    get_arch_name,
8    run_llama2_7b_bf16,
9    run_llama2_7b_int8,
10    run_mixtral_8x7b_int8,
11)
12
13import torch
14import torch.nn as nn
15from torch._inductor.runtime.benchmarking import benchmarker
16from torch.utils.flop_counter import FlopCounterMode
17
18
19WARMUP_ITER = 5
20
21A100_40G_BF16_TFLOPS = 312
22
23
24@dataclasses.dataclass
25class Experiment:
26    name: str
27    metric: str
28    target: float
29    actual: float
30    dtype: str
31    device: str
32    arch: str  # GPU name for CUDA or CPU arch for CPU
33    is_model: bool = False
34
35
36class SimpleMLP(nn.Module):
37    def __init__(self, input_dim, hidden_dim, output_dim, dtype):
38        super().__init__()
39        self.layers = nn.ModuleList(
40            [
41                nn.Linear(input_dim, hidden_dim, dtype=dtype),
42                nn.LayerNorm(hidden_dim, dtype=dtype),
43                nn.Linear(hidden_dim, output_dim, dtype=dtype),
44                nn.LayerNorm(output_dim, dtype=dtype),
45            ]
46        )
47
48    def forward(self, x):
49        for layer in self.layers:
50            x = layer(x)
51        return x
52
53
54def run_mlp_layer_norm_gelu(device: str = "cuda"):
55    dtype_flops_utilization_map = {
56        torch.bfloat16: "0.8",
57    }
58    input_shapes = [1024, 4096, 8192, 16384]
59    intermediate_size = 14336
60    results = []
61    for dtype, expected_flops_utilization in dtype_flops_utilization_map.items():
62        flops_utilization = 0
63        for D in input_shapes:
64            mod = SimpleMLP(
65                input_dim=D, hidden_dim=intermediate_size, output_dim=D, dtype=dtype
66            ).to(device)
67
68            x = torch.randn(D, device=device, dtype=torch.bfloat16)
69
70            with FlopCounterMode(display=False) as mode:
71                mod(x)
72
73            flops = mode.get_total_flops()
74
75            compiled_mod = torch.compile(mod, dynamic=False)
76
77            for _ in range(WARMUP_ITER):
78                compiled_mod(x)
79
80            benchmark_fn = (
81                benchmarker.benchmark_gpu
82                if device == "cuda"
83                else benchmarker.benchmark_cpu
84            )
85            us_per_iter = benchmark_fn(lambda: compiled_mod(x)) * 1000
86            flops_utilization += us_per_iter * flops / 1e9 / A100_40G_BF16_TFLOPS
87
88        flops_utilization = flops_utilization / len(input_shapes)
89        dtype_str = str(dtype).replace("torch.", "")
90        results.append(
91            Experiment(
92                "mlp_layer_norm_gelu",
93                "flops_utilization",
94                expected_flops_utilization,
95                f"{flops_utilization:.02f}",
96                dtype_str,
97                device,
98                get_arch_name(),
99            )
100        )
101    return results
102
103
104def run_layer_norm(device: str = "cuda"):
105    dtype_memory_bandwidth_map = {
106        torch.bfloat16: "950",
107    }
108    input_shapes = [1024, 4096, 8192, 16384]
109    BS = 4096
110    results = []
111    for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
112        memory_bandwidth = 0
113        for D in input_shapes:
114            mod = nn.LayerNorm(D).to(device)
115
116            x = torch.randn(BS, D, device=device, dtype=dtype)
117
118            compiled_mod = torch.compile(mod, dynamic=False)
119
120            for _ in range(WARMUP_ITER):
121                compiled_mod(x)
122
123            benchmark_fn = (
124                benchmarker.benchmark_gpu
125                if device == "cuda"
126                else benchmarker.benchmark_cpu
127            )
128            us_per_iter = benchmark_fn(lambda: compiled_mod(x)) * 1000
129            memory_bandwidth += (1e6 / us_per_iter) * 2 * BS * D * dtype.itemsize / 1e9
130
131        memory_bandwidth = memory_bandwidth / len(input_shapes)
132        dtype_str = str(dtype).replace("torch.", "")
133        results.append(
134            Experiment(
135                "layer_norm",
136                "memory_bandwidth(GB/s)",
137                expected_memory_bandwidth,
138                f"{memory_bandwidth:.02f}",
139                dtype_str,
140                device,
141                get_arch_name(),
142            )
143        )
144    return results
145
146
147@torch._inductor.config.patch(coordinate_descent_tuning=True)
148def run_gather_gemv(device: str = "cuda"):
149    E = 8
150    dtype_memory_bandwidth_map = {
151        torch.int8: "990",
152        torch.bfloat16: "1060",
153    }
154    input_shapes = [1024, 4096, 8192, 16384]
155    results = []
156    for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
157        memory_bandwidth = 0
158        for D in input_shapes:
159
160            def gather_gemv(W, score_idxs, x):
161                return W[score_idxs].to(x.dtype) @ x
162
163            W = torch.randn(E, D, D, device=device).to(dtype=dtype)
164            x = torch.randn(D, device=device, dtype=torch.bfloat16)
165            score_idxs = torch.tensor([3, 5], device=device)
166
167            compiled_fn = torch.compile(gather_gemv, dynamic=False)
168
169            for _ in range(WARMUP_ITER):
170                compiled_fn(W, score_idxs, x)
171
172            benchmark_fn = (
173                benchmarker.benchmark_gpu
174                if device == "cuda"
175                else benchmarker.benchmark_cpu
176            )
177            us_per_iter = benchmark_fn(lambda: compiled_fn(W, score_idxs, x)) * 1000
178            memory_bandwidth += (1e6 / us_per_iter) * 2 * D * D * dtype.itemsize / 1e9
179
180        memory_bandwidth = memory_bandwidth / len(input_shapes)
181        dtype_str = str(dtype).replace("torch.", "")
182        results.append(
183            Experiment(
184                "gather_gemv",
185                "memory_bandwidth(GB/s)",
186                expected_memory_bandwidth,
187                f"{memory_bandwidth:.02f}",
188                dtype_str,
189                device,
190                get_arch_name(),
191            )
192        )
193    return results
194
195
196@torch._inductor.config.patch(coordinate_descent_tuning=True)
197def run_gemv(device: str = "cuda"):
198    dtype_memory_bandwidth_map = {
199        torch.int8: "870",
200        torch.bfloat16: "990",
201    }
202    input_shapes = [1024, 4096, 8192, 16384]
203    results = []
204    for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
205        memory_bandwidth = 0
206        for D in input_shapes:
207
208            def gemv(W, x):
209                return W.to(x.dtype) @ x
210
211            W = torch.randn(D, D, device=device).to(dtype=dtype)
212            x = torch.randn(D, device=device, dtype=torch.bfloat16)
213
214            compiled_fn = torch.compile(gemv, dynamic=False)
215
216            for _ in range(WARMUP_ITER):
217                compiled_fn(W, x)
218
219            benchmark_fn = (
220                benchmarker.benchmark_gpu
221                if device == "cuda"
222                else benchmarker.benchmark_cpu
223            )
224            us_per_iter = benchmark_fn(lambda: compiled_fn(W, x)) * 1000
225            memory_bandwidth += (1e6 / us_per_iter) * D * D * dtype.itemsize / 1e9
226
227        memory_bandwidth = memory_bandwidth / len(input_shapes)
228        dtype_str = str(dtype).replace("torch.", "")
229        results.append(
230            Experiment(
231                "gemv",
232                "memory_bandwidth(GB/s)",
233                expected_memory_bandwidth,
234                f"{memory_bandwidth:.02f}",
235                dtype_str,
236                device,
237                get_arch_name(),
238            )
239        )
240    return results
241
242
243def output_csv(output_file, headers, row):
244    if os.path.exists(output_file):
245        with open(output_file) as fd:
246            lines = list(csv.reader(fd)) or [[]]
247            if headers and len(headers) > len(lines[0]):
248                # if prior results failed the header might not be filled in yet
249                lines[0] = headers
250            else:
251                headers = lines[0]
252    else:
253        lines = [headers]
254
255    if output_file != DEFAULT_OUTPUT_FILE:
256        os.makedirs(os.path.dirname(output_file), exist_ok=True)
257    lines.append([(f"{x:.6f}" if isinstance(x, float) else x) for x in row])
258    with open(output_file, "w") as fd:
259        writer = csv.writer(fd, lineterminator="\n")
260        for line in lines:
261            writer.writerow(list(line) + ["0"] * (len(headers) - len(line)))
262
263
264DEFAULT_OUTPUT_FILE = "gpt_fast_benchmark.csv"
265
266all_experiments = {
267    # A list of GPT models: LlaMa, Mixtral, etc.
268    run_llama2_7b_bf16,
269    run_llama2_7b_int8,
270    run_mixtral_8x7b_int8,
271    # A list of micro-benchmarks.
272    run_mlp_layer_norm_gelu,
273    run_layer_norm,
274    run_gather_gemv,
275    run_gemv,
276}
277
278
279def main(output_file=DEFAULT_OUTPUT_FILE):
280    results = []
281
282    for func in all_experiments:
283        try:
284            device = "cuda" if torch.cuda.is_available() else "cpu"
285        except AssertionError:
286            # This happens when torch is compiled with CUDA turning off completely
287            device = "cpu"
288
289        lst = func(device)
290        for x in lst:
291            results.append(dataclasses.astuple(x))
292
293    headers = [field.name for field in dataclasses.fields(Experiment)]
294
295    for row in results:
296        output_csv(output_file, headers, row)
297
298
299if __name__ == "__main__":
300    parser = argparse.ArgumentParser(description="Run experiments.")
301    parser.add_argument(
302        "--output",
303        default=DEFAULT_OUTPUT_FILE,
304        help="Set the output CSV file to save the benchmark results",
305    )
306    args = parser.parse_args()
307
308    main(output_file=args.output)
309