1import re 2 3import torch 4 5 6""" 7Instructions: 8 91. pytest -n 8 test/test_vmap.py test/test_ops.py test/test_aotdispatch.py > result.txt 102. python test/xfail_suggester.py 11""" 12 13with open("result.txt") as f: 14 lines = f.readlines() 15 16failed = [line for line in lines if line.startswith("FAILED")] 17p = re.compile("FAILED test/test_\w+.py::\w+::(\S+)") # noqa: W605 18 19 20def get_failed_test(line): 21 m = p.match(line) 22 if m is None: 23 return None 24 return m.group(1) 25 26 27base_names = { 28 "test_grad_", 29 "test_vjp_", 30 "test_vmapvjp_", 31 "test_vmapvjp_has_batch_rule_", 32 "test_vjpvmap_", 33 "test_jvp_", 34 "test_vmapjvp_", 35 "test_vmapjvpall_has_batch_rule_", 36 "test_vmapjvpall_", 37 "test_jvpvjp_", 38 "test_vjpvjp_", 39 "test_decomposition_", 40 "test_make_fx_exhaustive_", 41 "test_vmap_exhaustive_", 42 "test_op_has_batch_rule_", 43 "test_vmap_autograd_grad_", 44} 45 46failed_tests = [get_failed_test(line) for line in lines] 47failed_tests = [match for match in failed_tests if match is not None] 48failed_tests = sorted(failed_tests) 49 50suggested_xfails = {} 51 52 53def remove_device_dtype(test): 54 return "_".join(test.split("_")[:-2]) 55 56 57def belongs_to_base(test, base): 58 if not test.startswith(base): 59 return False 60 candidates = [try_base for try_base in base_names if len(try_base) > len(base)] 61 for candidate in candidates: 62 if test.startswith(candidate): 63 return False 64 return True 65 66 67def parse_namespace(base): 68 mappings = { 69 "nn_functional_": "nn.functional", 70 "fft_": "fft", 71 "linalg_": "linalg", 72 "_masked_": "_masked", 73 "sparse_": "sparse", 74 "special_": "special", 75 } 76 for heading in mappings.keys(): 77 if base.startswith(heading): 78 return mappings[heading], base[len(heading) :] 79 return None, base 80 81 82def get_torch_module(namespace): 83 if namespace is None: 84 return torch 85 if namespace == "nn.functional": 86 return torch.nn.functional 87 return getattr(torch, namespace) 88 89 90def parse_base(base): 91 namespace, rest = parse_namespace(base) 92 93 apis = dir(get_torch_module(namespace)) 94 apis = sorted(apis, key=lambda x: -len(x)) 95 96 api = rest 97 variant = "" 98 for candidate in apis: 99 if rest.startswith(candidate): 100 api = candidate 101 variant = rest[len(candidate) + 1 :] 102 break 103 print(base, namespace, api, variant) 104 return namespace, api, variant 105 106 107def any_starts_with(strs, thing): 108 for s in strs: 109 if s.startswith(thing): 110 return True 111 return False 112 113 114def get_suggested_xfails(base, tests): 115 result = [] 116 tests = [test[len(base) :] for test in tests if belongs_to_base(test, base)] 117 118 base_tests = {remove_device_dtype(test) for test in tests} 119 tests = set(tests) 120 for base in base_tests: 121 cpu_variant = base + "_cpu_float32" 122 cuda_variant = base + "_cuda_float32" 123 namespace, api, variant = parse_base(base) 124 if namespace is None: 125 api = api 126 else: 127 api = f"{namespace}.{api}" 128 if cpu_variant in tests and cuda_variant in tests: 129 result.append(f"xfail('{api}', '{variant}'),") 130 continue 131 if cpu_variant in tests: 132 result.append(f"xfail('{api}', '{variant}', device_type='cpu'),") 133 continue 134 if cuda_variant in tests: 135 result.append(f"xfail('{api}', '{variant}', device_type='cuda'),") 136 continue 137 result.append(f"skip('{api}', '{variant}',") 138 return result 139 140 141result = {base: get_suggested_xfails(base, failed_tests) for base in base_names} 142for k, v in result.items(): 143 print("=" * 50) 144 print(k) 145 print("=" * 50) 146 print("\n".join(v)) 147