xref: /aosp_15_r20/external/pytorch/benchmarks/gpt_fast/benchmark.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Workerimport csv
3*da0073e9SAndroid Build Coastguard Workerimport dataclasses
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerfrom generate import (
7*da0073e9SAndroid Build Coastguard Worker    get_arch_name,
8*da0073e9SAndroid Build Coastguard Worker    run_llama2_7b_bf16,
9*da0073e9SAndroid Build Coastguard Worker    run_llama2_7b_int8,
10*da0073e9SAndroid Build Coastguard Worker    run_mixtral_8x7b_int8,
11*da0073e9SAndroid Build Coastguard Worker)
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workerimport torch
14*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
15*da0073e9SAndroid Build Coastguard Workerfrom torch._inductor.runtime.benchmarking import benchmarker
16*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.flop_counter import FlopCounterMode
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard WorkerWARMUP_ITER = 5
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard WorkerA100_40G_BF16_TFLOPS = 312
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker@dataclasses.dataclass
25*da0073e9SAndroid Build Coastguard Workerclass Experiment:
26*da0073e9SAndroid Build Coastguard Worker    name: str
27*da0073e9SAndroid Build Coastguard Worker    metric: str
28*da0073e9SAndroid Build Coastguard Worker    target: float
29*da0073e9SAndroid Build Coastguard Worker    actual: float
30*da0073e9SAndroid Build Coastguard Worker    dtype: str
31*da0073e9SAndroid Build Coastguard Worker    device: str
32*da0073e9SAndroid Build Coastguard Worker    arch: str  # GPU name for CUDA or CPU arch for CPU
33*da0073e9SAndroid Build Coastguard Worker    is_model: bool = False
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Workerclass SimpleMLP(nn.Module):
37*da0073e9SAndroid Build Coastguard Worker    def __init__(self, input_dim, hidden_dim, output_dim, dtype):
38*da0073e9SAndroid Build Coastguard Worker        super().__init__()
39*da0073e9SAndroid Build Coastguard Worker        self.layers = nn.ModuleList(
40*da0073e9SAndroid Build Coastguard Worker            [
41*da0073e9SAndroid Build Coastguard Worker                nn.Linear(input_dim, hidden_dim, dtype=dtype),
42*da0073e9SAndroid Build Coastguard Worker                nn.LayerNorm(hidden_dim, dtype=dtype),
43*da0073e9SAndroid Build Coastguard Worker                nn.Linear(hidden_dim, output_dim, dtype=dtype),
44*da0073e9SAndroid Build Coastguard Worker                nn.LayerNorm(output_dim, dtype=dtype),
45*da0073e9SAndroid Build Coastguard Worker            ]
46*da0073e9SAndroid Build Coastguard Worker        )
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker    def forward(self, x):
49*da0073e9SAndroid Build Coastguard Worker        for layer in self.layers:
50*da0073e9SAndroid Build Coastguard Worker            x = layer(x)
51*da0073e9SAndroid Build Coastguard Worker        return x
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Workerdef run_mlp_layer_norm_gelu(device: str = "cuda"):
55*da0073e9SAndroid Build Coastguard Worker    dtype_flops_utilization_map = {
56*da0073e9SAndroid Build Coastguard Worker        torch.bfloat16: "0.8",
57*da0073e9SAndroid Build Coastguard Worker    }
58*da0073e9SAndroid Build Coastguard Worker    input_shapes = [1024, 4096, 8192, 16384]
59*da0073e9SAndroid Build Coastguard Worker    intermediate_size = 14336
60*da0073e9SAndroid Build Coastguard Worker    results = []
61*da0073e9SAndroid Build Coastguard Worker    for dtype, expected_flops_utilization in dtype_flops_utilization_map.items():
62*da0073e9SAndroid Build Coastguard Worker        flops_utilization = 0
63*da0073e9SAndroid Build Coastguard Worker        for D in input_shapes:
64*da0073e9SAndroid Build Coastguard Worker            mod = SimpleMLP(
65*da0073e9SAndroid Build Coastguard Worker                input_dim=D, hidden_dim=intermediate_size, output_dim=D, dtype=dtype
66*da0073e9SAndroid Build Coastguard Worker            ).to(device)
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(D, device=device, dtype=torch.bfloat16)
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker            with FlopCounterMode(display=False) as mode:
71*da0073e9SAndroid Build Coastguard Worker                mod(x)
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker            flops = mode.get_total_flops()
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker            compiled_mod = torch.compile(mod, dynamic=False)
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker            for _ in range(WARMUP_ITER):
78*da0073e9SAndroid Build Coastguard Worker                compiled_mod(x)
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker            benchmark_fn = (
81*da0073e9SAndroid Build Coastguard Worker                benchmarker.benchmark_gpu
82*da0073e9SAndroid Build Coastguard Worker                if device == "cuda"
83*da0073e9SAndroid Build Coastguard Worker                else benchmarker.benchmark_cpu
84*da0073e9SAndroid Build Coastguard Worker            )
85*da0073e9SAndroid Build Coastguard Worker            us_per_iter = benchmark_fn(lambda: compiled_mod(x)) * 1000
86*da0073e9SAndroid Build Coastguard Worker            flops_utilization += us_per_iter * flops / 1e9 / A100_40G_BF16_TFLOPS
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker        flops_utilization = flops_utilization / len(input_shapes)
89*da0073e9SAndroid Build Coastguard Worker        dtype_str = str(dtype).replace("torch.", "")
90*da0073e9SAndroid Build Coastguard Worker        results.append(
91*da0073e9SAndroid Build Coastguard Worker            Experiment(
92*da0073e9SAndroid Build Coastguard Worker                "mlp_layer_norm_gelu",
93*da0073e9SAndroid Build Coastguard Worker                "flops_utilization",
94*da0073e9SAndroid Build Coastguard Worker                expected_flops_utilization,
95*da0073e9SAndroid Build Coastguard Worker                f"{flops_utilization:.02f}",
96*da0073e9SAndroid Build Coastguard Worker                dtype_str,
97*da0073e9SAndroid Build Coastguard Worker                device,
98*da0073e9SAndroid Build Coastguard Worker                get_arch_name(),
99*da0073e9SAndroid Build Coastguard Worker            )
100*da0073e9SAndroid Build Coastguard Worker        )
101*da0073e9SAndroid Build Coastguard Worker    return results
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Workerdef run_layer_norm(device: str = "cuda"):
105*da0073e9SAndroid Build Coastguard Worker    dtype_memory_bandwidth_map = {
106*da0073e9SAndroid Build Coastguard Worker        torch.bfloat16: "950",
107*da0073e9SAndroid Build Coastguard Worker    }
108*da0073e9SAndroid Build Coastguard Worker    input_shapes = [1024, 4096, 8192, 16384]
109*da0073e9SAndroid Build Coastguard Worker    BS = 4096
110*da0073e9SAndroid Build Coastguard Worker    results = []
111*da0073e9SAndroid Build Coastguard Worker    for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
112*da0073e9SAndroid Build Coastguard Worker        memory_bandwidth = 0
113*da0073e9SAndroid Build Coastguard Worker        for D in input_shapes:
114*da0073e9SAndroid Build Coastguard Worker            mod = nn.LayerNorm(D).to(device)
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(BS, D, device=device, dtype=dtype)
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker            compiled_mod = torch.compile(mod, dynamic=False)
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker            for _ in range(WARMUP_ITER):
121*da0073e9SAndroid Build Coastguard Worker                compiled_mod(x)
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker            benchmark_fn = (
124*da0073e9SAndroid Build Coastguard Worker                benchmarker.benchmark_gpu
125*da0073e9SAndroid Build Coastguard Worker                if device == "cuda"
126*da0073e9SAndroid Build Coastguard Worker                else benchmarker.benchmark_cpu
127*da0073e9SAndroid Build Coastguard Worker            )
128*da0073e9SAndroid Build Coastguard Worker            us_per_iter = benchmark_fn(lambda: compiled_mod(x)) * 1000
129*da0073e9SAndroid Build Coastguard Worker            memory_bandwidth += (1e6 / us_per_iter) * 2 * BS * D * dtype.itemsize / 1e9
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker        memory_bandwidth = memory_bandwidth / len(input_shapes)
132*da0073e9SAndroid Build Coastguard Worker        dtype_str = str(dtype).replace("torch.", "")
133*da0073e9SAndroid Build Coastguard Worker        results.append(
134*da0073e9SAndroid Build Coastguard Worker            Experiment(
135*da0073e9SAndroid Build Coastguard Worker                "layer_norm",
136*da0073e9SAndroid Build Coastguard Worker                "memory_bandwidth(GB/s)",
137*da0073e9SAndroid Build Coastguard Worker                expected_memory_bandwidth,
138*da0073e9SAndroid Build Coastguard Worker                f"{memory_bandwidth:.02f}",
139*da0073e9SAndroid Build Coastguard Worker                dtype_str,
140*da0073e9SAndroid Build Coastguard Worker                device,
141*da0073e9SAndroid Build Coastguard Worker                get_arch_name(),
142*da0073e9SAndroid Build Coastguard Worker            )
143*da0073e9SAndroid Build Coastguard Worker        )
144*da0073e9SAndroid Build Coastguard Worker    return results
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker@torch._inductor.config.patch(coordinate_descent_tuning=True)
148*da0073e9SAndroid Build Coastguard Workerdef run_gather_gemv(device: str = "cuda"):
149*da0073e9SAndroid Build Coastguard Worker    E = 8
150*da0073e9SAndroid Build Coastguard Worker    dtype_memory_bandwidth_map = {
151*da0073e9SAndroid Build Coastguard Worker        torch.int8: "990",
152*da0073e9SAndroid Build Coastguard Worker        torch.bfloat16: "1060",
153*da0073e9SAndroid Build Coastguard Worker    }
154*da0073e9SAndroid Build Coastguard Worker    input_shapes = [1024, 4096, 8192, 16384]
155*da0073e9SAndroid Build Coastguard Worker    results = []
156*da0073e9SAndroid Build Coastguard Worker    for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
157*da0073e9SAndroid Build Coastguard Worker        memory_bandwidth = 0
158*da0073e9SAndroid Build Coastguard Worker        for D in input_shapes:
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker            def gather_gemv(W, score_idxs, x):
161*da0073e9SAndroid Build Coastguard Worker                return W[score_idxs].to(x.dtype) @ x
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker            W = torch.randn(E, D, D, device=device).to(dtype=dtype)
164*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(D, device=device, dtype=torch.bfloat16)
165*da0073e9SAndroid Build Coastguard Worker            score_idxs = torch.tensor([3, 5], device=device)
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker            compiled_fn = torch.compile(gather_gemv, dynamic=False)
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker            for _ in range(WARMUP_ITER):
170*da0073e9SAndroid Build Coastguard Worker                compiled_fn(W, score_idxs, x)
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker            benchmark_fn = (
173*da0073e9SAndroid Build Coastguard Worker                benchmarker.benchmark_gpu
174*da0073e9SAndroid Build Coastguard Worker                if device == "cuda"
175*da0073e9SAndroid Build Coastguard Worker                else benchmarker.benchmark_cpu
176*da0073e9SAndroid Build Coastguard Worker            )
177*da0073e9SAndroid Build Coastguard Worker            us_per_iter = benchmark_fn(lambda: compiled_fn(W, score_idxs, x)) * 1000
178*da0073e9SAndroid Build Coastguard Worker            memory_bandwidth += (1e6 / us_per_iter) * 2 * D * D * dtype.itemsize / 1e9
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker        memory_bandwidth = memory_bandwidth / len(input_shapes)
181*da0073e9SAndroid Build Coastguard Worker        dtype_str = str(dtype).replace("torch.", "")
182*da0073e9SAndroid Build Coastguard Worker        results.append(
183*da0073e9SAndroid Build Coastguard Worker            Experiment(
184*da0073e9SAndroid Build Coastguard Worker                "gather_gemv",
185*da0073e9SAndroid Build Coastguard Worker                "memory_bandwidth(GB/s)",
186*da0073e9SAndroid Build Coastguard Worker                expected_memory_bandwidth,
187*da0073e9SAndroid Build Coastguard Worker                f"{memory_bandwidth:.02f}",
188*da0073e9SAndroid Build Coastguard Worker                dtype_str,
189*da0073e9SAndroid Build Coastguard Worker                device,
190*da0073e9SAndroid Build Coastguard Worker                get_arch_name(),
191*da0073e9SAndroid Build Coastguard Worker            )
192*da0073e9SAndroid Build Coastguard Worker        )
193*da0073e9SAndroid Build Coastguard Worker    return results
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker@torch._inductor.config.patch(coordinate_descent_tuning=True)
197*da0073e9SAndroid Build Coastguard Workerdef run_gemv(device: str = "cuda"):
198*da0073e9SAndroid Build Coastguard Worker    dtype_memory_bandwidth_map = {
199*da0073e9SAndroid Build Coastguard Worker        torch.int8: "870",
200*da0073e9SAndroid Build Coastguard Worker        torch.bfloat16: "990",
201*da0073e9SAndroid Build Coastguard Worker    }
202*da0073e9SAndroid Build Coastguard Worker    input_shapes = [1024, 4096, 8192, 16384]
203*da0073e9SAndroid Build Coastguard Worker    results = []
204*da0073e9SAndroid Build Coastguard Worker    for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
205*da0073e9SAndroid Build Coastguard Worker        memory_bandwidth = 0
206*da0073e9SAndroid Build Coastguard Worker        for D in input_shapes:
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker            def gemv(W, x):
209*da0073e9SAndroid Build Coastguard Worker                return W.to(x.dtype) @ x
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker            W = torch.randn(D, D, device=device).to(dtype=dtype)
212*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(D, device=device, dtype=torch.bfloat16)
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker            compiled_fn = torch.compile(gemv, dynamic=False)
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker            for _ in range(WARMUP_ITER):
217*da0073e9SAndroid Build Coastguard Worker                compiled_fn(W, x)
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker            benchmark_fn = (
220*da0073e9SAndroid Build Coastguard Worker                benchmarker.benchmark_gpu
221*da0073e9SAndroid Build Coastguard Worker                if device == "cuda"
222*da0073e9SAndroid Build Coastguard Worker                else benchmarker.benchmark_cpu
223*da0073e9SAndroid Build Coastguard Worker            )
224*da0073e9SAndroid Build Coastguard Worker            us_per_iter = benchmark_fn(lambda: compiled_fn(W, x)) * 1000
225*da0073e9SAndroid Build Coastguard Worker            memory_bandwidth += (1e6 / us_per_iter) * D * D * dtype.itemsize / 1e9
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker        memory_bandwidth = memory_bandwidth / len(input_shapes)
228*da0073e9SAndroid Build Coastguard Worker        dtype_str = str(dtype).replace("torch.", "")
229*da0073e9SAndroid Build Coastguard Worker        results.append(
230*da0073e9SAndroid Build Coastguard Worker            Experiment(
231*da0073e9SAndroid Build Coastguard Worker                "gemv",
232*da0073e9SAndroid Build Coastguard Worker                "memory_bandwidth(GB/s)",
233*da0073e9SAndroid Build Coastguard Worker                expected_memory_bandwidth,
234*da0073e9SAndroid Build Coastguard Worker                f"{memory_bandwidth:.02f}",
235*da0073e9SAndroid Build Coastguard Worker                dtype_str,
236*da0073e9SAndroid Build Coastguard Worker                device,
237*da0073e9SAndroid Build Coastguard Worker                get_arch_name(),
238*da0073e9SAndroid Build Coastguard Worker            )
239*da0073e9SAndroid Build Coastguard Worker        )
240*da0073e9SAndroid Build Coastguard Worker    return results
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Workerdef output_csv(output_file, headers, row):
244*da0073e9SAndroid Build Coastguard Worker    if os.path.exists(output_file):
245*da0073e9SAndroid Build Coastguard Worker        with open(output_file) as fd:
246*da0073e9SAndroid Build Coastguard Worker            lines = list(csv.reader(fd)) or [[]]
247*da0073e9SAndroid Build Coastguard Worker            if headers and len(headers) > len(lines[0]):
248*da0073e9SAndroid Build Coastguard Worker                # if prior results failed the header might not be filled in yet
249*da0073e9SAndroid Build Coastguard Worker                lines[0] = headers
250*da0073e9SAndroid Build Coastguard Worker            else:
251*da0073e9SAndroid Build Coastguard Worker                headers = lines[0]
252*da0073e9SAndroid Build Coastguard Worker    else:
253*da0073e9SAndroid Build Coastguard Worker        lines = [headers]
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker    if output_file != DEFAULT_OUTPUT_FILE:
256*da0073e9SAndroid Build Coastguard Worker        os.makedirs(os.path.dirname(output_file), exist_ok=True)
257*da0073e9SAndroid Build Coastguard Worker    lines.append([(f"{x:.6f}" if isinstance(x, float) else x) for x in row])
258*da0073e9SAndroid Build Coastguard Worker    with open(output_file, "w") as fd:
259*da0073e9SAndroid Build Coastguard Worker        writer = csv.writer(fd, lineterminator="\n")
260*da0073e9SAndroid Build Coastguard Worker        for line in lines:
261*da0073e9SAndroid Build Coastguard Worker            writer.writerow(list(line) + ["0"] * (len(headers) - len(line)))
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard WorkerDEFAULT_OUTPUT_FILE = "gpt_fast_benchmark.csv"
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Workerall_experiments = {
267*da0073e9SAndroid Build Coastguard Worker    # A list of GPT models: LlaMa, Mixtral, etc.
268*da0073e9SAndroid Build Coastguard Worker    run_llama2_7b_bf16,
269*da0073e9SAndroid Build Coastguard Worker    run_llama2_7b_int8,
270*da0073e9SAndroid Build Coastguard Worker    run_mixtral_8x7b_int8,
271*da0073e9SAndroid Build Coastguard Worker    # A list of micro-benchmarks.
272*da0073e9SAndroid Build Coastguard Worker    run_mlp_layer_norm_gelu,
273*da0073e9SAndroid Build Coastguard Worker    run_layer_norm,
274*da0073e9SAndroid Build Coastguard Worker    run_gather_gemv,
275*da0073e9SAndroid Build Coastguard Worker    run_gemv,
276*da0073e9SAndroid Build Coastguard Worker}
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Workerdef main(output_file=DEFAULT_OUTPUT_FILE):
280*da0073e9SAndroid Build Coastguard Worker    results = []
281*da0073e9SAndroid Build Coastguard Worker
282*da0073e9SAndroid Build Coastguard Worker    for func in all_experiments:
283*da0073e9SAndroid Build Coastguard Worker        try:
284*da0073e9SAndroid Build Coastguard Worker            device = "cuda" if torch.cuda.is_available() else "cpu"
285*da0073e9SAndroid Build Coastguard Worker        except AssertionError:
286*da0073e9SAndroid Build Coastguard Worker            # This happens when torch is compiled with CUDA turning off completely
287*da0073e9SAndroid Build Coastguard Worker            device = "cpu"
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker        lst = func(device)
290*da0073e9SAndroid Build Coastguard Worker        for x in lst:
291*da0073e9SAndroid Build Coastguard Worker            results.append(dataclasses.astuple(x))
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker    headers = [field.name for field in dataclasses.fields(Experiment)]
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker    for row in results:
296*da0073e9SAndroid Build Coastguard Worker        output_csv(output_file, headers, row)
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
300*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(description="Run experiments.")
301*da0073e9SAndroid Build Coastguard Worker    parser.add_argument(
302*da0073e9SAndroid Build Coastguard Worker        "--output",
303*da0073e9SAndroid Build Coastguard Worker        default=DEFAULT_OUTPUT_FILE,
304*da0073e9SAndroid Build Coastguard Worker        help="Set the output CSV file to save the benchmark results",
305*da0073e9SAndroid Build Coastguard Worker    )
306*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker    main(output_file=args.output)
309