import re import torch """ Instructions: 1. pytest -n 8 test/test_vmap.py test/test_ops.py test/test_aotdispatch.py > result.txt 2. python test/xfail_suggester.py """ with open("result.txt") as f: lines = f.readlines() failed = [line for line in lines if line.startswith("FAILED")] p = re.compile("FAILED test/test_\w+.py::\w+::(\S+)") # noqa: W605 def get_failed_test(line): m = p.match(line) if m is None: return None return m.group(1) base_names = { "test_grad_", "test_vjp_", "test_vmapvjp_", "test_vmapvjp_has_batch_rule_", "test_vjpvmap_", "test_jvp_", "test_vmapjvp_", "test_vmapjvpall_has_batch_rule_", "test_vmapjvpall_", "test_jvpvjp_", "test_vjpvjp_", "test_decomposition_", "test_make_fx_exhaustive_", "test_vmap_exhaustive_", "test_op_has_batch_rule_", "test_vmap_autograd_grad_", } failed_tests = [get_failed_test(line) for line in lines] failed_tests = [match for match in failed_tests if match is not None] failed_tests = sorted(failed_tests) suggested_xfails = {} def remove_device_dtype(test): return "_".join(test.split("_")[:-2]) def belongs_to_base(test, base): if not test.startswith(base): return False candidates = [try_base for try_base in base_names if len(try_base) > len(base)] for candidate in candidates: if test.startswith(candidate): return False return True def parse_namespace(base): mappings = { "nn_functional_": "nn.functional", "fft_": "fft", "linalg_": "linalg", "_masked_": "_masked", "sparse_": "sparse", "special_": "special", } for heading in mappings.keys(): if base.startswith(heading): return mappings[heading], base[len(heading) :] return None, base def get_torch_module(namespace): if namespace is None: return torch if namespace == "nn.functional": return torch.nn.functional return getattr(torch, namespace) def parse_base(base): namespace, rest = parse_namespace(base) apis = dir(get_torch_module(namespace)) apis = sorted(apis, key=lambda x: -len(x)) api = rest variant = "" for candidate in apis: if rest.startswith(candidate): api = candidate variant = rest[len(candidate) + 1 :] break print(base, namespace, api, variant) return namespace, api, variant def any_starts_with(strs, thing): for s in strs: if s.startswith(thing): return True return False def get_suggested_xfails(base, tests): result = [] tests = [test[len(base) :] for test in tests if belongs_to_base(test, base)] base_tests = {remove_device_dtype(test) for test in tests} tests = set(tests) for base in base_tests: cpu_variant = base + "_cpu_float32" cuda_variant = base + "_cuda_float32" namespace, api, variant = parse_base(base) if namespace is None: api = api else: api = f"{namespace}.{api}" if cpu_variant in tests and cuda_variant in tests: result.append(f"xfail('{api}', '{variant}'),") continue if cpu_variant in tests: result.append(f"xfail('{api}', '{variant}', device_type='cpu'),") continue if cuda_variant in tests: result.append(f"xfail('{api}', '{variant}', device_type='cuda'),") continue result.append(f"skip('{api}', '{variant}',") return result result = {base: get_suggested_xfails(base, failed_tests) for base in base_names} for k, v in result.items(): print("=" * 50) print(k) print("=" * 50) print("\n".join(v))