1import copy 2import enum 3import pprint 4import unittest 5from enum import Enum 6 7# Importing these files make modifications to the op_db that we need 8import test_ops # noqa: F401 9 10import test_vmap # noqa: F401 11from functorch_additional_op_db import additional_op_db 12 13import torch 14import torch._functorch.top_operators_github_usage as top_ops 15from torch.testing._internal.common_device_type import toleranceOverride 16from torch.testing._internal.common_methods_invocations import op_db 17 18 19all_overridable = list(torch.overrides.get_testing_overrides().keys()) 20 21public_docs = [ 22 (torch.nn.functional, "torch.nn.functional", "docs/source/nn.functional.rst"), 23 (torch.fft, "torch.fft", "docs/source/fft.rst"), 24 (torch.special, "torch.special", "docs/source/special.rst"), 25 (torch.linalg, "torch.linalg", "docs/source/linalg.rst"), 26 (torch, "torch", "docs/source/torch.rst"), 27 (torch.Tensor, "torch.Tensor", "docs/source/tensors.rst"), 28] 29 30# torch.abs, Tensor.abs, Tensor.abs_ are all considered to be different 31 32 33def get_public_overridable_apis(pytorch_root="/raid/rzou/pt/debug-cpu"): 34 results = {} 35 all_overridable_apis = set(torch.overrides.get_testing_overrides().keys()) 36 for module, module_name, src in public_docs: 37 with open(f"{pytorch_root}/{src}") as f: 38 lines = f.readlines() 39 # APIs eitehr begin with 4 spaces or ".. autofunction::" 40 api_lines1 = [line.strip() for line in lines if line.startswith(" " * 4)] 41 api_lines2 = [ 42 line.strip()[len(".. autofunction:: ") :] 43 for line in lines 44 if line.startswith(".. autofunction::") 45 ] 46 lines = api_lines1 + api_lines2 47 lines = [line[7:] if line.startswith("Tensor.") else line for line in lines] 48 lines = [line for line in lines if hasattr(module, line)] 49 for line in lines: 50 api = getattr(module, line) 51 if api in all_overridable_apis: 52 results[f"{module_name}.{line}"] = api 53 return results 54 55 56denylist = { 57 "torch.Tensor.data_ptr", 58 "torch.Tensor.dim", 59 "torch.Tensor.element_size", 60 "torch.Tensor.backward", 61 "torch.Tensor.as_strided", 62 "torch.Tensor.register_hook", 63 "torch.Tensor.record_stream", 64 "torch.Tensor.qscheme", 65 "torch.Tensor.ndimension", 66 "torch.Tensor.smm", 67 "torch.Tensor.sspaddmm", 68 "torch.Tensor.retain_grad", 69 "torch.Tensor.sparse_mask", 70 "torch.Tensor.sparse_dim", 71 "torch.Tensor.dense_dim", 72 "torch.Tensor.values", 73 "torch.Tensor.indices", 74 "torch.Tensor.numel", 75 "torch.Tensor.size", 76 "torch.Tensor.nelement", 77 "torch.Tensor.q_scale", 78 "torch.Tensor.q_zero_point", 79 "torch.Tensor.q_per_channel_scales", 80 "torch.Tensor.q_per_channel_zero_points", 81 "torch.Tensor.q_per_channel_axis", 82 "torch.Tensor.int_repr", 83 "torch.Tensor.to_sparse", 84 "torch.Tensor.is_inference", 85 "torch.Tensor.storage", 86 "torch.Tensor.storage_type", 87} 88 89 90def get_method_only_ops_we_care_about(): 91 apis = get_public_overridable_apis() 92 result = [] 93 for key in apis.keys(): 94 if not key.startswith("torch.Tensor"): 95 continue 96 if key in denylist: 97 continue 98 api = key.split(".")[2] 99 # filter out in-place 100 if api.endswith("_"): 101 continue 102 if f"torch.{api}" not in apis.keys(): 103 result.append(api) 104 return result 105 106 107# Deduplicates torch.abs and Tensor.abs 108 109 110def get_public_overridable_ops(): 111 results = get_public_overridable_apis() 112 cpy = copy.deepcopy(results) 113 for key in cpy.keys(): 114 if not key.startswith("torch.Tensor"): 115 continue 116 api = key.split(".")[2] 117 if f"torch.{api}" in results.keys(): 118 del results[key] 119 return results 120 121 122def get_public_overridable_outplace_ops(): 123 results = get_public_overridable_ops() 124 cpy = copy.deepcopy(results) 125 for key in cpy.keys(): 126 # NB: there are no dunder methods bcs we don't document those 127 if key.endswith("_"): 128 del results[key] 129 return results 130 131 132def get_public_overridable_outplace_we_care_about(): 133 results = get_public_overridable_outplace_ops() 134 cpy = copy.deepcopy(results) 135 for key in cpy.keys(): 136 # quantization 137 if "quant" in key or ".q_" in key: 138 del results[key] 139 140 # is_cpu, etc. It doesn't make sense to have OpInfos for these 141 if ".is_" in key: 142 del results[key] 143 144 if key in denylist and key in results: 145 del results[key] 146 return results 147 148 149# e.g. nn.functional.softmax 150 151 152def get_op(dotted_name): 153 names = dotted_name.split(".") 154 mod = torch 155 for name in names: 156 if not hasattr(mod, name): 157 return None 158 mod = getattr(mod, name) 159 return mod 160 161 162# Maps function -> [OpInfo] 163 164 165def get_ops_covered_by_opinfos(): 166 ops = {} 167 168 def safe_append(dct, key, val): 169 if key in dct: 170 dct[key].append(val) 171 else: 172 dct[key] = [val] 173 174 for opinfo in op_db: 175 func_op = get_op(opinfo.name) 176 if func_op: 177 safe_append(ops, func_op, opinfo) 178 if opinfo.method_variant: 179 safe_append(ops, opinfo.method_variant, opinfo) 180 if opinfo.inplace_variant: 181 safe_append(ops, opinfo.inplace_variant, opinfo) 182 for alias in opinfo.aliases: 183 safe_append(ops, alias.op, opinfo) 184 return ops 185 186 187factory_fns = { 188 "tensor", 189 "zeros", 190 "ones", 191 "randn", 192 "arange", 193 "rand", 194 "empty", 195 "randperm", 196 "linspace", 197 "logspace", 198 "hann_window", 199 "full", 200 "eye", 201 "blackman_window", 202 "bartlett_window", 203 "randint", 204 "range", 205} 206 207 208def get_top_ops(torch_threshold, nn_fn_threshold, with_counts=False): 209 denylist = set( 210 { 211 # These are either not real "operators", factory functions 212 # that trivially work, or not-documented ops. 213 "load", 214 "no_grad", 215 "save", 216 "from_numpy", 217 "manual_seed", 218 "set_grad_enabled", 219 "set_default_tensor_type", 220 "set_num_threads", 221 "set_printoptions", 222 "numel", 223 "set_default_dtype", 224 "sparse_coo_tensor", 225 "set_rng_state", 226 "get_rng_state", 227 "get_default_dtype", 228 "initial_seed", 229 "get_num_threads", 230 "quantize_per_tensor", 231 "hann_window", 232 "is_tensor", 233 "as_tensor", 234 "equal", 235 "enable_grad", 236 "seed", 237 "is_storage", 238 "is_floating_point", 239 "nn.functional.torch", 240 "set_flush_denormal", 241 "set_num_interop_threads", 242 "dequantize", 243 "get_num_interop_threads", 244 "nn.functional.math", 245 "nn.functional.threshold_", 246 "nn.functional.selu_", 247 "nn.functional.elu_", 248 "nn.functional.rrelu_", 249 "nn.functional.leaky_relu_", 250 "nn.functional.hardtanh_", 251 "nn.functional.has_torch_function", 252 "nn.functional.has_torch_function_unary", 253 "nn.functional.has_torch_function_variadic", 254 "nn.functional.handle_torch_function", 255 "nn.functional.adaptive_max_pool1d_with_indices", 256 "nn.functional.adaptive_max_pool2d_with_indices", 257 "nn.functional.adaptive_max_pool3d_with_indices", 258 "nn.functional.fractional_max_pool2d_with_indices", 259 "nn.functional.fractional_max_pool3d_with_indices", 260 "is_complex", 261 "grad", 262 "quantize_per_channel", 263 "nn.functional.max_pool2d_with_indices", 264 "nn.functional.max_pool3d_with_indices", 265 "nn.functional.max_pool1d_with_indices", 266 "nn.functional.celu_", 267 "nn.functional.grad", 268 "nn.functional.relu_", 269 "nn.functional.boolean_dispatch", 270 "nn.functional.assert_int_or_pair", 271 "fft", # is namespace 272 } 273 ) 274 275 torch_ops = top_ops.top_torch 276 nn_fn_ops = top_ops.get_nn_functional_top_list() 277 torch_ops = [op for op in torch_ops if op[0] not in denylist] 278 nn_fn_ops = [op for op in nn_fn_ops if op[0] not in denylist] 279 280 ops = torch_ops[:torch_threshold] + nn_fn_ops[:nn_fn_threshold] 281 282 # Now, sort by priority 283 ops.sort(reverse=True, key=lambda op: op[1]) 284 if not with_counts: 285 ops = [op[0] for op in ops] 286 return ops 287 288 289def get_ops_percentage(torch_threshold, nn_fn_threshold): 290 data = top_ops.top_torch + top_ops.get_nn_functional_top_list() 291 292 def get_num_usages(opname): 293 # Ignore this, this is heavily inflated 294 if opname == "t": 295 return 0 296 result = [op[1] for op in data if op[0] == opname] 297 assert len(result) == 1 298 return result[0] 299 300 # get all operators that are not in the denylist 301 all_ops = get_top_ops(999999, 999999) 302 total_op_usages = sum(get_num_usages(op) for op in all_ops) 303 304 # get subset of all operators 305 subset_ops = get_top_ops(torch_threshold, nn_fn_threshold) 306 subset_op_usages = sum(get_num_usages(op) for op in subset_ops) 307 return subset_op_usages / total_op_usages 308 309 310def get_top_ops_not_covered_by_opinfo(torch_threshold=0, nn_fn_threshold=0): 311 ops = get_top_ops(torch_threshold, nn_fn_threshold) 312 313 ops_with_opinfo = [] 314 for op in op_db: 315 ops_with_opinfo.append(op.name) 316 ops_with_opinfo.extend([op.name for op in op.aliases]) 317 ops_with_opinfo = set(ops_with_opinfo) 318 319 result = [op for op in ops if op not in ops_with_opinfo] 320 result = [op for op in result if op not in denylist] 321 result = [op for op in result if op not in factory_fns] 322 return result 323 324 325def get_covered_ops(ops_list, invert=False): 326 ops_covered_by_opinfo = get_ops_covered_by_opinfos() 327 overridable_outplace_ops = ops_list 328 results = {} 329 for key, op in overridable_outplace_ops.items(): 330 cond = op in ops_covered_by_opinfo 331 if invert: 332 cond = not cond 333 if cond: 334 results[key] = op 335 return results 336 337 338class Status(Enum): 339 Correct = 0 340 Fast = 1 341 342 343tests = { 344 "test_vmap_exhaustive", 345 "test_op_has_batch_rule", 346 "test_vjp", 347 "test_vmapvjp", 348 "test_vmapvjp_has_batch_rule", 349 "test_jvp", 350 "test_vmapjvp", 351} 352 353 354def is_decorateinfo_skip_or_xfail(decorateinfo): 355 assert len(decorateinfo.decorators) == 1 356 actual_decorator = decorateinfo.decorators[0] 357 if isinstance(actual_decorator, toleranceOverride): 358 return False 359 if actual_decorator == unittest.expectedFailure: 360 return True 361 # Assume the rest are skips 362 return True 363 364 365def get_all_tested_ops(): 366 overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about() 367 op_to_opinfo = get_ops_covered_by_opinfos() 368 result = set({}) 369 for op in get_covered_ops(overridable_outplace_we_care_about).values(): 370 opinfos = op_to_opinfo[op] 371 result.update(opinfo.name for opinfo in opinfos) 372 return result 373 374 375def get_skipped_or_xfailed_ops_for(test_name): 376 overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about() 377 op_to_opinfo = get_ops_covered_by_opinfos() 378 result = set({}) 379 for op in get_covered_ops(overridable_outplace_we_care_about).values(): 380 opinfos = op_to_opinfo[op] 381 for opinfo in opinfos: 382 for decorator in opinfo.decorators: 383 if not hasattr(decorator, "test_name"): 384 continue 385 if decorator.test_name != test_name: 386 continue 387 if is_decorateinfo_skip_or_xfail(decorator): 388 result.add(opinfo.name) 389 return result 390 391 392def get_statuses(for_subset=None, invert=False): 393 overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about() 394 if for_subset is not None: 395 overridable_outplace_we_care_about = { 396 k: v 397 for k, v in overridable_outplace_we_care_about.items() 398 # Removes "torch." 399 if k[6:] in for_subset 400 } 401 op_to_opinfo = get_ops_covered_by_opinfos() 402 result = {} 403 _ = get_covered_ops(overridable_outplace_we_care_about) 404 405 def get_covered_tests(op): 406 opinfos = op_to_opinfo[op] 407 result = copy.deepcopy(tests) 408 for opinfo in opinfos: 409 for decorator in opinfo.decorators: 410 if not hasattr(decorator, "test_name"): 411 continue 412 if decorator.test_name in tests and decorator.test_name in result: 413 result.remove(decorator.test_name) 414 return result 415 416 def get_all_aliases(op): 417 opinfos = op_to_opinfo[op] 418 result = [] 419 for opinfo in opinfos: 420 result.append(opinfo.name) 421 result.extend(opinfo.aliases) 422 return set(result) 423 424 for name, op in get_covered_ops(overridable_outplace_we_care_about).items(): 425 successful_tests = get_covered_tests(op) 426 failed_tests = tests - successful_tests 427 result[name] = failed_tests if invert else successful_tests 428 return result 429 430 431def transpose_statuses(for_subset=None, invert=False): 432 statuses = get_statuses(for_subset, invert=invert) 433 result = {} 434 for test in tests: 435 result[test] = set({}) 436 for op, supported in statuses.items(): 437 for test in supported: 438 result[test].add(op) 439 return result 440 441 442overridable_apis = get_public_overridable_apis() 443 444overridable_ops = get_public_overridable_ops() 445 446overridable_outplace_ops = get_public_overridable_outplace_ops() 447 448overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about() 449 450tested_overridable_outplace_ops = get_covered_ops(overridable_outplace_we_care_about) 451untested_overridable_outplace_ops = get_covered_ops( 452 overridable_outplace_we_care_about, invert=True 453) 454 455# print("List of OpInfos we need:") 456# for key in untested_overridable_outplace_ops.keys(): 457# print(key) 458# print("-" * 80) 459# print("") 460 461print(f"Overridable public APIs: {len(overridable_apis)}") 462print(f"Overridable public ops: {len(overridable_ops)}") 463print(f"Overridable public outplace ops: {len(overridable_outplace_ops)}") 464print( 465 f"Overridable public outplace ops we care about: {len(overridable_outplace_we_care_about)}" 466) 467print( 468 f"OpInfo-tested overridable public outplace ops: {len(tested_overridable_outplace_ops)}" 469) 470 471 472def remove_torch(name): 473 assert name[:6] == "torch." 474 return name[6:] 475 476 477def get_list_of_all_tests(): 478 all_tests = list(tested_overridable_outplace_ops.keys()) 479 return {remove_torch(test) for test in all_tests} 480 481 482mytest = { 483 "test_vmap_exhaustive", 484 "test_op_has_batch_rule", 485 "test_vjp", 486 "test_vmapvjp", 487 "test_vmapvjp_has_batch_rule", 488} 489 490print("*" * 80) 491all_tests = get_list_of_all_tests() 492for test in mytest: 493 result = get_skipped_or_xfailed_ops_for(test) 494 diff = len(all_tests - result) 495 print(f"{test}: {diff}") 496 497 498def get_jvp_coverage(subset=None): 499 # - number that support autograd 500 # - number that support forward_ad (in pytorch core) 501 # - number that support functorch.jvp 502 op_to_opinfo = get_ops_covered_by_opinfos() 503 ops_dct = tested_overridable_outplace_ops 504 if subset is not None: 505 ops_dct = { 506 name: op for name, op in ops_dct.items() if remove_torch(name) in subset 507 } 508 supports_autograd_ops_dct = { 509 name: op_to_opinfo[fn] 510 for name, fn in ops_dct.items() 511 if op_to_opinfo[fn][0].supports_autograd 512 } 513 supports_forwardad_ops_dct = { 514 name: op_to_opinfo[fn] 515 for name, fn in ops_dct.items() 516 if op_to_opinfo[fn][0].supports_forward_ad 517 } 518 519 ops = {remove_torch(test) for test in list(ops_dct.keys())} 520 supports_autograd = { 521 remove_torch(test) for test in list(supports_autograd_ops_dct.keys()) 522 } 523 supports_forward_ad = { 524 remove_torch(test) for test in list(supports_forwardad_ops_dct.keys()) 525 } 526 assert supports_forward_ad.issubset(supports_autograd) 527 assert supports_autograd.issubset(ops) 528 529 failed_ops = get_skipped_or_xfailed_ops_for("test_jvp") 530 531 coverage = len(supports_forward_ad - failed_ops) 532 no_forward_ad = len(supports_autograd) - len(supports_forward_ad) 533 print(f"test_jvp, {coverage}, {no_forward_ad}, {len(ops)}") 534 535 536get_jvp_coverage() 537get_jvp_coverage(get_top_ops(100, 25)) 538for op in get_top_ops(100, 25): 539 print(op) 540print("*" * 80) 541 542# result = get_skipped_or_xfailed_ops_for('test_vmap_exhaustive') 543# result = get_skipped_or_xfailed_ops_for('test_op_has_batch_rule') 544# result = get_skipped_or_xfailed_ops_for('test_vjp') 545# result = get_skipped_or_xfailed_ops_for('test_vmapvjp') 546# result = get_skipped_or_xfailed_ops_for('test_vmapvjp_has_batch_rule') 547# import pdb; pdb.set_trace() 548 549statuses = transpose_statuses() 550for test in tests: 551 print(f"{test} coverage {len(statuses[test])}") 552 553method_only_ops = get_method_only_ops_we_care_about() 554# for op in method_only_ops: 555# print(f' {op},') 556 557top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(100, 25) 558print("=" * 80) 559for op in top_ops_not_covered_by_opinfo: 560 print(f"{op}, {top_ops.usage_count[op]}") 561 562# print("top ops not covered by opinfo: ") 563# top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(200, 50) 564# for op in top_ops_not_covered_by_opinfo: 565# print(f'{op}, {top_ops.usage_count[op]}') 566 567# print("top ops not covered by opinfo: ") 568# top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(220, 92) 569# for op in top_ops_not_covered_by_opinfo: 570# print(f'{op}, {top_ops.usage_count[op]}') 571 572# print("top ops not covered by opinfo: ") 573# top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(999, 999) 574# for op in top_ops_not_covered_by_opinfo: 575# print(f'{op}, {top_ops.usage_count[op]}') 576 577 578def remove_from_set(parent, to_remove): 579 for to_remove_elt in to_remove: 580 if to_remove_elt in parent: 581 parent.remove(to_remove_elt) 582 583 584def print_coverage_info(th=100, nn=25): 585 print("=" * 80) 586 print(f"top {th}, {nn} coverage") 587 statuses = transpose_statuses(get_top_ops(th, nn), invert=True) 588 top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(th, nn) 589 590 # testing problems 591 exemptions = { 592 "torch.nn.functional.dropout", # randomness 593 } 594 595 # Allowed exemptions 596 vmap_exemptions = { 597 "torch.randn_like", # randomness 598 "torch.rand_like", # randomness 599 "torch.allclose", # number output 600 "torch.unique", # dynamic 601 "torch.nonzero", # dynamic 602 "torch.masked_select", # dynamic 603 "torch.prod", # dynamic (backward) 604 "torch.norm", # norm with nuc is not commonly used; we support the other cases. 605 "torch.svd", # There isn't a bug, it is just nondeterministic so we can't test it. 606 "torch.nn.functional.embedding", # We support everything except the sparse option. 607 } 608 remove_from_set(statuses["test_vmap_exhaustive"], vmap_exemptions) 609 remove_from_set(statuses["test_vmapvjp"], vmap_exemptions) 610 remove_from_set(statuses["test_vmapvjp_has_batch_rule"], vmap_exemptions) 611 remove_from_set(statuses["test_op_has_batch_rule"], vmap_exemptions) 612 remove_from_set(statuses["test_vmapjvp"], vmap_exemptions) 613 for test in tests: 614 remove_from_set(statuses[test], exemptions) 615 616 print(f"total ops in set: {th + nn}") 617 print(f"tested by OpInfo: {th + nn - len(top_ops_not_covered_by_opinfo)}") 618 for test in tests: 619 if test in {"test_jvp", "test_vmapjvp"}: 620 continue 621 print(f"{test} failing coverage {len(statuses[test])}") 622 623 # We don't care about these yet 624 del statuses["test_jvp"] 625 del statuses["test_vmapjvp"] 626 627 pprint.pprint(statuses) 628 629 630def get_name_to_opinfo_map(): 631 dct = {} 632 for op in op_db + additional_op_db: 633 634 def add(name, op): 635 if name not in dct: 636 dct[name] = [] 637 dct[name].append(op) 638 639 add(op.name, op) 640 for alias in op.aliases: 641 add(alias.name, op) 642 return dct 643 644 645NAME_TO_OPINFO = get_name_to_opinfo_map() 646 647 648class Support(enum.Enum): 649 NO = 0 650 YES = 1 651 UNKNOWN = 2 652 653 654FACTORY_FNS = { 655 "tensor", 656 "zeros", 657 "ones", 658 "randn", 659 "arange", 660 "rand", 661 "empty", 662 "range", 663 "full", 664 "randperm", 665 "eye", 666 "randint", 667 "linspace", 668 "logspace", 669} 670 671VJP_EXEMPTIONS = { 672 "nn.functional.dropout", # not actually problem, randomness testing artifact 673 "nn.functional.dropout2d", # not actually problem, randomness testing artifact 674 "nn.functional.rrelu", # not actually problem, randomness testing artifact 675 "bernoulli", # not actually problem, randomness testing artifact 676 "normal", # not actually problem, randomness testing artifact 677} 678 679VMAP_EXEMPTIONS = { 680 "randn_like", # randomness 681 "rand_like", # randomness 682 "allclose", # number output 683 "unique", # dynamic 684 "nonzero", # dynamic 685 "masked_select", # dynamic 686 "prod", # dynamic (backward) 687 "norm", # norm with nuc is not commonly used; we support the other cases. 688 "svd", # There isn't a bug, it is just nondeterministic so we can't test it. 689 "nn.functional.embedding", # We support everything except the sparse option. 690 "nn.functional.dropout", # randomness 691 "nn.functional.dropout2d", # randomness 692 "bernoulli", # randomness 693 "multinomial", # randomness 694 "normal", # randomness 695} 696 697JVP_EXEMPTIONS = { 698 "nn.functional.dropout", # not actually problem, randomness testing artifact 699 "nn.functional.dropout2d", # not actually problem, randomness testing artifact 700 "nn.functional.rrelu", # not actually problem, randomness testing artifact 701 "normal", # not actually problem, randomness testing artifact 702 "bernoulli", # not actually problem, randomness testing artifact 703} 704 705 706class Operator: 707 def __init__(self, name): 708 self.name = name 709 self.opinfos = NAME_TO_OPINFO.get(name, None) 710 assert self.opinfos is None or len(self.opinfos) > 0 711 712 def has_opinfo(self): 713 return self.opinfos is not None 714 715 def __repr__(self): 716 return f'Operator("{self.name}")' 717 718 def __hash__(self): 719 return hash(self.name) 720 721 def no_opinfos_skip_test(self, test_name): 722 """Returns NO if any opinfos have a skip or xfail for the test""" 723 if not self.has_opinfo(): 724 return Support.UNKNOWN 725 for opinfo in self.opinfos: 726 for decorator in opinfo.decorators: 727 if not hasattr(decorator, "test_name"): 728 continue 729 if decorator.test_name != test_name: 730 continue 731 if is_decorateinfo_skip_or_xfail(decorator): 732 return Support.NO 733 return Support.YES 734 735 def any_opinfo_attr(self, attr): 736 if not self.has_opinfo(): 737 raise RuntimeError 738 return any(getattr(opinfo, attr) for opinfo in self.opinfos) 739 740 def all_opinfo_attr(self, attr): 741 if not self.has_opinfo(): 742 raise RuntimeError 743 return all(getattr(opinfo, attr) for opinfo in self.opinfos) 744 745 def supports_vjp(self): 746 if self.name in FACTORY_FNS: 747 return Support.YES 748 if self.name in VJP_EXEMPTIONS: 749 return Support.YES 750 return self.no_opinfos_skip_test("test_vjp") 751 752 def supports_vmap(self): 753 if self.name in FACTORY_FNS: 754 return Support.YES 755 if self.name in VMAP_EXEMPTIONS: 756 return Support.YES 757 return self.no_opinfos_skip_test("test_vmap_exhaustive") 758 759 def supports_fast_vmap(self): 760 if self.name in FACTORY_FNS: 761 return Support.YES 762 if self.name in VMAP_EXEMPTIONS: 763 return Support.YES 764 return self.no_opinfos_skip_test("test_op_has_batch_rule") 765 766 def supports_vmapvjp(self): 767 if self.name in FACTORY_FNS: 768 return Support.YES 769 if self.name in VMAP_EXEMPTIONS: 770 return Support.YES 771 return self.no_opinfos_skip_test("test_vmapvjp") 772 773 def supports_fast_vmapvjp(self): 774 if self.name in FACTORY_FNS: 775 return Support.YES 776 if self.name in VMAP_EXEMPTIONS: 777 return Support.YES 778 return self.no_opinfos_skip_test("test_vmapvjp_has_batch_rule") 779 780 def supports_jvp(self): 781 if self.name in FACTORY_FNS: 782 return Support.YES 783 if self.name in JVP_EXEMPTIONS: 784 return Support.YES 785 if not self.has_opinfo(): 786 return Support.UNKNOWN 787 if self.any_opinfo_attr("supports_autograd") and not self.all_opinfo_attr( 788 "supports_forward_ad" 789 ): 790 return Support.NO 791 return self.no_opinfos_skip_test("test_jvp") 792 793 def supports_jvpvjp(self): 794 if self.name in FACTORY_FNS: 795 return Support.YES 796 exemptions = { 797 # we have support (see OpInfo), testing artifact 798 "nn.functional.dropout2d", 799 "nn.functional.dropout", 800 # exception: we dont even support double backward for this 801 "nn.functional.hardswish", 802 "bernoulli", # this isn't differentiable 803 "normal", # not differentiable 804 } 805 if self.name in exemptions: 806 return Support.YES 807 return self.no_opinfos_skip_test("test_jvpvjp") 808 809 def _supports_vmapjvp_base(self, test): 810 if self.name in FACTORY_FNS: 811 return Support.YES 812 VMAPJVP_EXEMPTIONS = { 813 "prod", # dynamic (backward) 814 "nn.functional.batch_norm", # testing problem 815 "normal", # not actually problem, randomness testing artifact 816 "bernoulli", # not actually problem, randomness testing artifact 817 "nn.functional.dropout2d", # not actually problem, randomness testing artifact 818 "nn.functional.dropout", # not actually problem, randomness testing artifact 819 # Not a problem. 820 # It's just that the max_norm testing mutates inputs... 821 # (we have our own functorch variant of the OpInfo without max_norm) 822 "nn.functional.embedding", 823 } 824 if self.name in VMAPJVP_EXEMPTIONS: 825 return Support.YES 826 if not self.has_opinfo(): 827 return Support.UNKNOWN 828 if self.any_opinfo_attr("supports_autograd") and not self.all_opinfo_attr( 829 "supports_forward_ad" 830 ): 831 return Support.NO 832 return self.no_opinfos_skip_test(test) 833 834 def supports_vmapjvp(self): 835 return self._supports_vmapjvp_base("test_vmapjvpall") 836 837 def supports_fast_vmapjvp(self): 838 return self._supports_vmapjvp_base("test_vmapjvpall_has_batch_rule") 839 840 841class OperatorSet: 842 def __init__(self, operators): 843 self.data = set(operators) 844 845 @classmethod 846 def from_names(cls, names): 847 return OperatorSet([Operator(name) for name in names]) 848 849 @classmethod 850 def from_top_ops_threshold(cls, torch_threshold, nn_fn_threshold): 851 names = get_top_ops(torch_threshold, nn_fn_threshold) 852 return cls.from_names(names) 853 854 @classmethod 855 def from_top125(cls): 856 return cls.from_top_ops_threshold(100, 25) 857 858 @classmethod 859 def from_top160(cls): 860 return cls.from_top_ops_threshold(107, 53) 861 862 @classmethod 863 def all(cls): 864 dct = get_public_overridable_outplace_we_care_about() 865 names = dct.keys() 866 names_sanitized = [] 867 for n in names: 868 torch_tensor = "torch.Tensor." 869 torch_dot = "torch." 870 if n.startswith(torch_tensor): 871 names_sanitized.append(n[len(torch_tensor) :]) 872 elif n.startswith(torch_dot): 873 names_sanitized.append(n[len(torch_dot) :]) 874 else: 875 raise AssertionError 876 return cls.from_names(names_sanitized) 877 878 def query(self, operator_method, filter=(Support.NO, Support.YES, Support.UNKNOWN)): 879 result = {} 880 for key in filter: 881 result[key] = set() 882 for op in self.data: 883 support_status = operator_method(op) 884 if support_status in filter: 885 result[support_status].add(op) 886 return result 887 888 def summary(self): 889 checks = [ 890 "supports_vjp", 891 "supports_vmap", 892 "supports_fast_vmap", 893 "supports_vmapvjp", 894 "supports_fast_vmapvjp", 895 "supports_jvp", 896 "supports_vmapjvp", 897 "supports_fast_vmapjvp", 898 "supports_jvpvjp", 899 ] 900 result = ["test, yes, no, unknown"] 901 for check in checks: 902 accessor = getattr(Operator, check) 903 all_results = self.query(accessor) 904 yes_amt = len(all_results[Support.YES]) 905 no_amt = len(all_results[Support.NO]) 906 unknown_amt = len(all_results[Support.UNKNOWN]) 907 result.append(f"{check}, {yes_amt}, {no_amt}, {unknown_amt}") 908 return "\n".join(result) 909 910 911opset = OperatorSet.all() 912has_no_opinfo = opset.query(Operator.has_opinfo, (False,)) 913 914print("=" * 30 + " Summary " + "=" * 30) 915print(f"% of usages on github: {get_ops_percentage(99999, 99999)}") 916print(opset.summary()) 917 918# sanity checks 919result = opset.query(Operator.supports_vjp, (Support.NO, Support.UNKNOWN)) 920# pprint.pprint(result) 921 922print("=" * 30 + " Top 60 Summary " + "=" * 30) 923print(f"% of usages on github: {get_ops_percentage(35, 25)}") 924opset = OperatorSet.from_top_ops_threshold(35, 25) 925# result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN)) 926# pprint.pprint(result) 927# result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN)) 928# pprint.pprint(result) 929# kresult = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN)) 930# kpprint.pprint(result) 931# result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN)) 932# pprint.pprint(result) 933# result = opset.query(Operator.supports_fast_vmapjvp, (Support.NO, Support.UNKNOWN)) 934# pprint.pprint(result) 935# pprint.pprint(result) 936print(opset.summary()) 937 938print("=" * 30 + " Top 125 Summary " + "=" * 30) 939print(f"% of usages on github: {get_ops_percentage(100, 25)}") 940opset = OperatorSet.from_top125() 941# result = opset.query(Operator.supports_vmap, (Support.NO, Support.UNKNOWN)) 942# pprint.pprint(result) 943# result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN)) 944# pprint.pprint(result) 945print("supports_vjp") 946result = opset.query(Operator.supports_vjp, (Support.NO, Support.UNKNOWN)) 947pprint.pprint(result) 948print("supports_jvp") 949result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN)) 950pprint.pprint(result) 951print("supports_vmapjvp") 952result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN)) 953pprint.pprint(result) 954print("supports_jvpvjp") 955result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN)) 956pprint.pprint(result) 957# result = opset.query(Operator.supports_fast_vmapjvp, (Support.NO, Support.UNKNOWN)) 958# pprint.pprint(result) 959# pprint.pprint(result) 960print(opset.summary()) 961 962# print("=" * 30 + " Top 160 Summary " + "=" * 30) 963# opset = OperatorSet.from_top160() 964# result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN)) 965# pprint.pprint(result) 966# print(opset.summary()) 967 968# Print list of everything in order 969# all_ops = get_top_ops(999999, 999999, with_counts=True) 970# for op, count in all_ops: 971# print(f'{op}, {count}') 972