1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport contextlib 4*da0073e9SAndroid Build Coastguard Workerimport os 5*da0073e9SAndroid Build Coastguard Workerimport sys 6*da0073e9SAndroid Build Coastguard Workerimport unittest 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 12*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 13*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 14*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 15*da0073e9SAndroid Build Coastguard Worker IS_FBCODE, 16*da0073e9SAndroid Build Coastguard Worker run_tests, 17*da0073e9SAndroid Build Coastguard Worker set_default_dtype, 18*da0073e9SAndroid Build Coastguard Worker suppress_warnings, 19*da0073e9SAndroid Build Coastguard Worker) 20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_metaprogramming_utils import ( 21*da0073e9SAndroid Build Coastguard Worker get_all_nn_module_tests, 22*da0073e9SAndroid Build Coastguard Worker get_nn_functional_compiled_fn_and_inputs, 23*da0073e9SAndroid Build Coastguard Worker get_nn_mod_test_name, 24*da0073e9SAndroid Build Coastguard Worker nn_functional_tests, 25*da0073e9SAndroid Build Coastguard Worker try_get_nn_module_compiled_mod_and_inputs, 26*da0073e9SAndroid Build Coastguard Worker) 27*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import enable_profiling_mode, JitTestCase 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Workerdef num_ifs_loops(graph): 31*da0073e9SAndroid Build Coastguard Worker graph_str = str(graph) 32*da0073e9SAndroid Build Coastguard Worker # only look at body of graph 33*da0073e9SAndroid Build Coastguard Worker graph_body = graph_str[0 : graph_str.find("return")] 34*da0073e9SAndroid Build Coastguard Worker return graph_body.count("prim::Loop") + graph_body.count("prim::If") 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Workerdef num_non_tensor_nodes(block): 38*da0073e9SAndroid Build Coastguard Worker num_non_tensor = 0 39*da0073e9SAndroid Build Coastguard Worker for node in block.nodes(): 40*da0073e9SAndroid Build Coastguard Worker kind = node.kind() 41*da0073e9SAndroid Build Coastguard Worker # GetAttr don't provide useful signal here, since they are non-optimizable except with freezing 42*da0073e9SAndroid Build Coastguard Worker # Constant is not executed, bailouts should be a separate tests, don't provide useful signal here 43*da0073e9SAndroid Build Coastguard Worker if kind == "prim::Constant" or "prim::Bailout" in kind or "GetAttr" in kind: 44*da0073e9SAndroid Build Coastguard Worker continue 45*da0073e9SAndroid Build Coastguard Worker for b in node.blocks(): 46*da0073e9SAndroid Build Coastguard Worker num_non_tensor += num_non_tensor_nodes(b) 47*da0073e9SAndroid Build Coastguard Worker tensor_out = False 48*da0073e9SAndroid Build Coastguard Worker for out in node.outputs(): 49*da0073e9SAndroid Build Coastguard Worker if "Tensor" in str(out.type()): 50*da0073e9SAndroid Build Coastguard Worker tensor_out = True 51*da0073e9SAndroid Build Coastguard Worker break 52*da0073e9SAndroid Build Coastguard Worker num_non_tensor += int(not tensor_out) 53*da0073e9SAndroid Build Coastguard Worker return num_non_tensor 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Workerclass TestComplexity(JitTestCase): 57*da0073e9SAndroid Build Coastguard Worker def setUp(self): 58*da0073e9SAndroid Build Coastguard Worker super().setUp() 59*da0073e9SAndroid Build Coastguard Worker self.grad_enabled = torch.is_grad_enabled() 60*da0073e9SAndroid Build Coastguard Worker torch.set_grad_enabled(False) 61*da0073e9SAndroid Build Coastguard Worker self._stack = contextlib.ExitStack() 62*da0073e9SAndroid Build Coastguard Worker self._stack.enter_context(set_default_dtype(torch.double)) 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 65*da0073e9SAndroid Build Coastguard Worker self._stack.close() 66*da0073e9SAndroid Build Coastguard Worker torch.set_grad_enabled(self.grad_enabled) 67*da0073e9SAndroid Build Coastguard Worker super().tearDown() 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 70*da0073e9SAndroid Build Coastguard Worker def test_generated_functional_tests(self): 71*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode(): 72*da0073e9SAndroid Build Coastguard Worker stats = [("Name", "Ifs/Loops", "non-tensor ops")] 73*da0073e9SAndroid Build Coastguard Worker for test in nn_functional_tests: 74*da0073e9SAndroid Build Coastguard Worker test_name = test[0] 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker fn, inputs = get_nn_functional_compiled_fn_and_inputs(*test) 77*da0073e9SAndroid Build Coastguard Worker for _ in range(6): 78*da0073e9SAndroid Build Coastguard Worker fn(*inputs) 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 81*da0073e9SAndroid Build Coastguard Worker stats.append((test_name, num_ifs_loops(g), num_non_tensor_nodes(g))) 82*da0073e9SAndroid Build Coastguard Worker for line in stats: 83*da0073e9SAndroid Build Coastguard Worker print(line) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 86*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_FBCODE, "Causes a RecursionError in fbcode") 87*da0073e9SAndroid Build Coastguard Worker def test_nn_module_tests(self): 88*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode(): 89*da0073e9SAndroid Build Coastguard Worker stats = [("Name", "Ifs/Loops", "non-tensor ops")] 90*da0073e9SAndroid Build Coastguard Worker for test in get_all_nn_module_tests(): 91*da0073e9SAndroid Build Coastguard Worker out = try_get_nn_module_compiled_mod_and_inputs(**test) 92*da0073e9SAndroid Build Coastguard Worker if not out: 93*da0073e9SAndroid Build Coastguard Worker continue 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker mod, inputs = out 96*da0073e9SAndroid Build Coastguard Worker test_name = get_nn_mod_test_name(**test) 97*da0073e9SAndroid Build Coastguard Worker for _ in range(6): 98*da0073e9SAndroid Build Coastguard Worker mod(*inputs) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 101*da0073e9SAndroid Build Coastguard Worker stats.append((test_name, num_ifs_loops(g), num_non_tensor_nodes(g))) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker for line in stats: 104*da0073e9SAndroid Build Coastguard Worker print(line) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 108*da0073e9SAndroid Build Coastguard Worker run_tests() 109