xref: /aosp_15_r20/external/pytorch/benchmarks/inference/process_metrics.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2This file will take the csv outputs from server.py, calculate the mean and
3variance of the warmup_latency, average_latency, throughput and gpu_util
4and write these to the corresponding `results/output_{batch_size}_{compile}.md`
5file, appending to the file if it exists or creatng a new one otherwise.
6"""
7
8import argparse
9import os
10
11import pandas as pd
12
13
14if __name__ == "__main__":
15    parser = argparse.ArgumentParser(description="Parse output files")
16    parser.add_argument("--csv", type=str, help="Path to csv file")
17    parser.add_argument("--name", type=str, help="Name of experiment")
18    args = parser.parse_args()
19
20    input_csv = "./results/" + args.csv
21    df = pd.read_csv(input_csv)
22
23    batch_size = int(os.path.basename(args.csv).split("_")[1])
24    compile = os.path.basename(args.csv).split("_")[-1].split(".")[0]
25
26    # Calculate mean and standard deviation for a subset of metrics
27    metrics = ["warmup_latency", "average_latency", "throughput", "gpu_util"]
28    means = {}
29    stds = {}
30
31    for metric in metrics:
32        means[metric] = df[metric].mean()
33        stds[metric] = df[metric].std()
34
35    output_md = f"results/output_{batch_size}_{compile}.md"
36    write_header = os.path.isfile(output_md) is False
37
38    with open(output_md, "a+") as f:
39        if write_header:
40            f.write(f"## Batch Size {batch_size} Compile {compile}\n\n")
41            f.write(
42                "| Experiment | Warmup_latency (s) | Average_latency (s) | Throughput (samples/sec) | GPU Utilization (%) |\n"
43            )
44            f.write(
45                "| ---------- | ------------------ | ------------------- | ------------------------ | ------------------- |\n"
46            )
47
48        line = f"| {args.name} |"
49        for metric in metrics:
50            line += f" {means[metric]:.3f} +/- {stds[metric]:.3f} |"
51        f.write(line + "\n")
52