xref: /aosp_15_r20/external/pytorch/test/dynamo/test_modes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Workerfrom unittest.mock import patch
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
6*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
7*da0073e9SAndroid Build Coastguard Workerfrom torch._C import (
8*da0073e9SAndroid Build Coastguard Worker    _len_torch_function_stack,
9*da0073e9SAndroid Build Coastguard Worker    _pop_torch_function_stack,
10*da0073e9SAndroid Build Coastguard Worker    _push_on_torch_function_stack,
11*da0073e9SAndroid Build Coastguard Worker)
12*da0073e9SAndroid Build Coastguard Workerfrom torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode
13*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._device import DeviceContext
14*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerclass TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
18*da0073e9SAndroid Build Coastguard Worker    @classmethod
19*da0073e9SAndroid Build Coastguard Worker    def setUpClass(cls):
20*da0073e9SAndroid Build Coastguard Worker        super().setUpClass()
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker    @classmethod
23*da0073e9SAndroid Build Coastguard Worker    def tearDownClass(cls):
24*da0073e9SAndroid Build Coastguard Worker        super().tearDownClass()
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker    def test_skip_torch_dispatch_modes(self):
27*da0073e9SAndroid Build Coastguard Worker        class RewriteAddToMul(TorchDispatchMode):
28*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
29*da0073e9SAndroid Build Coastguard Worker                if func is torch.ops.aten.add.Tensor:
30*da0073e9SAndroid Build Coastguard Worker                    func = torch.ops.aten.mul.Tensor
31*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker        def fn(x):
34*da0073e9SAndroid Build Coastguard Worker            return x + x
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([3.0])
39*da0073e9SAndroid Build Coastguard Worker        with RewriteAddToMul():
40*da0073e9SAndroid Build Coastguard Worker            eager_res = fn(x)
41*da0073e9SAndroid Build Coastguard Worker            compiled_res = torch._dynamo.optimize(cnt)(fn)(x)
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(eager_res, compiled_res)
44*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 0)
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Workerclass TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
48*da0073e9SAndroid Build Coastguard Worker    @classmethod
49*da0073e9SAndroid Build Coastguard Worker    def setUpClass(cls):
50*da0073e9SAndroid Build Coastguard Worker        cls.default_device_old = torch.get_default_device()
51*da0073e9SAndroid Build Coastguard Worker        super().setUpClass()
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    @classmethod
54*da0073e9SAndroid Build Coastguard Worker    def tearDownClass(cls):
55*da0073e9SAndroid Build Coastguard Worker        torch.set_default_device(cls.default_device_old)
56*da0073e9SAndroid Build Coastguard Worker        super().tearDownClass()
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
59*da0073e9SAndroid Build Coastguard Worker        torch.set_default_device(None)
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
62*da0073e9SAndroid Build Coastguard Worker        torch.set_default_device(None)
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    def _run_torch_function_mode_guard_test(self):
65*da0073e9SAndroid Build Coastguard Worker        class TestMode1(BaseTorchFunctionMode):
66*da0073e9SAndroid Build Coastguard Worker            pass
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker        class TestMode2(BaseTorchFunctionMode):
69*da0073e9SAndroid Build Coastguard Worker            pass
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt.__call__)
74*da0073e9SAndroid Build Coastguard Worker        def fn(x):
75*da0073e9SAndroid Build Coastguard Worker            return x + 1
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(2, 2)
78*da0073e9SAndroid Build Coastguard Worker        fn(inp)
79*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker        with TestMode1():
82*da0073e9SAndroid Build Coastguard Worker            fn(inp)
83*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker        with TestMode1(), TestMode2():
86*da0073e9SAndroid Build Coastguard Worker            fn(inp)
87*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 3)
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker        with TestMode2(), TestMode1():
90*da0073e9SAndroid Build Coastguard Worker            fn(inp)
91*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 4)
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker        with TestMode1():
94*da0073e9SAndroid Build Coastguard Worker            fn(inp)
95*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 4)
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    def _run_ignored_mode_types_test(self):
98*da0073e9SAndroid Build Coastguard Worker        class IgnoredMode(BaseTorchFunctionMode):
99*da0073e9SAndroid Build Coastguard Worker            pass
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt.__call__, fullgraph=True)
104*da0073e9SAndroid Build Coastguard Worker        def fn(x):
105*da0073e9SAndroid Build Coastguard Worker            return x + 1
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(2, 2)
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker        with patch(
110*da0073e9SAndroid Build Coastguard Worker            "torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode}
111*da0073e9SAndroid Build Coastguard Worker        ):
112*da0073e9SAndroid Build Coastguard Worker            # initial compile
113*da0073e9SAndroid Build Coastguard Worker            fn(inp)
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker            # no recompile, mode ignored
116*da0073e9SAndroid Build Coastguard Worker            # note: the ref stack is length 0, and the stack we are checking against has length 2
117*da0073e9SAndroid Build Coastguard Worker            # we want to check both ref stack len > runtime stack, and ref stack len < runtime stack
118*da0073e9SAndroid Build Coastguard Worker            with IgnoredMode(), IgnoredMode():
119*da0073e9SAndroid Build Coastguard Worker                fn(inp)
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt.frame_count, 1)
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker            # recompile due to new mode on the stack
124*da0073e9SAndroid Build Coastguard Worker            with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
125*da0073e9SAndroid Build Coastguard Worker                fn(inp)
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt.frame_count, 2)
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker            # recompile
130*da0073e9SAndroid Build Coastguard Worker            # tests both ref stack len > runtime stack len for the above guard check
131*da0073e9SAndroid Build Coastguard Worker            # and ref stack len < runtime stack len for the initial zero mode case
132*da0073e9SAndroid Build Coastguard Worker            with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode():
133*da0073e9SAndroid Build Coastguard Worker                fn(inp)
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt.frame_count, 3)
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker            # no recompile
138*da0073e9SAndroid Build Coastguard Worker            with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
139*da0073e9SAndroid Build Coastguard Worker                fn(inp)
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt.frame_count, 3)
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker        # This is tricky, basically the ignored modes are baked into the guard
144*da0073e9SAndroid Build Coastguard Worker        # IgnoredMode will be ignored forever by that guard.
145*da0073e9SAndroid Build Coastguard Worker        # This is okay since we don't expect to be modifying IGNORED_MODES
146*da0073e9SAndroid Build Coastguard Worker        # in the middle of execution except for the purposes of testing.
147*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.reset()
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker        with IgnoredMode():
150*da0073e9SAndroid Build Coastguard Worker            fn(inp)
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 4)
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch("enable_cpp_guard_manager", False)
155*da0073e9SAndroid Build Coastguard Worker    def test_torch_function_mode_guards_ignored_types_py(self):
156*da0073e9SAndroid Build Coastguard Worker        self._run_ignored_mode_types_test()
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker    def test_torch_function_mode_guards_ignored_types_cpp(self):
159*da0073e9SAndroid Build Coastguard Worker        self._run_ignored_mode_types_test()
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch("enable_cpp_guard_manager", False)
162*da0073e9SAndroid Build Coastguard Worker    def test_torch_function_mode_guards_py(self):
163*da0073e9SAndroid Build Coastguard Worker        self._run_torch_function_mode_guard_test()
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker    def test_torch_function_mode_guards_cpp(self):
166*da0073e9SAndroid Build Coastguard Worker        self._run_torch_function_mode_guard_test()
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker    def test_stack_state_mutation_default_device(self):
169*da0073e9SAndroid Build Coastguard Worker        m = BaseTorchFunctionMode()
170*da0073e9SAndroid Build Coastguard Worker        m1 = BaseTorchFunctionMode()
171*da0073e9SAndroid Build Coastguard Worker        with m, m1:
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker            @torch.compile(fullgraph=True)
174*da0073e9SAndroid Build Coastguard Worker            def fn(x):
175*da0073e9SAndroid Build Coastguard Worker                torch.set_default_device("cpu")
176*da0073e9SAndroid Build Coastguard Worker                _pop_torch_function_stack()
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker            fn(torch.ones(2, 2))
179*da0073e9SAndroid Build Coastguard Worker            _push_on_torch_function_stack(m1)
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker            stack = _get_current_function_mode_stack()
182*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(stack[0], DeviceContext)
183*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(stack[0].device, torch.device("cpu"))
184*da0073e9SAndroid Build Coastguard Worker            self.assertIs(stack[1], m)
185*da0073e9SAndroid Build Coastguard Worker            self.assertIs(stack[2], m1)
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker    def test_stack_state_clear_default_device(self):
188*da0073e9SAndroid Build Coastguard Worker        @torch.compile(fullgraph=True)
189*da0073e9SAndroid Build Coastguard Worker        def fn(x):
190*da0073e9SAndroid Build Coastguard Worker            torch.set_default_device(None)
191*da0073e9SAndroid Build Coastguard Worker            return x + 1
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker        fn(torch.ones(2, 2))
194*da0073e9SAndroid Build Coastguard Worker        stack = _get_current_function_mode_stack()
195*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(stack), 0)
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker        m = BaseTorchFunctionMode()
198*da0073e9SAndroid Build Coastguard Worker        m1 = BaseTorchFunctionMode()
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker        # Stack populated, add device
201*da0073e9SAndroid Build Coastguard Worker        with m, m1:
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker            @torch.compile(fullgraph=True)
204*da0073e9SAndroid Build Coastguard Worker            def fn(x):
205*da0073e9SAndroid Build Coastguard Worker                torch.set_default_device("cpu")
206*da0073e9SAndroid Build Coastguard Worker                torch.set_default_device(None)
207*da0073e9SAndroid Build Coastguard Worker                torch.set_default_device("cpu")
208*da0073e9SAndroid Build Coastguard Worker                return x + 1
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker            fn(torch.ones(2, 2))
211*da0073e9SAndroid Build Coastguard Worker            stack = _get_current_function_mode_stack()
212*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(stack[0].device, torch.device("cpu"))
213*da0073e9SAndroid Build Coastguard Worker            self.assertIs(stack[1], m)
214*da0073e9SAndroid Build Coastguard Worker            self.assertIs(stack[2], m1)
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker        # Stack populated, remove device
217*da0073e9SAndroid Build Coastguard Worker        torch.set_default_device("cpu")
218*da0073e9SAndroid Build Coastguard Worker        with m, m1:
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker            @torch.compile(fullgraph=True)
221*da0073e9SAndroid Build Coastguard Worker            def fn(x):
222*da0073e9SAndroid Build Coastguard Worker                torch.set_default_device(None)
223*da0073e9SAndroid Build Coastguard Worker                return x + 1
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker            fn(torch.ones(2, 2))
226*da0073e9SAndroid Build Coastguard Worker            stack = _get_current_function_mode_stack()
227*da0073e9SAndroid Build Coastguard Worker            self.assertIs(stack[0], m)
228*da0073e9SAndroid Build Coastguard Worker            self.assertIs(stack[1], m1)
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker        @torch.compile(fullgraph=True)
231*da0073e9SAndroid Build Coastguard Worker        def fn(x):
232*da0073e9SAndroid Build Coastguard Worker            torch.set_default_device("cpu")
233*da0073e9SAndroid Build Coastguard Worker            torch.set_default_device("cpu")
234*da0073e9SAndroid Build Coastguard Worker            return x + 1
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker        fn(torch.ones(2, 2))
237*da0073e9SAndroid Build Coastguard Worker        stack = _get_current_function_mode_stack()
238*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(stack[0].device, torch.device("cpu"))
239*da0073e9SAndroid Build Coastguard Worker        torch.set_default_device(None)
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker    def test_pop_torch_function_mode(self):
242*da0073e9SAndroid Build Coastguard Worker        m = BaseTorchFunctionMode()
243*da0073e9SAndroid Build Coastguard Worker        with m:
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker            @torch.compile(fullgraph=True)
246*da0073e9SAndroid Build Coastguard Worker            def fn(x):
247*da0073e9SAndroid Build Coastguard Worker                _pop_torch_function_stack()
248*da0073e9SAndroid Build Coastguard Worker                return x + 1
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker            fn(torch.ones(2, 2))
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(_len_torch_function_stack(), 0)
253*da0073e9SAndroid Build Coastguard Worker            # reset stack so __exit__ doesn't crash
254*da0073e9SAndroid Build Coastguard Worker            _push_on_torch_function_stack(m)
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(_len_torch_function_stack(), 0)
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker    def test_error_empty_stack_pop_torch_function_mode(self):
259*da0073e9SAndroid Build Coastguard Worker        @torch.compile(fullgraph=True)
260*da0073e9SAndroid Build Coastguard Worker        def fn(x):
261*da0073e9SAndroid Build Coastguard Worker            _pop_torch_function_stack()
262*da0073e9SAndroid Build Coastguard Worker            return x + 1
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
265*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.exc.Unsupported,
266*da0073e9SAndroid Build Coastguard Worker            "Popping from an empty torch function mode stack",
267*da0073e9SAndroid Build Coastguard Worker            lambda: fn(torch.ones(2, 2)),
268*da0073e9SAndroid Build Coastguard Worker        )
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker    def test_push_torch_function_mode(self):
271*da0073e9SAndroid Build Coastguard Worker        m = BaseTorchFunctionMode()
272*da0073e9SAndroid Build Coastguard Worker        with m:
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker            @torch.compile(fullgraph=True)
275*da0073e9SAndroid Build Coastguard Worker            def fn(x, m):
276*da0073e9SAndroid Build Coastguard Worker                _push_on_torch_function_stack(m)
277*da0073e9SAndroid Build Coastguard Worker                return x + 1
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker            fn(torch.ones(2, 2), m)
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(_len_torch_function_stack(), 2)
282*da0073e9SAndroid Build Coastguard Worker            # reset stack state
283*da0073e9SAndroid Build Coastguard Worker            _pop_torch_function_stack()
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(_len_torch_function_stack(), 0)
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker    def test_len_torch_function_mode(self):
288*da0073e9SAndroid Build Coastguard Worker        m = BaseTorchFunctionMode()
289*da0073e9SAndroid Build Coastguard Worker        with m:
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker            @torch.compile(fullgraph=True)
292*da0073e9SAndroid Build Coastguard Worker            def fn(x):
293*da0073e9SAndroid Build Coastguard Worker                z = _len_torch_function_stack()
294*da0073e9SAndroid Build Coastguard Worker                return x + z
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker            res = fn(torch.ones(2, 2))
297*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res, torch.ones(2, 2) + 1)
298*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(_len_torch_function_stack(), 1)
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker    def test_intermedate_torch_function_mode_construction_mutation(self):
301*da0073e9SAndroid Build Coastguard Worker        class TestMode(BaseTorchFunctionMode):
302*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x):
303*da0073e9SAndroid Build Coastguard Worker                self.x = x
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker        @torch.compile(fullgraph=True)
306*da0073e9SAndroid Build Coastguard Worker        def fn(x):
307*da0073e9SAndroid Build Coastguard Worker            z = TestMode(2)
308*da0073e9SAndroid Build Coastguard Worker            z.y = 2
309*da0073e9SAndroid Build Coastguard Worker            return x + 1, z
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker        fn(torch.ones(2, 2))
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker    def test_torch_function_mode_enabled_guard(self):
314*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
315*da0073e9SAndroid Build Coastguard Worker        inp = torch.ones(2, 2)
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt.__call__)
318*da0073e9SAndroid Build Coastguard Worker        def fn(x):
319*da0073e9SAndroid Build Coastguard Worker            return x + 1
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker        with BaseTorchFunctionMode(), torch._C.DisableTorchFunctionSubclass():
322*da0073e9SAndroid Build Coastguard Worker            with torch._C.DisableTorchFunction():
323*da0073e9SAndroid Build Coastguard Worker                fn(inp)
324*da0073e9SAndroid Build Coastguard Worker            fn(inp)
325*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 2)
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
329*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker    run_tests()
332