1*da0073e9SAndroid Build Coastguard Workerimport argparse 2*da0073e9SAndroid Build Coastguard Workerimport os 3*da0073e9SAndroid Build Coastguard Workerimport sys 4*da0073e9SAndroid Build Coastguard Workerimport textwrap 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport pandas as pd 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker# Hack to have something similar to DISABLED_TEST. These models are flaky. 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerflaky_models = { 12*da0073e9SAndroid Build Coastguard Worker "yolov3", 13*da0073e9SAndroid Build Coastguard Worker "gluon_inception_v3", 14*da0073e9SAndroid Build Coastguard Worker "detectron2_maskrcnn_r_101_c4", 15*da0073e9SAndroid Build Coastguard Worker "XGLMForCausalLM", # discovered in https://github.com/pytorch/pytorch/pull/128148 16*da0073e9SAndroid Build Coastguard Worker} 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerdef get_field(csv, model_name: str, field: str): 20*da0073e9SAndroid Build Coastguard Worker try: 21*da0073e9SAndroid Build Coastguard Worker return csv.loc[csv["name"] == model_name][field].item() 22*da0073e9SAndroid Build Coastguard Worker except Exception as e: 23*da0073e9SAndroid Build Coastguard Worker return None 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Workerdef check_accuracy(actual_csv, expected_csv, expected_filename): 27*da0073e9SAndroid Build Coastguard Worker failed = [] 28*da0073e9SAndroid Build Coastguard Worker improved = [] 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker for model in actual_csv["name"]: 31*da0073e9SAndroid Build Coastguard Worker accuracy = get_field(actual_csv, model, "accuracy") 32*da0073e9SAndroid Build Coastguard Worker expected_accuracy = get_field(expected_csv, model, "accuracy") 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker if accuracy == expected_accuracy: 35*da0073e9SAndroid Build Coastguard Worker status = "PASS" if expected_accuracy == "pass" else "XFAIL" 36*da0073e9SAndroid Build Coastguard Worker print(f"{model:34} {status}") 37*da0073e9SAndroid Build Coastguard Worker continue 38*da0073e9SAndroid Build Coastguard Worker elif model in flaky_models: 39*da0073e9SAndroid Build Coastguard Worker if accuracy == "pass": 40*da0073e9SAndroid Build Coastguard Worker # model passed but marked xfailed 41*da0073e9SAndroid Build Coastguard Worker status = "PASS_BUT_FLAKY:" 42*da0073e9SAndroid Build Coastguard Worker else: 43*da0073e9SAndroid Build Coastguard Worker # model failed but marked passe 44*da0073e9SAndroid Build Coastguard Worker status = "FAIL_BUT_FLAKY:" 45*da0073e9SAndroid Build Coastguard Worker elif accuracy != "pass": 46*da0073e9SAndroid Build Coastguard Worker status = "FAIL:" 47*da0073e9SAndroid Build Coastguard Worker failed.append(model) 48*da0073e9SAndroid Build Coastguard Worker else: 49*da0073e9SAndroid Build Coastguard Worker status = "IMPROVED:" 50*da0073e9SAndroid Build Coastguard Worker improved.append(model) 51*da0073e9SAndroid Build Coastguard Worker print( 52*da0073e9SAndroid Build Coastguard Worker f"{model:34} {status:9} accuracy={accuracy}, expected={expected_accuracy}" 53*da0073e9SAndroid Build Coastguard Worker ) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker msg = "" 56*da0073e9SAndroid Build Coastguard Worker if failed or improved: 57*da0073e9SAndroid Build Coastguard Worker if failed: 58*da0073e9SAndroid Build Coastguard Worker msg += textwrap.dedent( 59*da0073e9SAndroid Build Coastguard Worker f""" 60*da0073e9SAndroid Build Coastguard Worker Error: {len(failed)} models have accuracy status regressed: 61*da0073e9SAndroid Build Coastguard Worker {' '.join(failed)} 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker """ 64*da0073e9SAndroid Build Coastguard Worker ) 65*da0073e9SAndroid Build Coastguard Worker if improved: 66*da0073e9SAndroid Build Coastguard Worker msg += textwrap.dedent( 67*da0073e9SAndroid Build Coastguard Worker f""" 68*da0073e9SAndroid Build Coastguard Worker Improvement: {len(improved)} models have accuracy status improved: 69*da0073e9SAndroid Build Coastguard Worker {' '.join(improved)} 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker """ 72*da0073e9SAndroid Build Coastguard Worker ) 73*da0073e9SAndroid Build Coastguard Worker sha = os.getenv("SHA1", "{your CI commit sha}") 74*da0073e9SAndroid Build Coastguard Worker msg += textwrap.dedent( 75*da0073e9SAndroid Build Coastguard Worker f""" 76*da0073e9SAndroid Build Coastguard Worker If this change is expected, you can update `{expected_filename}` to reflect the new baseline. 77*da0073e9SAndroid Build Coastguard Worker from pytorch/pytorch root, run 78*da0073e9SAndroid Build Coastguard Worker `python benchmarks/dynamo/ci_expected_accuracy/update_expected.py {sha}` 79*da0073e9SAndroid Build Coastguard Worker and then `git add` the resulting local changes to expected CSVs to your commit. 80*da0073e9SAndroid Build Coastguard Worker """ 81*da0073e9SAndroid Build Coastguard Worker ) 82*da0073e9SAndroid Build Coastguard Worker return failed or improved, msg 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Workerdef main(): 86*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser() 87*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--actual", type=str, required=True) 88*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--expected", type=str, required=True) 89*da0073e9SAndroid Build Coastguard Worker args = parser.parse_args() 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker actual = pd.read_csv(args.actual) 92*da0073e9SAndroid Build Coastguard Worker expected = pd.read_csv(args.expected) 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker failed, msg = check_accuracy(actual, expected, args.expected) 95*da0073e9SAndroid Build Coastguard Worker if failed: 96*da0073e9SAndroid Build Coastguard Worker print(msg) 97*da0073e9SAndroid Build Coastguard Worker sys.exit(1) 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 101*da0073e9SAndroid Build Coastguard Worker main() 102