xref: /aosp_15_r20/external/pytorch/test/functorch/xfail_suggester.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import re
2
3import torch
4
5
6"""
7Instructions:
8
91. pytest -n 8 test/test_vmap.py test/test_ops.py test/test_aotdispatch.py > result.txt
102. python test/xfail_suggester.py
11"""
12
13with open("result.txt") as f:
14    lines = f.readlines()
15
16failed = [line for line in lines if line.startswith("FAILED")]
17p = re.compile("FAILED test/test_\w+.py::\w+::(\S+)")  # noqa: W605
18
19
20def get_failed_test(line):
21    m = p.match(line)
22    if m is None:
23        return None
24    return m.group(1)
25
26
27base_names = {
28    "test_grad_",
29    "test_vjp_",
30    "test_vmapvjp_",
31    "test_vmapvjp_has_batch_rule_",
32    "test_vjpvmap_",
33    "test_jvp_",
34    "test_vmapjvp_",
35    "test_vmapjvpall_has_batch_rule_",
36    "test_vmapjvpall_",
37    "test_jvpvjp_",
38    "test_vjpvjp_",
39    "test_decomposition_",
40    "test_make_fx_exhaustive_",
41    "test_vmap_exhaustive_",
42    "test_op_has_batch_rule_",
43    "test_vmap_autograd_grad_",
44}
45
46failed_tests = [get_failed_test(line) for line in lines]
47failed_tests = [match for match in failed_tests if match is not None]
48failed_tests = sorted(failed_tests)
49
50suggested_xfails = {}
51
52
53def remove_device_dtype(test):
54    return "_".join(test.split("_")[:-2])
55
56
57def belongs_to_base(test, base):
58    if not test.startswith(base):
59        return False
60    candidates = [try_base for try_base in base_names if len(try_base) > len(base)]
61    for candidate in candidates:
62        if test.startswith(candidate):
63            return False
64    return True
65
66
67def parse_namespace(base):
68    mappings = {
69        "nn_functional_": "nn.functional",
70        "fft_": "fft",
71        "linalg_": "linalg",
72        "_masked_": "_masked",
73        "sparse_": "sparse",
74        "special_": "special",
75    }
76    for heading in mappings.keys():
77        if base.startswith(heading):
78            return mappings[heading], base[len(heading) :]
79    return None, base
80
81
82def get_torch_module(namespace):
83    if namespace is None:
84        return torch
85    if namespace == "nn.functional":
86        return torch.nn.functional
87    return getattr(torch, namespace)
88
89
90def parse_base(base):
91    namespace, rest = parse_namespace(base)
92
93    apis = dir(get_torch_module(namespace))
94    apis = sorted(apis, key=lambda x: -len(x))
95
96    api = rest
97    variant = ""
98    for candidate in apis:
99        if rest.startswith(candidate):
100            api = candidate
101            variant = rest[len(candidate) + 1 :]
102            break
103    print(base, namespace, api, variant)
104    return namespace, api, variant
105
106
107def any_starts_with(strs, thing):
108    for s in strs:
109        if s.startswith(thing):
110            return True
111    return False
112
113
114def get_suggested_xfails(base, tests):
115    result = []
116    tests = [test[len(base) :] for test in tests if belongs_to_base(test, base)]
117
118    base_tests = {remove_device_dtype(test) for test in tests}
119    tests = set(tests)
120    for base in base_tests:
121        cpu_variant = base + "_cpu_float32"
122        cuda_variant = base + "_cuda_float32"
123        namespace, api, variant = parse_base(base)
124        if namespace is None:
125            api = api
126        else:
127            api = f"{namespace}.{api}"
128        if cpu_variant in tests and cuda_variant in tests:
129            result.append(f"xfail('{api}', '{variant}'),")
130            continue
131        if cpu_variant in tests:
132            result.append(f"xfail('{api}', '{variant}', device_type='cpu'),")
133            continue
134        if cuda_variant in tests:
135            result.append(f"xfail('{api}', '{variant}', device_type='cuda'),")
136            continue
137        result.append(f"skip('{api}', '{variant}',")
138    return result
139
140
141result = {base: get_suggested_xfails(base, failed_tests) for base in base_names}
142for k, v in result.items():
143    print("=" * 50)
144    print(k)
145    print("=" * 50)
146    print("\n".join(v))
147