xref: /aosp_15_r20/external/pytorch/test/functorch/discover_coverage.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import copy
2import enum
3import pprint
4import unittest
5from enum import Enum
6
7# Importing these files make modifications to the op_db that we need
8import test_ops  # noqa: F401
9
10import test_vmap  # noqa: F401
11from functorch_additional_op_db import additional_op_db
12
13import torch
14import torch._functorch.top_operators_github_usage as top_ops
15from torch.testing._internal.common_device_type import toleranceOverride
16from torch.testing._internal.common_methods_invocations import op_db
17
18
19all_overridable = list(torch.overrides.get_testing_overrides().keys())
20
21public_docs = [
22    (torch.nn.functional, "torch.nn.functional", "docs/source/nn.functional.rst"),
23    (torch.fft, "torch.fft", "docs/source/fft.rst"),
24    (torch.special, "torch.special", "docs/source/special.rst"),
25    (torch.linalg, "torch.linalg", "docs/source/linalg.rst"),
26    (torch, "torch", "docs/source/torch.rst"),
27    (torch.Tensor, "torch.Tensor", "docs/source/tensors.rst"),
28]
29
30# torch.abs, Tensor.abs, Tensor.abs_ are all considered to be different
31
32
33def get_public_overridable_apis(pytorch_root="/raid/rzou/pt/debug-cpu"):
34    results = {}
35    all_overridable_apis = set(torch.overrides.get_testing_overrides().keys())
36    for module, module_name, src in public_docs:
37        with open(f"{pytorch_root}/{src}") as f:
38            lines = f.readlines()
39        # APIs eitehr begin with 4 spaces or ".. autofunction::"
40        api_lines1 = [line.strip() for line in lines if line.startswith(" " * 4)]
41        api_lines2 = [
42            line.strip()[len(".. autofunction:: ") :]
43            for line in lines
44            if line.startswith(".. autofunction::")
45        ]
46        lines = api_lines1 + api_lines2
47        lines = [line[7:] if line.startswith("Tensor.") else line for line in lines]
48        lines = [line for line in lines if hasattr(module, line)]
49        for line in lines:
50            api = getattr(module, line)
51            if api in all_overridable_apis:
52                results[f"{module_name}.{line}"] = api
53    return results
54
55
56denylist = {
57    "torch.Tensor.data_ptr",
58    "torch.Tensor.dim",
59    "torch.Tensor.element_size",
60    "torch.Tensor.backward",
61    "torch.Tensor.as_strided",
62    "torch.Tensor.register_hook",
63    "torch.Tensor.record_stream",
64    "torch.Tensor.qscheme",
65    "torch.Tensor.ndimension",
66    "torch.Tensor.smm",
67    "torch.Tensor.sspaddmm",
68    "torch.Tensor.retain_grad",
69    "torch.Tensor.sparse_mask",
70    "torch.Tensor.sparse_dim",
71    "torch.Tensor.dense_dim",
72    "torch.Tensor.values",
73    "torch.Tensor.indices",
74    "torch.Tensor.numel",
75    "torch.Tensor.size",
76    "torch.Tensor.nelement",
77    "torch.Tensor.q_scale",
78    "torch.Tensor.q_zero_point",
79    "torch.Tensor.q_per_channel_scales",
80    "torch.Tensor.q_per_channel_zero_points",
81    "torch.Tensor.q_per_channel_axis",
82    "torch.Tensor.int_repr",
83    "torch.Tensor.to_sparse",
84    "torch.Tensor.is_inference",
85    "torch.Tensor.storage",
86    "torch.Tensor.storage_type",
87}
88
89
90def get_method_only_ops_we_care_about():
91    apis = get_public_overridable_apis()
92    result = []
93    for key in apis.keys():
94        if not key.startswith("torch.Tensor"):
95            continue
96        if key in denylist:
97            continue
98        api = key.split(".")[2]
99        # filter out in-place
100        if api.endswith("_"):
101            continue
102        if f"torch.{api}" not in apis.keys():
103            result.append(api)
104    return result
105
106
107# Deduplicates torch.abs and Tensor.abs
108
109
110def get_public_overridable_ops():
111    results = get_public_overridable_apis()
112    cpy = copy.deepcopy(results)
113    for key in cpy.keys():
114        if not key.startswith("torch.Tensor"):
115            continue
116        api = key.split(".")[2]
117        if f"torch.{api}" in results.keys():
118            del results[key]
119    return results
120
121
122def get_public_overridable_outplace_ops():
123    results = get_public_overridable_ops()
124    cpy = copy.deepcopy(results)
125    for key in cpy.keys():
126        # NB: there are no dunder methods bcs we don't document those
127        if key.endswith("_"):
128            del results[key]
129    return results
130
131
132def get_public_overridable_outplace_we_care_about():
133    results = get_public_overridable_outplace_ops()
134    cpy = copy.deepcopy(results)
135    for key in cpy.keys():
136        # quantization
137        if "quant" in key or ".q_" in key:
138            del results[key]
139
140        # is_cpu, etc. It doesn't make sense to have OpInfos for these
141        if ".is_" in key:
142            del results[key]
143
144        if key in denylist and key in results:
145            del results[key]
146    return results
147
148
149# e.g. nn.functional.softmax
150
151
152def get_op(dotted_name):
153    names = dotted_name.split(".")
154    mod = torch
155    for name in names:
156        if not hasattr(mod, name):
157            return None
158        mod = getattr(mod, name)
159    return mod
160
161
162# Maps function -> [OpInfo]
163
164
165def get_ops_covered_by_opinfos():
166    ops = {}
167
168    def safe_append(dct, key, val):
169        if key in dct:
170            dct[key].append(val)
171        else:
172            dct[key] = [val]
173
174    for opinfo in op_db:
175        func_op = get_op(opinfo.name)
176        if func_op:
177            safe_append(ops, func_op, opinfo)
178        if opinfo.method_variant:
179            safe_append(ops, opinfo.method_variant, opinfo)
180        if opinfo.inplace_variant:
181            safe_append(ops, opinfo.inplace_variant, opinfo)
182        for alias in opinfo.aliases:
183            safe_append(ops, alias.op, opinfo)
184    return ops
185
186
187factory_fns = {
188    "tensor",
189    "zeros",
190    "ones",
191    "randn",
192    "arange",
193    "rand",
194    "empty",
195    "randperm",
196    "linspace",
197    "logspace",
198    "hann_window",
199    "full",
200    "eye",
201    "blackman_window",
202    "bartlett_window",
203    "randint",
204    "range",
205}
206
207
208def get_top_ops(torch_threshold, nn_fn_threshold, with_counts=False):
209    denylist = set(
210        {
211            # These are either not real "operators", factory functions
212            # that trivially work, or not-documented ops.
213            "load",
214            "no_grad",
215            "save",
216            "from_numpy",
217            "manual_seed",
218            "set_grad_enabled",
219            "set_default_tensor_type",
220            "set_num_threads",
221            "set_printoptions",
222            "numel",
223            "set_default_dtype",
224            "sparse_coo_tensor",
225            "set_rng_state",
226            "get_rng_state",
227            "get_default_dtype",
228            "initial_seed",
229            "get_num_threads",
230            "quantize_per_tensor",
231            "hann_window",
232            "is_tensor",
233            "as_tensor",
234            "equal",
235            "enable_grad",
236            "seed",
237            "is_storage",
238            "is_floating_point",
239            "nn.functional.torch",
240            "set_flush_denormal",
241            "set_num_interop_threads",
242            "dequantize",
243            "get_num_interop_threads",
244            "nn.functional.math",
245            "nn.functional.threshold_",
246            "nn.functional.selu_",
247            "nn.functional.elu_",
248            "nn.functional.rrelu_",
249            "nn.functional.leaky_relu_",
250            "nn.functional.hardtanh_",
251            "nn.functional.has_torch_function",
252            "nn.functional.has_torch_function_unary",
253            "nn.functional.has_torch_function_variadic",
254            "nn.functional.handle_torch_function",
255            "nn.functional.adaptive_max_pool1d_with_indices",
256            "nn.functional.adaptive_max_pool2d_with_indices",
257            "nn.functional.adaptive_max_pool3d_with_indices",
258            "nn.functional.fractional_max_pool2d_with_indices",
259            "nn.functional.fractional_max_pool3d_with_indices",
260            "is_complex",
261            "grad",
262            "quantize_per_channel",
263            "nn.functional.max_pool2d_with_indices",
264            "nn.functional.max_pool3d_with_indices",
265            "nn.functional.max_pool1d_with_indices",
266            "nn.functional.celu_",
267            "nn.functional.grad",
268            "nn.functional.relu_",
269            "nn.functional.boolean_dispatch",
270            "nn.functional.assert_int_or_pair",
271            "fft",  # is namespace
272        }
273    )
274
275    torch_ops = top_ops.top_torch
276    nn_fn_ops = top_ops.get_nn_functional_top_list()
277    torch_ops = [op for op in torch_ops if op[0] not in denylist]
278    nn_fn_ops = [op for op in nn_fn_ops if op[0] not in denylist]
279
280    ops = torch_ops[:torch_threshold] + nn_fn_ops[:nn_fn_threshold]
281
282    # Now, sort by priority
283    ops.sort(reverse=True, key=lambda op: op[1])
284    if not with_counts:
285        ops = [op[0] for op in ops]
286    return ops
287
288
289def get_ops_percentage(torch_threshold, nn_fn_threshold):
290    data = top_ops.top_torch + top_ops.get_nn_functional_top_list()
291
292    def get_num_usages(opname):
293        # Ignore this, this is heavily inflated
294        if opname == "t":
295            return 0
296        result = [op[1] for op in data if op[0] == opname]
297        assert len(result) == 1
298        return result[0]
299
300    # get all operators that are not in the denylist
301    all_ops = get_top_ops(999999, 999999)
302    total_op_usages = sum(get_num_usages(op) for op in all_ops)
303
304    # get subset of all operators
305    subset_ops = get_top_ops(torch_threshold, nn_fn_threshold)
306    subset_op_usages = sum(get_num_usages(op) for op in subset_ops)
307    return subset_op_usages / total_op_usages
308
309
310def get_top_ops_not_covered_by_opinfo(torch_threshold=0, nn_fn_threshold=0):
311    ops = get_top_ops(torch_threshold, nn_fn_threshold)
312
313    ops_with_opinfo = []
314    for op in op_db:
315        ops_with_opinfo.append(op.name)
316        ops_with_opinfo.extend([op.name for op in op.aliases])
317    ops_with_opinfo = set(ops_with_opinfo)
318
319    result = [op for op in ops if op not in ops_with_opinfo]
320    result = [op for op in result if op not in denylist]
321    result = [op for op in result if op not in factory_fns]
322    return result
323
324
325def get_covered_ops(ops_list, invert=False):
326    ops_covered_by_opinfo = get_ops_covered_by_opinfos()
327    overridable_outplace_ops = ops_list
328    results = {}
329    for key, op in overridable_outplace_ops.items():
330        cond = op in ops_covered_by_opinfo
331        if invert:
332            cond = not cond
333        if cond:
334            results[key] = op
335    return results
336
337
338class Status(Enum):
339    Correct = 0
340    Fast = 1
341
342
343tests = {
344    "test_vmap_exhaustive",
345    "test_op_has_batch_rule",
346    "test_vjp",
347    "test_vmapvjp",
348    "test_vmapvjp_has_batch_rule",
349    "test_jvp",
350    "test_vmapjvp",
351}
352
353
354def is_decorateinfo_skip_or_xfail(decorateinfo):
355    assert len(decorateinfo.decorators) == 1
356    actual_decorator = decorateinfo.decorators[0]
357    if isinstance(actual_decorator, toleranceOverride):
358        return False
359    if actual_decorator == unittest.expectedFailure:
360        return True
361    # Assume the rest are skips
362    return True
363
364
365def get_all_tested_ops():
366    overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about()
367    op_to_opinfo = get_ops_covered_by_opinfos()
368    result = set({})
369    for op in get_covered_ops(overridable_outplace_we_care_about).values():
370        opinfos = op_to_opinfo[op]
371        result.update(opinfo.name for opinfo in opinfos)
372    return result
373
374
375def get_skipped_or_xfailed_ops_for(test_name):
376    overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about()
377    op_to_opinfo = get_ops_covered_by_opinfos()
378    result = set({})
379    for op in get_covered_ops(overridable_outplace_we_care_about).values():
380        opinfos = op_to_opinfo[op]
381        for opinfo in opinfos:
382            for decorator in opinfo.decorators:
383                if not hasattr(decorator, "test_name"):
384                    continue
385                if decorator.test_name != test_name:
386                    continue
387                if is_decorateinfo_skip_or_xfail(decorator):
388                    result.add(opinfo.name)
389    return result
390
391
392def get_statuses(for_subset=None, invert=False):
393    overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about()
394    if for_subset is not None:
395        overridable_outplace_we_care_about = {
396            k: v
397            for k, v in overridable_outplace_we_care_about.items()
398            # Removes "torch."
399            if k[6:] in for_subset
400        }
401    op_to_opinfo = get_ops_covered_by_opinfos()
402    result = {}
403    _ = get_covered_ops(overridable_outplace_we_care_about)
404
405    def get_covered_tests(op):
406        opinfos = op_to_opinfo[op]
407        result = copy.deepcopy(tests)
408        for opinfo in opinfos:
409            for decorator in opinfo.decorators:
410                if not hasattr(decorator, "test_name"):
411                    continue
412                if decorator.test_name in tests and decorator.test_name in result:
413                    result.remove(decorator.test_name)
414        return result
415
416    def get_all_aliases(op):
417        opinfos = op_to_opinfo[op]
418        result = []
419        for opinfo in opinfos:
420            result.append(opinfo.name)
421            result.extend(opinfo.aliases)
422        return set(result)
423
424    for name, op in get_covered_ops(overridable_outplace_we_care_about).items():
425        successful_tests = get_covered_tests(op)
426        failed_tests = tests - successful_tests
427        result[name] = failed_tests if invert else successful_tests
428    return result
429
430
431def transpose_statuses(for_subset=None, invert=False):
432    statuses = get_statuses(for_subset, invert=invert)
433    result = {}
434    for test in tests:
435        result[test] = set({})
436    for op, supported in statuses.items():
437        for test in supported:
438            result[test].add(op)
439    return result
440
441
442overridable_apis = get_public_overridable_apis()
443
444overridable_ops = get_public_overridable_ops()
445
446overridable_outplace_ops = get_public_overridable_outplace_ops()
447
448overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about()
449
450tested_overridable_outplace_ops = get_covered_ops(overridable_outplace_we_care_about)
451untested_overridable_outplace_ops = get_covered_ops(
452    overridable_outplace_we_care_about, invert=True
453)
454
455# print("List of OpInfos we need:")
456# for key in untested_overridable_outplace_ops.keys():
457#     print(key)
458# print("-" * 80)
459# print("")
460
461print(f"Overridable public APIs: {len(overridable_apis)}")
462print(f"Overridable public ops: {len(overridable_ops)}")
463print(f"Overridable public outplace ops: {len(overridable_outplace_ops)}")
464print(
465    f"Overridable public outplace ops we care about: {len(overridable_outplace_we_care_about)}"
466)
467print(
468    f"OpInfo-tested overridable public outplace ops: {len(tested_overridable_outplace_ops)}"
469)
470
471
472def remove_torch(name):
473    assert name[:6] == "torch."
474    return name[6:]
475
476
477def get_list_of_all_tests():
478    all_tests = list(tested_overridable_outplace_ops.keys())
479    return {remove_torch(test) for test in all_tests}
480
481
482mytest = {
483    "test_vmap_exhaustive",
484    "test_op_has_batch_rule",
485    "test_vjp",
486    "test_vmapvjp",
487    "test_vmapvjp_has_batch_rule",
488}
489
490print("*" * 80)
491all_tests = get_list_of_all_tests()
492for test in mytest:
493    result = get_skipped_or_xfailed_ops_for(test)
494    diff = len(all_tests - result)
495    print(f"{test}: {diff}")
496
497
498def get_jvp_coverage(subset=None):
499    # - number that support autograd
500    # - number that support forward_ad (in pytorch core)
501    # - number that support functorch.jvp
502    op_to_opinfo = get_ops_covered_by_opinfos()
503    ops_dct = tested_overridable_outplace_ops
504    if subset is not None:
505        ops_dct = {
506            name: op for name, op in ops_dct.items() if remove_torch(name) in subset
507        }
508    supports_autograd_ops_dct = {
509        name: op_to_opinfo[fn]
510        for name, fn in ops_dct.items()
511        if op_to_opinfo[fn][0].supports_autograd
512    }
513    supports_forwardad_ops_dct = {
514        name: op_to_opinfo[fn]
515        for name, fn in ops_dct.items()
516        if op_to_opinfo[fn][0].supports_forward_ad
517    }
518
519    ops = {remove_torch(test) for test in list(ops_dct.keys())}
520    supports_autograd = {
521        remove_torch(test) for test in list(supports_autograd_ops_dct.keys())
522    }
523    supports_forward_ad = {
524        remove_torch(test) for test in list(supports_forwardad_ops_dct.keys())
525    }
526    assert supports_forward_ad.issubset(supports_autograd)
527    assert supports_autograd.issubset(ops)
528
529    failed_ops = get_skipped_or_xfailed_ops_for("test_jvp")
530
531    coverage = len(supports_forward_ad - failed_ops)
532    no_forward_ad = len(supports_autograd) - len(supports_forward_ad)
533    print(f"test_jvp, {coverage}, {no_forward_ad}, {len(ops)}")
534
535
536get_jvp_coverage()
537get_jvp_coverage(get_top_ops(100, 25))
538for op in get_top_ops(100, 25):
539    print(op)
540print("*" * 80)
541
542# result = get_skipped_or_xfailed_ops_for('test_vmap_exhaustive')
543# result = get_skipped_or_xfailed_ops_for('test_op_has_batch_rule')
544# result = get_skipped_or_xfailed_ops_for('test_vjp')
545# result = get_skipped_or_xfailed_ops_for('test_vmapvjp')
546# result = get_skipped_or_xfailed_ops_for('test_vmapvjp_has_batch_rule')
547# import pdb; pdb.set_trace()
548
549statuses = transpose_statuses()
550for test in tests:
551    print(f"{test} coverage {len(statuses[test])}")
552
553method_only_ops = get_method_only_ops_we_care_about()
554# for op in method_only_ops:
555#     print(f'    {op},')
556
557top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(100, 25)
558print("=" * 80)
559for op in top_ops_not_covered_by_opinfo:
560    print(f"{op}, {top_ops.usage_count[op]}")
561
562# print("top ops not covered by opinfo: ")
563# top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(200, 50)
564# for op in top_ops_not_covered_by_opinfo:
565#     print(f'{op}, {top_ops.usage_count[op]}')
566
567# print("top ops not covered by opinfo: ")
568# top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(220, 92)
569# for op in top_ops_not_covered_by_opinfo:
570#    print(f'{op}, {top_ops.usage_count[op]}')
571
572# print("top ops not covered by opinfo: ")
573# top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(999, 999)
574# for op in top_ops_not_covered_by_opinfo:
575#     print(f'{op}, {top_ops.usage_count[op]}')
576
577
578def remove_from_set(parent, to_remove):
579    for to_remove_elt in to_remove:
580        if to_remove_elt in parent:
581            parent.remove(to_remove_elt)
582
583
584def print_coverage_info(th=100, nn=25):
585    print("=" * 80)
586    print(f"top {th}, {nn} coverage")
587    statuses = transpose_statuses(get_top_ops(th, nn), invert=True)
588    top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(th, nn)
589
590    # testing problems
591    exemptions = {
592        "torch.nn.functional.dropout",  # randomness
593    }
594
595    # Allowed exemptions
596    vmap_exemptions = {
597        "torch.randn_like",  # randomness
598        "torch.rand_like",  # randomness
599        "torch.allclose",  # number output
600        "torch.unique",  # dynamic
601        "torch.nonzero",  # dynamic
602        "torch.masked_select",  # dynamic
603        "torch.prod",  # dynamic (backward)
604        "torch.norm",  # norm with nuc is not commonly used; we support the other cases.
605        "torch.svd",  # There isn't a bug, it is just nondeterministic so we can't test it.
606        "torch.nn.functional.embedding",  # We support everything except the sparse option.
607    }
608    remove_from_set(statuses["test_vmap_exhaustive"], vmap_exemptions)
609    remove_from_set(statuses["test_vmapvjp"], vmap_exemptions)
610    remove_from_set(statuses["test_vmapvjp_has_batch_rule"], vmap_exemptions)
611    remove_from_set(statuses["test_op_has_batch_rule"], vmap_exemptions)
612    remove_from_set(statuses["test_vmapjvp"], vmap_exemptions)
613    for test in tests:
614        remove_from_set(statuses[test], exemptions)
615
616    print(f"total ops in set: {th + nn}")
617    print(f"tested by OpInfo: {th + nn - len(top_ops_not_covered_by_opinfo)}")
618    for test in tests:
619        if test in {"test_jvp", "test_vmapjvp"}:
620            continue
621        print(f"{test} failing coverage {len(statuses[test])}")
622
623    # We don't care about these yet
624    del statuses["test_jvp"]
625    del statuses["test_vmapjvp"]
626
627    pprint.pprint(statuses)
628
629
630def get_name_to_opinfo_map():
631    dct = {}
632    for op in op_db + additional_op_db:
633
634        def add(name, op):
635            if name not in dct:
636                dct[name] = []
637            dct[name].append(op)
638
639        add(op.name, op)
640        for alias in op.aliases:
641            add(alias.name, op)
642    return dct
643
644
645NAME_TO_OPINFO = get_name_to_opinfo_map()
646
647
648class Support(enum.Enum):
649    NO = 0
650    YES = 1
651    UNKNOWN = 2
652
653
654FACTORY_FNS = {
655    "tensor",
656    "zeros",
657    "ones",
658    "randn",
659    "arange",
660    "rand",
661    "empty",
662    "range",
663    "full",
664    "randperm",
665    "eye",
666    "randint",
667    "linspace",
668    "logspace",
669}
670
671VJP_EXEMPTIONS = {
672    "nn.functional.dropout",  # not actually problem, randomness testing artifact
673    "nn.functional.dropout2d",  # not actually problem, randomness testing artifact
674    "nn.functional.rrelu",  # not actually problem, randomness testing artifact
675    "bernoulli",  # not actually problem, randomness testing artifact
676    "normal",  # not actually problem, randomness testing artifact
677}
678
679VMAP_EXEMPTIONS = {
680    "randn_like",  # randomness
681    "rand_like",  # randomness
682    "allclose",  # number output
683    "unique",  # dynamic
684    "nonzero",  # dynamic
685    "masked_select",  # dynamic
686    "prod",  # dynamic (backward)
687    "norm",  # norm with nuc is not commonly used; we support the other cases.
688    "svd",  # There isn't a bug, it is just nondeterministic so we can't test it.
689    "nn.functional.embedding",  # We support everything except the sparse option.
690    "nn.functional.dropout",  # randomness
691    "nn.functional.dropout2d",  # randomness
692    "bernoulli",  # randomness
693    "multinomial",  # randomness
694    "normal",  # randomness
695}
696
697JVP_EXEMPTIONS = {
698    "nn.functional.dropout",  # not actually problem, randomness testing artifact
699    "nn.functional.dropout2d",  # not actually problem, randomness testing artifact
700    "nn.functional.rrelu",  # not actually problem, randomness testing artifact
701    "normal",  # not actually problem, randomness testing artifact
702    "bernoulli",  # not actually problem, randomness testing artifact
703}
704
705
706class Operator:
707    def __init__(self, name):
708        self.name = name
709        self.opinfos = NAME_TO_OPINFO.get(name, None)
710        assert self.opinfos is None or len(self.opinfos) > 0
711
712    def has_opinfo(self):
713        return self.opinfos is not None
714
715    def __repr__(self):
716        return f'Operator("{self.name}")'
717
718    def __hash__(self):
719        return hash(self.name)
720
721    def no_opinfos_skip_test(self, test_name):
722        """Returns NO if any opinfos have a skip or xfail for the test"""
723        if not self.has_opinfo():
724            return Support.UNKNOWN
725        for opinfo in self.opinfos:
726            for decorator in opinfo.decorators:
727                if not hasattr(decorator, "test_name"):
728                    continue
729                if decorator.test_name != test_name:
730                    continue
731                if is_decorateinfo_skip_or_xfail(decorator):
732                    return Support.NO
733        return Support.YES
734
735    def any_opinfo_attr(self, attr):
736        if not self.has_opinfo():
737            raise RuntimeError
738        return any(getattr(opinfo, attr) for opinfo in self.opinfos)
739
740    def all_opinfo_attr(self, attr):
741        if not self.has_opinfo():
742            raise RuntimeError
743        return all(getattr(opinfo, attr) for opinfo in self.opinfos)
744
745    def supports_vjp(self):
746        if self.name in FACTORY_FNS:
747            return Support.YES
748        if self.name in VJP_EXEMPTIONS:
749            return Support.YES
750        return self.no_opinfos_skip_test("test_vjp")
751
752    def supports_vmap(self):
753        if self.name in FACTORY_FNS:
754            return Support.YES
755        if self.name in VMAP_EXEMPTIONS:
756            return Support.YES
757        return self.no_opinfos_skip_test("test_vmap_exhaustive")
758
759    def supports_fast_vmap(self):
760        if self.name in FACTORY_FNS:
761            return Support.YES
762        if self.name in VMAP_EXEMPTIONS:
763            return Support.YES
764        return self.no_opinfos_skip_test("test_op_has_batch_rule")
765
766    def supports_vmapvjp(self):
767        if self.name in FACTORY_FNS:
768            return Support.YES
769        if self.name in VMAP_EXEMPTIONS:
770            return Support.YES
771        return self.no_opinfos_skip_test("test_vmapvjp")
772
773    def supports_fast_vmapvjp(self):
774        if self.name in FACTORY_FNS:
775            return Support.YES
776        if self.name in VMAP_EXEMPTIONS:
777            return Support.YES
778        return self.no_opinfos_skip_test("test_vmapvjp_has_batch_rule")
779
780    def supports_jvp(self):
781        if self.name in FACTORY_FNS:
782            return Support.YES
783        if self.name in JVP_EXEMPTIONS:
784            return Support.YES
785        if not self.has_opinfo():
786            return Support.UNKNOWN
787        if self.any_opinfo_attr("supports_autograd") and not self.all_opinfo_attr(
788            "supports_forward_ad"
789        ):
790            return Support.NO
791        return self.no_opinfos_skip_test("test_jvp")
792
793    def supports_jvpvjp(self):
794        if self.name in FACTORY_FNS:
795            return Support.YES
796        exemptions = {
797            # we have support (see OpInfo), testing artifact
798            "nn.functional.dropout2d",
799            "nn.functional.dropout",
800            # exception: we dont even support double backward for this
801            "nn.functional.hardswish",
802            "bernoulli",  # this isn't differentiable
803            "normal",  # not differentiable
804        }
805        if self.name in exemptions:
806            return Support.YES
807        return self.no_opinfos_skip_test("test_jvpvjp")
808
809    def _supports_vmapjvp_base(self, test):
810        if self.name in FACTORY_FNS:
811            return Support.YES
812        VMAPJVP_EXEMPTIONS = {
813            "prod",  # dynamic (backward)
814            "nn.functional.batch_norm",  # testing problem
815            "normal",  # not actually problem, randomness testing artifact
816            "bernoulli",  # not actually problem, randomness testing artifact
817            "nn.functional.dropout2d",  # not actually problem, randomness testing artifact
818            "nn.functional.dropout",  # not actually problem, randomness testing artifact
819            # Not a problem.
820            # It's just that the max_norm testing mutates inputs...
821            # (we have our own functorch variant of the OpInfo without max_norm)
822            "nn.functional.embedding",
823        }
824        if self.name in VMAPJVP_EXEMPTIONS:
825            return Support.YES
826        if not self.has_opinfo():
827            return Support.UNKNOWN
828        if self.any_opinfo_attr("supports_autograd") and not self.all_opinfo_attr(
829            "supports_forward_ad"
830        ):
831            return Support.NO
832        return self.no_opinfos_skip_test(test)
833
834    def supports_vmapjvp(self):
835        return self._supports_vmapjvp_base("test_vmapjvpall")
836
837    def supports_fast_vmapjvp(self):
838        return self._supports_vmapjvp_base("test_vmapjvpall_has_batch_rule")
839
840
841class OperatorSet:
842    def __init__(self, operators):
843        self.data = set(operators)
844
845    @classmethod
846    def from_names(cls, names):
847        return OperatorSet([Operator(name) for name in names])
848
849    @classmethod
850    def from_top_ops_threshold(cls, torch_threshold, nn_fn_threshold):
851        names = get_top_ops(torch_threshold, nn_fn_threshold)
852        return cls.from_names(names)
853
854    @classmethod
855    def from_top125(cls):
856        return cls.from_top_ops_threshold(100, 25)
857
858    @classmethod
859    def from_top160(cls):
860        return cls.from_top_ops_threshold(107, 53)
861
862    @classmethod
863    def all(cls):
864        dct = get_public_overridable_outplace_we_care_about()
865        names = dct.keys()
866        names_sanitized = []
867        for n in names:
868            torch_tensor = "torch.Tensor."
869            torch_dot = "torch."
870            if n.startswith(torch_tensor):
871                names_sanitized.append(n[len(torch_tensor) :])
872            elif n.startswith(torch_dot):
873                names_sanitized.append(n[len(torch_dot) :])
874            else:
875                raise AssertionError
876        return cls.from_names(names_sanitized)
877
878    def query(self, operator_method, filter=(Support.NO, Support.YES, Support.UNKNOWN)):
879        result = {}
880        for key in filter:
881            result[key] = set()
882        for op in self.data:
883            support_status = operator_method(op)
884            if support_status in filter:
885                result[support_status].add(op)
886        return result
887
888    def summary(self):
889        checks = [
890            "supports_vjp",
891            "supports_vmap",
892            "supports_fast_vmap",
893            "supports_vmapvjp",
894            "supports_fast_vmapvjp",
895            "supports_jvp",
896            "supports_vmapjvp",
897            "supports_fast_vmapjvp",
898            "supports_jvpvjp",
899        ]
900        result = ["test, yes, no, unknown"]
901        for check in checks:
902            accessor = getattr(Operator, check)
903            all_results = self.query(accessor)
904            yes_amt = len(all_results[Support.YES])
905            no_amt = len(all_results[Support.NO])
906            unknown_amt = len(all_results[Support.UNKNOWN])
907            result.append(f"{check}, {yes_amt}, {no_amt}, {unknown_amt}")
908        return "\n".join(result)
909
910
911opset = OperatorSet.all()
912has_no_opinfo = opset.query(Operator.has_opinfo, (False,))
913
914print("=" * 30 + " Summary " + "=" * 30)
915print(f"% of usages on github: {get_ops_percentage(99999, 99999)}")
916print(opset.summary())
917
918# sanity checks
919result = opset.query(Operator.supports_vjp, (Support.NO, Support.UNKNOWN))
920# pprint.pprint(result)
921
922print("=" * 30 + " Top 60 Summary " + "=" * 30)
923print(f"% of usages on github: {get_ops_percentage(35, 25)}")
924opset = OperatorSet.from_top_ops_threshold(35, 25)
925# result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
926# pprint.pprint(result)
927# result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN))
928# pprint.pprint(result)
929# kresult = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
930# kpprint.pprint(result)
931# result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
932# pprint.pprint(result)
933# result = opset.query(Operator.supports_fast_vmapjvp, (Support.NO, Support.UNKNOWN))
934# pprint.pprint(result)
935# pprint.pprint(result)
936print(opset.summary())
937
938print("=" * 30 + " Top 125 Summary " + "=" * 30)
939print(f"% of usages on github: {get_ops_percentage(100, 25)}")
940opset = OperatorSet.from_top125()
941# result = opset.query(Operator.supports_vmap, (Support.NO, Support.UNKNOWN))
942# pprint.pprint(result)
943# result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
944# pprint.pprint(result)
945print("supports_vjp")
946result = opset.query(Operator.supports_vjp, (Support.NO, Support.UNKNOWN))
947pprint.pprint(result)
948print("supports_jvp")
949result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN))
950pprint.pprint(result)
951print("supports_vmapjvp")
952result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
953pprint.pprint(result)
954print("supports_jvpvjp")
955result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
956pprint.pprint(result)
957# result = opset.query(Operator.supports_fast_vmapjvp, (Support.NO, Support.UNKNOWN))
958# pprint.pprint(result)
959# pprint.pprint(result)
960print(opset.summary())
961
962# print("=" * 30 + " Top 160 Summary " + "=" * 30)
963# opset = OperatorSet.from_top160()
964# result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
965# pprint.pprint(result)
966# print(opset.summary())
967
968# Print list of everything in order
969# all_ops = get_top_ops(999999, 999999, with_counts=True)
970# for op, count in all_ops:
971#     print(f'{op}, {count}')
972