1# Owner(s): ["oncall: jit"] 2 3import sys 4import os 5import contextlib 6import subprocess 7from torch.testing._internal.common_utils import TestCase, run_tests, TemporaryFileName 8 9 10@contextlib.contextmanager 11def _jit_disabled(): 12 cur_env = os.environ.get("PYTORCH_JIT", "1") 13 os.environ["PYTORCH_JIT"] = "0" 14 try: 15 yield 16 finally: 17 os.environ["PYTORCH_JIT"] = cur_env 18 19 20class TestJitDisabled(TestCase): 21 """ 22 These tests are separate from the rest of the JIT tests because we need 23 run a new subprocess and `import torch` with the correct environment 24 variables set. 25 """ 26 27 def compare_enabled_disabled(self, src): 28 """ 29 Runs the script in `src` with PYTORCH_JIT enabled and disabled and 30 compares their stdout for equality. 31 """ 32 # Write `src` out to a temporary so our source inspection logic works 33 # correctly. 34 with TemporaryFileName() as fname: 35 with open(fname, 'w') as f: 36 f.write(src) 37 with _jit_disabled(): 38 out_disabled = subprocess.check_output([ 39 sys.executable, 40 fname]) 41 out_enabled = subprocess.check_output([ 42 sys.executable, 43 fname]) 44 self.assertEqual(out_disabled, out_enabled) 45 46 def test_attribute(self): 47 _program_string = """ 48import torch 49 50class Foo(torch.jit.ScriptModule): 51 def __init__(self, x): 52 super().__init__() 53 self.x = torch.jit.Attribute(x, torch.Tensor) 54 55 def forward(self, input): 56 return input 57 58s = Foo(torch.ones(2, 3)) 59print(s.x) 60""" 61 self.compare_enabled_disabled(_program_string) 62 63 def test_script_module_construction(self): 64 _program_string = """ 65import torch 66 67class AModule(torch.jit.ScriptModule): 68 @torch.jit.script_method 69 def forward(self, input): 70 pass 71 72AModule() 73print("Didn't throw exception") 74""" 75 self.compare_enabled_disabled(_program_string) 76 77 def test_recursive_script(self): 78 _program_string = """ 79import torch 80 81class AModule(torch.nn.Module): 82 def forward(self, input): 83 pass 84 85sm = torch.jit.script(AModule()) 86print("Didn't throw exception") 87""" 88 self.compare_enabled_disabled(_program_string) 89 90if __name__ == '__main__': 91 run_tests() 92