xref: /aosp_15_r20/external/pytorch/test/dynamo/test_modes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2from unittest.mock import patch
3
4import torch
5import torch._dynamo.test_case
6import torch._dynamo.testing
7from torch._C import (
8    _len_torch_function_stack,
9    _pop_torch_function_stack,
10    _push_on_torch_function_stack,
11)
12from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode
13from torch.utils._device import DeviceContext
14from torch.utils._python_dispatch import TorchDispatchMode
15
16
17class TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
18    @classmethod
19    def setUpClass(cls):
20        super().setUpClass()
21
22    @classmethod
23    def tearDownClass(cls):
24        super().tearDownClass()
25
26    def test_skip_torch_dispatch_modes(self):
27        class RewriteAddToMul(TorchDispatchMode):
28            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
29                if func is torch.ops.aten.add.Tensor:
30                    func = torch.ops.aten.mul.Tensor
31                return func(*args, **kwargs)
32
33        def fn(x):
34            return x + x
35
36        cnt = torch._dynamo.testing.CompileCounter()
37
38        x = torch.tensor([3.0])
39        with RewriteAddToMul():
40            eager_res = fn(x)
41            compiled_res = torch._dynamo.optimize(cnt)(fn)(x)
42
43        self.assertEqual(eager_res, compiled_res)
44        self.assertEqual(cnt.frame_count, 0)
45
46
47class TorchFunctionModeTests(torch._dynamo.test_case.TestCase):
48    @classmethod
49    def setUpClass(cls):
50        cls.default_device_old = torch.get_default_device()
51        super().setUpClass()
52
53    @classmethod
54    def tearDownClass(cls):
55        torch.set_default_device(cls.default_device_old)
56        super().tearDownClass()
57
58    def setUp(self):
59        torch.set_default_device(None)
60
61    def tearDown(self):
62        torch.set_default_device(None)
63
64    def _run_torch_function_mode_guard_test(self):
65        class TestMode1(BaseTorchFunctionMode):
66            pass
67
68        class TestMode2(BaseTorchFunctionMode):
69            pass
70
71        cnt = torch._dynamo.testing.CompileCounter()
72
73        @torch.compile(backend=cnt.__call__)
74        def fn(x):
75            return x + 1
76
77        inp = torch.ones(2, 2)
78        fn(inp)
79        self.assertEqual(cnt.frame_count, 1)
80
81        with TestMode1():
82            fn(inp)
83        self.assertEqual(cnt.frame_count, 2)
84
85        with TestMode1(), TestMode2():
86            fn(inp)
87        self.assertEqual(cnt.frame_count, 3)
88
89        with TestMode2(), TestMode1():
90            fn(inp)
91        self.assertEqual(cnt.frame_count, 4)
92
93        with TestMode1():
94            fn(inp)
95        self.assertEqual(cnt.frame_count, 4)
96
97    def _run_ignored_mode_types_test(self):
98        class IgnoredMode(BaseTorchFunctionMode):
99            pass
100
101        cnt = torch._dynamo.testing.CompileCounter()
102
103        @torch.compile(backend=cnt.__call__, fullgraph=True)
104        def fn(x):
105            return x + 1
106
107        inp = torch.ones(2, 2)
108
109        with patch(
110            "torch._dynamo.variables.torch_function.IGNORED_MODES", {IgnoredMode}
111        ):
112            # initial compile
113            fn(inp)
114
115            # no recompile, mode ignored
116            # note: the ref stack is length 0, and the stack we are checking against has length 2
117            # we want to check both ref stack len > runtime stack, and ref stack len < runtime stack
118            with IgnoredMode(), IgnoredMode():
119                fn(inp)
120
121            self.assertEqual(cnt.frame_count, 1)
122
123            # recompile due to new mode on the stack
124            with BaseTorchFunctionMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
125                fn(inp)
126
127            self.assertEqual(cnt.frame_count, 2)
128
129            # recompile
130            # tests both ref stack len > runtime stack len for the above guard check
131            # and ref stack len < runtime stack len for the initial zero mode case
132            with BaseTorchFunctionMode(), IgnoredMode(), BaseTorchFunctionMode():
133                fn(inp)
134
135            self.assertEqual(cnt.frame_count, 3)
136
137            # no recompile
138            with IgnoredMode(), IgnoredMode(), BaseTorchFunctionMode(), BaseTorchFunctionMode():
139                fn(inp)
140
141            self.assertEqual(cnt.frame_count, 3)
142
143        # This is tricky, basically the ignored modes are baked into the guard
144        # IgnoredMode will be ignored forever by that guard.
145        # This is okay since we don't expect to be modifying IGNORED_MODES
146        # in the middle of execution except for the purposes of testing.
147        torch._dynamo.reset()
148
149        with IgnoredMode():
150            fn(inp)
151
152        self.assertEqual(cnt.frame_count, 4)
153
154    @torch._dynamo.config.patch("enable_cpp_guard_manager", False)
155    def test_torch_function_mode_guards_ignored_types_py(self):
156        self._run_ignored_mode_types_test()
157
158    def test_torch_function_mode_guards_ignored_types_cpp(self):
159        self._run_ignored_mode_types_test()
160
161    @torch._dynamo.config.patch("enable_cpp_guard_manager", False)
162    def test_torch_function_mode_guards_py(self):
163        self._run_torch_function_mode_guard_test()
164
165    def test_torch_function_mode_guards_cpp(self):
166        self._run_torch_function_mode_guard_test()
167
168    def test_stack_state_mutation_default_device(self):
169        m = BaseTorchFunctionMode()
170        m1 = BaseTorchFunctionMode()
171        with m, m1:
172
173            @torch.compile(fullgraph=True)
174            def fn(x):
175                torch.set_default_device("cpu")
176                _pop_torch_function_stack()
177
178            fn(torch.ones(2, 2))
179            _push_on_torch_function_stack(m1)
180
181            stack = _get_current_function_mode_stack()
182            self.assertIsInstance(stack[0], DeviceContext)
183            self.assertEqual(stack[0].device, torch.device("cpu"))
184            self.assertIs(stack[1], m)
185            self.assertIs(stack[2], m1)
186
187    def test_stack_state_clear_default_device(self):
188        @torch.compile(fullgraph=True)
189        def fn(x):
190            torch.set_default_device(None)
191            return x + 1
192
193        fn(torch.ones(2, 2))
194        stack = _get_current_function_mode_stack()
195        self.assertEqual(len(stack), 0)
196
197        m = BaseTorchFunctionMode()
198        m1 = BaseTorchFunctionMode()
199
200        # Stack populated, add device
201        with m, m1:
202
203            @torch.compile(fullgraph=True)
204            def fn(x):
205                torch.set_default_device("cpu")
206                torch.set_default_device(None)
207                torch.set_default_device("cpu")
208                return x + 1
209
210            fn(torch.ones(2, 2))
211            stack = _get_current_function_mode_stack()
212            self.assertEqual(stack[0].device, torch.device("cpu"))
213            self.assertIs(stack[1], m)
214            self.assertIs(stack[2], m1)
215
216        # Stack populated, remove device
217        torch.set_default_device("cpu")
218        with m, m1:
219
220            @torch.compile(fullgraph=True)
221            def fn(x):
222                torch.set_default_device(None)
223                return x + 1
224
225            fn(torch.ones(2, 2))
226            stack = _get_current_function_mode_stack()
227            self.assertIs(stack[0], m)
228            self.assertIs(stack[1], m1)
229
230        @torch.compile(fullgraph=True)
231        def fn(x):
232            torch.set_default_device("cpu")
233            torch.set_default_device("cpu")
234            return x + 1
235
236        fn(torch.ones(2, 2))
237        stack = _get_current_function_mode_stack()
238        self.assertEqual(stack[0].device, torch.device("cpu"))
239        torch.set_default_device(None)
240
241    def test_pop_torch_function_mode(self):
242        m = BaseTorchFunctionMode()
243        with m:
244
245            @torch.compile(fullgraph=True)
246            def fn(x):
247                _pop_torch_function_stack()
248                return x + 1
249
250            fn(torch.ones(2, 2))
251
252            self.assertEqual(_len_torch_function_stack(), 0)
253            # reset stack so __exit__ doesn't crash
254            _push_on_torch_function_stack(m)
255
256        self.assertEqual(_len_torch_function_stack(), 0)
257
258    def test_error_empty_stack_pop_torch_function_mode(self):
259        @torch.compile(fullgraph=True)
260        def fn(x):
261            _pop_torch_function_stack()
262            return x + 1
263
264        self.assertRaisesRegex(
265            torch._dynamo.exc.Unsupported,
266            "Popping from an empty torch function mode stack",
267            lambda: fn(torch.ones(2, 2)),
268        )
269
270    def test_push_torch_function_mode(self):
271        m = BaseTorchFunctionMode()
272        with m:
273
274            @torch.compile(fullgraph=True)
275            def fn(x, m):
276                _push_on_torch_function_stack(m)
277                return x + 1
278
279            fn(torch.ones(2, 2), m)
280
281            self.assertEqual(_len_torch_function_stack(), 2)
282            # reset stack state
283            _pop_torch_function_stack()
284
285        self.assertEqual(_len_torch_function_stack(), 0)
286
287    def test_len_torch_function_mode(self):
288        m = BaseTorchFunctionMode()
289        with m:
290
291            @torch.compile(fullgraph=True)
292            def fn(x):
293                z = _len_torch_function_stack()
294                return x + z
295
296            res = fn(torch.ones(2, 2))
297            self.assertEqual(res, torch.ones(2, 2) + 1)
298            self.assertEqual(_len_torch_function_stack(), 1)
299
300    def test_intermedate_torch_function_mode_construction_mutation(self):
301        class TestMode(BaseTorchFunctionMode):
302            def __init__(self, x):
303                self.x = x
304
305        @torch.compile(fullgraph=True)
306        def fn(x):
307            z = TestMode(2)
308            z.y = 2
309            return x + 1, z
310
311        fn(torch.ones(2, 2))
312
313    def test_torch_function_mode_enabled_guard(self):
314        cnt = torch._dynamo.testing.CompileCounter()
315        inp = torch.ones(2, 2)
316
317        @torch.compile(backend=cnt.__call__)
318        def fn(x):
319            return x + 1
320
321        with BaseTorchFunctionMode(), torch._C.DisableTorchFunctionSubclass():
322            with torch._C.DisableTorchFunction():
323                fn(inp)
324            fn(inp)
325        self.assertEqual(cnt.frame_count, 2)
326
327
328if __name__ == "__main__":
329    from torch._dynamo.test_case import run_tests
330
331    run_tests()
332