1# Owner(s): ["oncall: jit"] 2 3import contextlib 4import os 5import sys 6import unittest 7 8import torch 9 10 11# Make the helper files in test/ importable 12pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 13sys.path.append(pytorch_test_dir) 14from torch.testing._internal.common_utils import ( 15 IS_FBCODE, 16 run_tests, 17 set_default_dtype, 18 suppress_warnings, 19) 20from torch.testing._internal.jit_metaprogramming_utils import ( 21 get_all_nn_module_tests, 22 get_nn_functional_compiled_fn_and_inputs, 23 get_nn_mod_test_name, 24 nn_functional_tests, 25 try_get_nn_module_compiled_mod_and_inputs, 26) 27from torch.testing._internal.jit_utils import enable_profiling_mode, JitTestCase 28 29 30def num_ifs_loops(graph): 31 graph_str = str(graph) 32 # only look at body of graph 33 graph_body = graph_str[0 : graph_str.find("return")] 34 return graph_body.count("prim::Loop") + graph_body.count("prim::If") 35 36 37def num_non_tensor_nodes(block): 38 num_non_tensor = 0 39 for node in block.nodes(): 40 kind = node.kind() 41 # GetAttr don't provide useful signal here, since they are non-optimizable except with freezing 42 # Constant is not executed, bailouts should be a separate tests, don't provide useful signal here 43 if kind == "prim::Constant" or "prim::Bailout" in kind or "GetAttr" in kind: 44 continue 45 for b in node.blocks(): 46 num_non_tensor += num_non_tensor_nodes(b) 47 tensor_out = False 48 for out in node.outputs(): 49 if "Tensor" in str(out.type()): 50 tensor_out = True 51 break 52 num_non_tensor += int(not tensor_out) 53 return num_non_tensor 54 55 56class TestComplexity(JitTestCase): 57 def setUp(self): 58 super().setUp() 59 self.grad_enabled = torch.is_grad_enabled() 60 torch.set_grad_enabled(False) 61 self._stack = contextlib.ExitStack() 62 self._stack.enter_context(set_default_dtype(torch.double)) 63 64 def tearDown(self): 65 self._stack.close() 66 torch.set_grad_enabled(self.grad_enabled) 67 super().tearDown() 68 69 @suppress_warnings 70 def test_generated_functional_tests(self): 71 with enable_profiling_mode(): 72 stats = [("Name", "Ifs/Loops", "non-tensor ops")] 73 for test in nn_functional_tests: 74 test_name = test[0] 75 76 fn, inputs = get_nn_functional_compiled_fn_and_inputs(*test) 77 for _ in range(6): 78 fn(*inputs) 79 80 g = torch.jit.last_executed_optimized_graph() 81 stats.append((test_name, num_ifs_loops(g), num_non_tensor_nodes(g))) 82 for line in stats: 83 print(line) 84 85 @suppress_warnings 86 @unittest.skipIf(IS_FBCODE, "Causes a RecursionError in fbcode") 87 def test_nn_module_tests(self): 88 with enable_profiling_mode(): 89 stats = [("Name", "Ifs/Loops", "non-tensor ops")] 90 for test in get_all_nn_module_tests(): 91 out = try_get_nn_module_compiled_mod_and_inputs(**test) 92 if not out: 93 continue 94 95 mod, inputs = out 96 test_name = get_nn_mod_test_name(**test) 97 for _ in range(6): 98 mod(*inputs) 99 100 g = torch.jit.last_executed_optimized_graph() 101 stats.append((test_name, num_ifs_loops(g), num_non_tensor_nodes(g))) 102 103 for line in stats: 104 print(line) 105 106 107if __name__ == "__main__": 108 run_tests() 109