xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/check_accuracy.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Workerimport os
3*da0073e9SAndroid Build Coastguard Workerimport sys
4*da0073e9SAndroid Build Coastguard Workerimport textwrap
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport pandas as pd
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker# Hack to have something similar to DISABLED_TEST. These models are flaky.
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerflaky_models = {
12*da0073e9SAndroid Build Coastguard Worker    "yolov3",
13*da0073e9SAndroid Build Coastguard Worker    "gluon_inception_v3",
14*da0073e9SAndroid Build Coastguard Worker    "detectron2_maskrcnn_r_101_c4",
15*da0073e9SAndroid Build Coastguard Worker    "XGLMForCausalLM",  # discovered in https://github.com/pytorch/pytorch/pull/128148
16*da0073e9SAndroid Build Coastguard Worker}
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Workerdef get_field(csv, model_name: str, field: str):
20*da0073e9SAndroid Build Coastguard Worker    try:
21*da0073e9SAndroid Build Coastguard Worker        return csv.loc[csv["name"] == model_name][field].item()
22*da0073e9SAndroid Build Coastguard Worker    except Exception as e:
23*da0073e9SAndroid Build Coastguard Worker        return None
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Workerdef check_accuracy(actual_csv, expected_csv, expected_filename):
27*da0073e9SAndroid Build Coastguard Worker    failed = []
28*da0073e9SAndroid Build Coastguard Worker    improved = []
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker    for model in actual_csv["name"]:
31*da0073e9SAndroid Build Coastguard Worker        accuracy = get_field(actual_csv, model, "accuracy")
32*da0073e9SAndroid Build Coastguard Worker        expected_accuracy = get_field(expected_csv, model, "accuracy")
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker        if accuracy == expected_accuracy:
35*da0073e9SAndroid Build Coastguard Worker            status = "PASS" if expected_accuracy == "pass" else "XFAIL"
36*da0073e9SAndroid Build Coastguard Worker            print(f"{model:34}  {status}")
37*da0073e9SAndroid Build Coastguard Worker            continue
38*da0073e9SAndroid Build Coastguard Worker        elif model in flaky_models:
39*da0073e9SAndroid Build Coastguard Worker            if accuracy == "pass":
40*da0073e9SAndroid Build Coastguard Worker                # model passed but marked xfailed
41*da0073e9SAndroid Build Coastguard Worker                status = "PASS_BUT_FLAKY:"
42*da0073e9SAndroid Build Coastguard Worker            else:
43*da0073e9SAndroid Build Coastguard Worker                # model failed but marked passe
44*da0073e9SAndroid Build Coastguard Worker                status = "FAIL_BUT_FLAKY:"
45*da0073e9SAndroid Build Coastguard Worker        elif accuracy != "pass":
46*da0073e9SAndroid Build Coastguard Worker            status = "FAIL:"
47*da0073e9SAndroid Build Coastguard Worker            failed.append(model)
48*da0073e9SAndroid Build Coastguard Worker        else:
49*da0073e9SAndroid Build Coastguard Worker            status = "IMPROVED:"
50*da0073e9SAndroid Build Coastguard Worker            improved.append(model)
51*da0073e9SAndroid Build Coastguard Worker        print(
52*da0073e9SAndroid Build Coastguard Worker            f"{model:34}  {status:9} accuracy={accuracy}, expected={expected_accuracy}"
53*da0073e9SAndroid Build Coastguard Worker        )
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    msg = ""
56*da0073e9SAndroid Build Coastguard Worker    if failed or improved:
57*da0073e9SAndroid Build Coastguard Worker        if failed:
58*da0073e9SAndroid Build Coastguard Worker            msg += textwrap.dedent(
59*da0073e9SAndroid Build Coastguard Worker                f"""
60*da0073e9SAndroid Build Coastguard Worker            Error: {len(failed)} models have accuracy status regressed:
61*da0073e9SAndroid Build Coastguard Worker                {' '.join(failed)}
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker            """
64*da0073e9SAndroid Build Coastguard Worker            )
65*da0073e9SAndroid Build Coastguard Worker        if improved:
66*da0073e9SAndroid Build Coastguard Worker            msg += textwrap.dedent(
67*da0073e9SAndroid Build Coastguard Worker                f"""
68*da0073e9SAndroid Build Coastguard Worker            Improvement: {len(improved)} models have accuracy status improved:
69*da0073e9SAndroid Build Coastguard Worker                {' '.join(improved)}
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker            """
72*da0073e9SAndroid Build Coastguard Worker            )
73*da0073e9SAndroid Build Coastguard Worker        sha = os.getenv("SHA1", "{your CI commit sha}")
74*da0073e9SAndroid Build Coastguard Worker        msg += textwrap.dedent(
75*da0073e9SAndroid Build Coastguard Worker            f"""
76*da0073e9SAndroid Build Coastguard Worker        If this change is expected, you can update `{expected_filename}` to reflect the new baseline.
77*da0073e9SAndroid Build Coastguard Worker        from pytorch/pytorch root, run
78*da0073e9SAndroid Build Coastguard Worker        `python benchmarks/dynamo/ci_expected_accuracy/update_expected.py {sha}`
79*da0073e9SAndroid Build Coastguard Worker        and then `git add` the resulting local changes to expected CSVs to your commit.
80*da0073e9SAndroid Build Coastguard Worker        """
81*da0073e9SAndroid Build Coastguard Worker        )
82*da0073e9SAndroid Build Coastguard Worker    return failed or improved, msg
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Workerdef main():
86*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser()
87*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--actual", type=str, required=True)
88*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--expected", type=str, required=True)
89*da0073e9SAndroid Build Coastguard Worker    args = parser.parse_args()
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    actual = pd.read_csv(args.actual)
92*da0073e9SAndroid Build Coastguard Worker    expected = pd.read_csv(args.expected)
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker    failed, msg = check_accuracy(actual, expected, args.expected)
95*da0073e9SAndroid Build Coastguard Worker    if failed:
96*da0073e9SAndroid Build Coastguard Worker        print(msg)
97*da0073e9SAndroid Build Coastguard Worker        sys.exit(1)
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
101*da0073e9SAndroid Build Coastguard Worker    main()
102