xref: /aosp_15_r20/external/pytorch/test/jit/test_ignore_context_manager.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5import unittest
6
7import torch
8
9
10# Make the helper files in test/ importable
11pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
12sys.path.append(pytorch_test_dir)
13from torch.jit.frontend import _IS_ASTUNPARSE_INSTALLED
14from torch.testing._internal.jit_utils import JitTestCase
15
16
17if __name__ == "__main__":
18    raise RuntimeError(
19        "This test file is not meant to be run directly, use:\n\n"
20        "\tpython test/test_jit.py TESTNAME\n\n"
21        "instead."
22    )
23
24
25class TestIgnoreContextManager(JitTestCase):
26    @unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
27    def test_with_ignore_context_manager_with_inp_out(self):
28        class A(torch.nn.Module):
29            def forward(self):
30                a: int = 4
31                b: int = 5
32                c: int = 0
33                d: int = 6
34                with torch.jit._IgnoreContextManager(
35                    a="inp:int", b="inp:int", c="out:int", d="out:int"
36                ):
37                    l = [2 for i in range(a) if i > 2]
38                    c = l[0] + a + b
39                    d = 9
40                return c + d
41
42        model = A()
43        s = torch.jit.script(model)
44        self.assertEqual(s(), model())
45        self.assertEqual(s(), 20)
46
47        class B(torch.nn.Module):
48            def forward(self):
49                a: int = 4
50                b: int = 5
51                c: int = 0
52                with torch.jit._IgnoreContextManager(
53                    a="inp:int", b="inp:int", c="out:int"
54                ):
55                    l = [2 for i in range(a) if i > 2]
56                    c = l[0] + a + b
57                return c
58
59        model = B()
60        s = torch.jit.script(model)
61        self.assertEqual(s(), 11)
62        self.assertEqual(s(), model())
63
64        class C(torch.nn.Module):
65            def forward(self):
66                a: int = 4
67                b: int = 5
68                with torch.jit._IgnoreContextManager(a="inp:int", b="out:int"):
69                    l = [2 for i in range(a) if i > 2]
70                    b = l[0] + a
71                return b
72
73        model = C()
74        s = torch.jit.script(model)
75        self.assertEqual(s(), 6)
76        self.assertEqual(s(), model())
77
78    @unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
79    def test_with_ignore_context_manager_with_just_inp(self):
80        class A(torch.nn.Module):
81            def forward(self):
82                a: int = 4
83                b: int = 5
84                with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int"):
85                    l = [2 + b for i in range(a) if i > 2]
86                return a
87
88        model = A()
89        s = torch.jit.script(model)
90        self.assertEqual(s(), 4)
91        self.assertEqual(s(), model())
92
93    @unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
94    def test_with_ignore_context_manager_with_just_out(self):
95        class A(torch.nn.Module):
96            def forward(self):
97                with torch.jit._IgnoreContextManager(c="out:List[int]"):
98                    c = [2 for i in range(7) if i > 2]
99                c[0] = 3
100                return c[0] + c[1]
101
102        model = A()
103        s = torch.jit.script(model)
104        self.assertEqual(s(), 5)
105        self.assertEqual(s(), model())
106