1# Owner(s): ["oncall: jit"] 2 3import torch 4from torch.testing import FileCheck 5from torch.testing._internal.jit_utils import JitTestCase 6 7 8if __name__ == "__main__": 9 raise RuntimeError( 10 "This test file is not meant to be run directly, use:\n\n" 11 "\tpython test/test_jit.py TestPythonBindings\n\n" 12 "instead." 13 ) 14 15 16class TestPythonBindings(JitTestCase): 17 def test_cu_get_functions(self): 18 @torch.jit.script 19 def test_get_python_cu_fn(x: torch.Tensor): 20 return 2 * x 21 22 cu = torch.jit._state._python_cu 23 self.assertTrue( 24 "test_get_python_cu_fn" in (str(fn.name) for fn in cu.get_functions()) 25 ) 26 27 def test_cu_create_function(self): 28 @torch.jit.script 29 def fn(x: torch.Tensor): 30 return 2 * x 31 32 cu = torch._C.CompilationUnit() 33 cu.create_function("test_fn", fn.graph) 34 35 inp = torch.randn(5) 36 37 self.assertEqual(inp * 2, cu.find_function("test_fn")(inp)) 38 self.assertEqual(cu.find_function("doesnt_exist"), None) 39 self.assertEqual(inp * 2, cu.test_fn(inp)) 40 with self.assertRaises(AttributeError): 41 cu.doesnt_exist(inp) 42 43 def test_invalidation(self): 44 @torch.jit.script 45 def test_invalidation_fn(x: torch.Tensor): 46 return 2 * x 47 48 gr = test_invalidation_fn.graph.copy() 49 n = gr.insertNode(gr.create("prim::profile")) 50 v = n.output() 51 # check that they work 52 str((n, v)) 53 torch._C._jit_pass_dce(gr) 54 with self.assertRaisesRegex(RuntimeError, "invalidated"): 55 str(n) 56 with self.assertRaisesRegex(RuntimeError, "invalidated"): 57 str(v) 58 59 def test_graph_iterator_keepalive(self): 60 @torch.jit.script 61 def test_iterator_keepalive_fn(x: torch.Tensor): 62 return 2 * x 63 64 # the list would segfault before because inlined_graph 65 # is temporary and had been deleted (see issue #50454) 66 n = test_iterator_keepalive_fn.inlined_graph.nodes() 67 list(n) 68 i = test_iterator_keepalive_fn.inlined_graph.inputs() 69 list(i) 70 o = test_iterator_keepalive_fn.inlined_graph.outputs() 71 list(o) 72 73 def test_aliasdb(self): 74 @torch.jit.script 75 def test_aliasdb_fn(x: torch.Tensor): 76 return 2 * x 77 78 gr = test_aliasdb_fn.graph.copy() 79 alias_db = gr.alias_db() 80 self.assertTrue("WILDCARD" in str(alias_db)) 81 self.assertTrue("digraph alias_db" in alias_db.to_graphviz_str()) 82 83 def test_graph_create(self): 84 gr = torch._C.Graph() 85 with self.assertRaises(ValueError): 86 gr.create("prim::Constant", [None]) 87 88 def test_add_input(self): 89 gr = torch._C.Graph() 90 foo_value = gr.addInput("foo") 91 assert foo_value in gr.inputs() 92 93 def test_canonicalize(self): 94 ir = """ 95graph(%p207 : Tensor, 96 %1 : Tensor, 97 %p407 : int): 98 %11 : Tensor = aten::view_expand_placeholder(%1) 99 %12 : Tensor = aten::pointwise_placeholder(%11, %p207, %p407) 100 %13 : Tensor = aten::view_expand_placeholder(%12) 101 %14 : Tensor = aten::pointwise_placeholder(%13) 102 return (%14) 103 """ 104 105 graph1 = torch._C.parse_ir(ir) 106 graph1 = torch._C._jit_pass_canonicalize(graph1, True) 107 108 graph2 = torch._C.parse_ir(ir) 109 graph2 = torch._C._jit_pass_canonicalize(graph2) 110 111 self.assertEqual(str(graph1), str(graph2)) 112 FileCheck().check("%p207").check_not("%14").run(graph1) 113 114 graph3 = torch._C.parse_ir(ir) 115 graph3 = torch._C._jit_pass_canonicalize(graph3, False) 116 FileCheck().check_not("%p207").run(graph3) 117