1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5import unittest 6 7import torch 8 9 10# Make the helper files in test/ importable 11pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 12sys.path.append(pytorch_test_dir) 13from torch.jit.frontend import _IS_ASTUNPARSE_INSTALLED 14from torch.testing._internal.jit_utils import JitTestCase 15 16 17if __name__ == "__main__": 18 raise RuntimeError( 19 "This test file is not meant to be run directly, use:\n\n" 20 "\tpython test/test_jit.py TESTNAME\n\n" 21 "instead." 22 ) 23 24 25class TestIgnoreContextManager(JitTestCase): 26 @unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required") 27 def test_with_ignore_context_manager_with_inp_out(self): 28 class A(torch.nn.Module): 29 def forward(self): 30 a: int = 4 31 b: int = 5 32 c: int = 0 33 d: int = 6 34 with torch.jit._IgnoreContextManager( 35 a="inp:int", b="inp:int", c="out:int", d="out:int" 36 ): 37 l = [2 for i in range(a) if i > 2] 38 c = l[0] + a + b 39 d = 9 40 return c + d 41 42 model = A() 43 s = torch.jit.script(model) 44 self.assertEqual(s(), model()) 45 self.assertEqual(s(), 20) 46 47 class B(torch.nn.Module): 48 def forward(self): 49 a: int = 4 50 b: int = 5 51 c: int = 0 52 with torch.jit._IgnoreContextManager( 53 a="inp:int", b="inp:int", c="out:int" 54 ): 55 l = [2 for i in range(a) if i > 2] 56 c = l[0] + a + b 57 return c 58 59 model = B() 60 s = torch.jit.script(model) 61 self.assertEqual(s(), 11) 62 self.assertEqual(s(), model()) 63 64 class C(torch.nn.Module): 65 def forward(self): 66 a: int = 4 67 b: int = 5 68 with torch.jit._IgnoreContextManager(a="inp:int", b="out:int"): 69 l = [2 for i in range(a) if i > 2] 70 b = l[0] + a 71 return b 72 73 model = C() 74 s = torch.jit.script(model) 75 self.assertEqual(s(), 6) 76 self.assertEqual(s(), model()) 77 78 @unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required") 79 def test_with_ignore_context_manager_with_just_inp(self): 80 class A(torch.nn.Module): 81 def forward(self): 82 a: int = 4 83 b: int = 5 84 with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int"): 85 l = [2 + b for i in range(a) if i > 2] 86 return a 87 88 model = A() 89 s = torch.jit.script(model) 90 self.assertEqual(s(), 4) 91 self.assertEqual(s(), model()) 92 93 @unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required") 94 def test_with_ignore_context_manager_with_just_out(self): 95 class A(torch.nn.Module): 96 def forward(self): 97 with torch.jit._IgnoreContextManager(c="out:List[int]"): 98 c = [2 for i in range(7) if i > 2] 99 c[0] = 3 100 return c[0] + c[1] 101 102 model = A() 103 s = torch.jit.script(model) 104 self.assertEqual(s(), 5) 105 self.assertEqual(s(), model()) 106