xref: /aosp_15_r20/external/pytorch/test/test_jit_disabled.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import sys
4import os
5import contextlib
6import subprocess
7from torch.testing._internal.common_utils import TestCase, run_tests, TemporaryFileName
8
9
10@contextlib.contextmanager
11def _jit_disabled():
12    cur_env = os.environ.get("PYTORCH_JIT", "1")
13    os.environ["PYTORCH_JIT"] = "0"
14    try:
15        yield
16    finally:
17        os.environ["PYTORCH_JIT"] = cur_env
18
19
20class TestJitDisabled(TestCase):
21    """
22    These tests are separate from the rest of the JIT tests because we need
23    run a new subprocess and `import torch` with the correct environment
24    variables set.
25    """
26
27    def compare_enabled_disabled(self, src):
28        """
29        Runs the script in `src` with PYTORCH_JIT enabled and disabled and
30        compares their stdout for equality.
31        """
32        # Write `src` out to a temporary so our source inspection logic works
33        # correctly.
34        with TemporaryFileName() as fname:
35            with open(fname, 'w') as f:
36                f.write(src)
37                with _jit_disabled():
38                    out_disabled = subprocess.check_output([
39                        sys.executable,
40                        fname])
41                out_enabled = subprocess.check_output([
42                    sys.executable,
43                    fname])
44                self.assertEqual(out_disabled, out_enabled)
45
46    def test_attribute(self):
47        _program_string = """
48import torch
49
50class Foo(torch.jit.ScriptModule):
51    def __init__(self, x):
52        super().__init__()
53        self.x = torch.jit.Attribute(x, torch.Tensor)
54
55    def forward(self, input):
56        return input
57
58s = Foo(torch.ones(2, 3))
59print(s.x)
60"""
61        self.compare_enabled_disabled(_program_string)
62
63    def test_script_module_construction(self):
64        _program_string = """
65import torch
66
67class AModule(torch.jit.ScriptModule):
68    @torch.jit.script_method
69    def forward(self, input):
70        pass
71
72AModule()
73print("Didn't throw exception")
74"""
75        self.compare_enabled_disabled(_program_string)
76
77    def test_recursive_script(self):
78        _program_string = """
79import torch
80
81class AModule(torch.nn.Module):
82    def forward(self, input):
83        pass
84
85sm = torch.jit.script(AModule())
86print("Didn't throw exception")
87"""
88        self.compare_enabled_disabled(_program_string)
89
90if __name__ == '__main__':
91    run_tests()
92