xref: /aosp_15_r20/external/pytorch/test/jit/test_python_bindings.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import torch
4from torch.testing import FileCheck
5from torch.testing._internal.jit_utils import JitTestCase
6
7
8if __name__ == "__main__":
9    raise RuntimeError(
10        "This test file is not meant to be run directly, use:\n\n"
11        "\tpython test/test_jit.py TestPythonBindings\n\n"
12        "instead."
13    )
14
15
16class TestPythonBindings(JitTestCase):
17    def test_cu_get_functions(self):
18        @torch.jit.script
19        def test_get_python_cu_fn(x: torch.Tensor):
20            return 2 * x
21
22        cu = torch.jit._state._python_cu
23        self.assertTrue(
24            "test_get_python_cu_fn" in (str(fn.name) for fn in cu.get_functions())
25        )
26
27    def test_cu_create_function(self):
28        @torch.jit.script
29        def fn(x: torch.Tensor):
30            return 2 * x
31
32        cu = torch._C.CompilationUnit()
33        cu.create_function("test_fn", fn.graph)
34
35        inp = torch.randn(5)
36
37        self.assertEqual(inp * 2, cu.find_function("test_fn")(inp))
38        self.assertEqual(cu.find_function("doesnt_exist"), None)
39        self.assertEqual(inp * 2, cu.test_fn(inp))
40        with self.assertRaises(AttributeError):
41            cu.doesnt_exist(inp)
42
43    def test_invalidation(self):
44        @torch.jit.script
45        def test_invalidation_fn(x: torch.Tensor):
46            return 2 * x
47
48        gr = test_invalidation_fn.graph.copy()
49        n = gr.insertNode(gr.create("prim::profile"))
50        v = n.output()
51        # check that they work
52        str((n, v))
53        torch._C._jit_pass_dce(gr)
54        with self.assertRaisesRegex(RuntimeError, "invalidated"):
55            str(n)
56        with self.assertRaisesRegex(RuntimeError, "invalidated"):
57            str(v)
58
59    def test_graph_iterator_keepalive(self):
60        @torch.jit.script
61        def test_iterator_keepalive_fn(x: torch.Tensor):
62            return 2 * x
63
64        # the list would segfault before because inlined_graph
65        # is temporary and had been deleted (see issue #50454)
66        n = test_iterator_keepalive_fn.inlined_graph.nodes()
67        list(n)
68        i = test_iterator_keepalive_fn.inlined_graph.inputs()
69        list(i)
70        o = test_iterator_keepalive_fn.inlined_graph.outputs()
71        list(o)
72
73    def test_aliasdb(self):
74        @torch.jit.script
75        def test_aliasdb_fn(x: torch.Tensor):
76            return 2 * x
77
78        gr = test_aliasdb_fn.graph.copy()
79        alias_db = gr.alias_db()
80        self.assertTrue("WILDCARD" in str(alias_db))
81        self.assertTrue("digraph alias_db" in alias_db.to_graphviz_str())
82
83    def test_graph_create(self):
84        gr = torch._C.Graph()
85        with self.assertRaises(ValueError):
86            gr.create("prim::Constant", [None])
87
88    def test_add_input(self):
89        gr = torch._C.Graph()
90        foo_value = gr.addInput("foo")
91        assert foo_value in gr.inputs()
92
93    def test_canonicalize(self):
94        ir = """
95graph(%p207 : Tensor,
96      %1 : Tensor,
97      %p407 : int):
98  %11 : Tensor = aten::view_expand_placeholder(%1)
99  %12 : Tensor = aten::pointwise_placeholder(%11, %p207, %p407)
100  %13 : Tensor = aten::view_expand_placeholder(%12)
101  %14 : Tensor = aten::pointwise_placeholder(%13)
102  return (%14)
103        """
104
105        graph1 = torch._C.parse_ir(ir)
106        graph1 = torch._C._jit_pass_canonicalize(graph1, True)
107
108        graph2 = torch._C.parse_ir(ir)
109        graph2 = torch._C._jit_pass_canonicalize(graph2)
110
111        self.assertEqual(str(graph1), str(graph2))
112        FileCheck().check("%p207").check_not("%14").run(graph1)
113
114        graph3 = torch._C.parse_ir(ir)
115        graph3 = torch._C._jit_pass_canonicalize(graph3, False)
116        FileCheck().check_not("%p207").run(graph3)
117