xref: /aosp_15_r20/external/pytorch/test/dynamo/test_comptime.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport collections
4*da0073e9SAndroid Build Coastguard Workerimport re
5*da0073e9SAndroid Build Coastguard Workerimport sys
6*da0073e9SAndroid Build Coastguard Workerimport time
7*da0073e9SAndroid Build Coastguard Workerfrom io import StringIO
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
10*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
11*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.comptime import comptime
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker# Because we don't support free variables in comptime at the moment,
15*da0073e9SAndroid Build Coastguard Worker# we have to communicate via globals.  This also means these tests cannot
16*da0073e9SAndroid Build Coastguard Worker# be run in parallel in a single process (not that you'd... ever want
17*da0073e9SAndroid Build Coastguard Worker# to do that?)
18*da0073e9SAndroid Build Coastguard WorkerFILE = None
19*da0073e9SAndroid Build Coastguard WorkerSELF = None
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Workerclass ComptimeTests(torch._dynamo.test_case.TestCase):
23*da0073e9SAndroid Build Coastguard Worker    def test_print_single(self):
24*da0073e9SAndroid Build Coastguard Worker        global FILE
25*da0073e9SAndroid Build Coastguard Worker        FILE = StringIO()
26*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker        def comptime_print(e):
29*da0073e9SAndroid Build Coastguard Worker            @comptime
30*da0073e9SAndroid Build Coastguard Worker            def _(ctx):
31*da0073e9SAndroid Build Coastguard Worker                ctx.print(ctx.get_local("e"), file=FILE)
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker        Employee = collections.namedtuple("Employee", ["name", "id"])
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker        class mylist(list):
36*da0073e9SAndroid Build Coastguard Worker            pass
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(cnt, dynamic=True)
39*da0073e9SAndroid Build Coastguard Worker        def f(x):
40*da0073e9SAndroid Build Coastguard Worker            y = x * 2
41*da0073e9SAndroid Build Coastguard Worker            comptime_print(y)
42*da0073e9SAndroid Build Coastguard Worker            comptime_print(2)
43*da0073e9SAndroid Build Coastguard Worker            comptime_print([y, 2])
44*da0073e9SAndroid Build Coastguard Worker            comptime_print((y, 2))
45*da0073e9SAndroid Build Coastguard Worker            comptime_print({"foo": y})
46*da0073e9SAndroid Build Coastguard Worker            comptime_print(range(1, 3))
47*da0073e9SAndroid Build Coastguard Worker            comptime_print(Employee("foo", 2))
48*da0073e9SAndroid Build Coastguard Worker            comptime_print(mylist([1, 2]))
49*da0073e9SAndroid Build Coastguard Worker            comptime_print(collections.defaultdict(lambda: None))
50*da0073e9SAndroid Build Coastguard Worker            comptime_print(set())
51*da0073e9SAndroid Build Coastguard Worker            comptime_print({"a", "b"})
52*da0073e9SAndroid Build Coastguard Worker            comptime_print(x.size(0))
53*da0073e9SAndroid Build Coastguard Worker            return y + 3
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker        f(torch.randn(2))
56*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
57*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
58*da0073e9SAndroid Build Coastguard Worker            FILE.getvalue().strip(),
59*da0073e9SAndroid Build Coastguard Worker            """\
60*da0073e9SAndroid Build Coastguard WorkerFakeTensor(..., size=(s0,))
61*da0073e9SAndroid Build Coastguard Worker2
62*da0073e9SAndroid Build Coastguard Worker[FakeTensor(..., size=(s0,)), 2]
63*da0073e9SAndroid Build Coastguard Worker(FakeTensor(..., size=(s0,)), 2)
64*da0073e9SAndroid Build Coastguard Worker{'foo': FakeTensor(..., size=(s0,))}
65*da0073e9SAndroid Build Coastguard Workerrange(1, 3, 1)
66*da0073e9SAndroid Build Coastguard WorkerEmployee(name='foo', id=2)
67*da0073e9SAndroid Build Coastguard Worker[1, 2]
68*da0073e9SAndroid Build Coastguard Workerdefaultdict(NestedUserFunctionVariable(), {})
69*da0073e9SAndroid Build Coastguard Workerset()
70*da0073e9SAndroid Build Coastguard Worker{'a','b'}
71*da0073e9SAndroid Build Coastguard Workers0""",
72*da0073e9SAndroid Build Coastguard Worker        )
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    def test_print_graph(self):
75*da0073e9SAndroid Build Coastguard Worker        global FILE
76*da0073e9SAndroid Build Coastguard Worker        FILE = StringIO()
77*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(cnt)
80*da0073e9SAndroid Build Coastguard Worker        def f(x):
81*da0073e9SAndroid Build Coastguard Worker            y = x * 2
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker            @comptime
84*da0073e9SAndroid Build Coastguard Worker            def _(ctx):
85*da0073e9SAndroid Build Coastguard Worker                ctx.print_graph(verbose=False, file=FILE)
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker            # Test the compact notation doesn't error or graph break;
88*da0073e9SAndroid Build Coastguard Worker            # you'll have to visually inspect to see that it printed
89*da0073e9SAndroid Build Coastguard Worker            comptime.print_graph()
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker            return y + 3
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker        f(torch.randn(2))
94*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
95*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
96*da0073e9SAndroid Build Coastguard Worker            FILE.getvalue().strip(),
97*da0073e9SAndroid Build Coastguard Worker            """\
98*da0073e9SAndroid Build Coastguard Workerdef forward(self, L_x_ : torch.Tensor):
99*da0073e9SAndroid Build Coastguard Worker    l_x_ = L_x_
100*da0073e9SAndroid Build Coastguard Worker    y = l_x_ * 2;  l_x_ = y = None""",
101*da0073e9SAndroid Build Coastguard Worker        )
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker    def test_print_disas(self):
104*da0073e9SAndroid Build Coastguard Worker        global FILE
105*da0073e9SAndroid Build Coastguard Worker        FILE = StringIO()
106*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(cnt)
109*da0073e9SAndroid Build Coastguard Worker        def f(x):
110*da0073e9SAndroid Build Coastguard Worker            y = x * 2
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker            @comptime
113*da0073e9SAndroid Build Coastguard Worker            def _(ctx):
114*da0073e9SAndroid Build Coastguard Worker                ctx.print_disas(file=FILE)
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker            comptime.print_disas()
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker            return y + 3
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker        def munge_disas(s):
121*da0073e9SAndroid Build Coastguard Worker            re.sub(
122*da0073e9SAndroid Build Coastguard Worker                r"^(?: +\d+)?(?: +(-->)) \+\d+ ([A-Za-z0-9_]+)",
123*da0073e9SAndroid Build Coastguard Worker                "\1 \3",
124*da0073e9SAndroid Build Coastguard Worker                s,
125*da0073e9SAndroid Build Coastguard Worker                flags=re.MULTILINE,
126*da0073e9SAndroid Build Coastguard Worker            )
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker        f(torch.randn(2))
129*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
130*da0073e9SAndroid Build Coastguard Worker        out = FILE.getvalue()
131*da0073e9SAndroid Build Coastguard Worker        # Check that the instruction offset is working
132*da0073e9SAndroid Build Coastguard Worker        self.assertIn("-->", out)
133*da0073e9SAndroid Build Coastguard Worker        # Check that the bytecode resembles what we expect
134*da0073e9SAndroid Build Coastguard Worker        self.assertIn("STORE_FAST", out)
135*da0073e9SAndroid Build Coastguard Worker        if sys.version_info < (3, 11):
136*da0073e9SAndroid Build Coastguard Worker            self.assertIn("BINARY_MULTIPLY", out)
137*da0073e9SAndroid Build Coastguard Worker        else:
138*da0073e9SAndroid Build Coastguard Worker            self.assertIn("BINARY_OP", out)
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker    def test_print_value_stack(self):
141*da0073e9SAndroid Build Coastguard Worker        global FILE
142*da0073e9SAndroid Build Coastguard Worker        FILE = StringIO()
143*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker        def g(x):
146*da0073e9SAndroid Build Coastguard Worker            @comptime
147*da0073e9SAndroid Build Coastguard Worker            def _(ctx):
148*da0073e9SAndroid Build Coastguard Worker                ctx.print_value_stack(file=FILE, stacklevel=1)
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker            return x
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(cnt)
153*da0073e9SAndroid Build Coastguard Worker        def f(x):
154*da0073e9SAndroid Build Coastguard Worker            y = x + g(x)
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker            return y + comptime.print_value_stack_and_return(y * 2)
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker        f(torch.randn(2))
159*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
160*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
161*da0073e9SAndroid Build Coastguard Worker            FILE.getvalue(),
162*da0073e9SAndroid Build Coastguard Worker            """\
163*da0073e9SAndroid Build Coastguard Worker- FakeTensor(..., size=(2,))
164*da0073e9SAndroid Build Coastguard Worker""",
165*da0073e9SAndroid Build Coastguard Worker        )
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker    def test_print_locals(self):
168*da0073e9SAndroid Build Coastguard Worker        global FILE
169*da0073e9SAndroid Build Coastguard Worker        FILE = StringIO()
170*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(cnt)
173*da0073e9SAndroid Build Coastguard Worker        def f(x):
174*da0073e9SAndroid Build Coastguard Worker            y = x * 2
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker            @comptime
177*da0073e9SAndroid Build Coastguard Worker            def _(ctx):
178*da0073e9SAndroid Build Coastguard Worker                ctx.print_locals(file=FILE)
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker            comptime.print_locals()
181*da0073e9SAndroid Build Coastguard Worker
182*da0073e9SAndroid Build Coastguard Worker            return y + 3
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker        f(torch.randn(2))
185*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
186*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
187*da0073e9SAndroid Build Coastguard Worker            FILE.getvalue(),
188*da0073e9SAndroid Build Coastguard Worker            """\
189*da0073e9SAndroid Build Coastguard Workerx = FakeTensor(..., size=(2,))
190*da0073e9SAndroid Build Coastguard Workery = FakeTensor(..., size=(2,))
191*da0073e9SAndroid Build Coastguard Worker""",
192*da0073e9SAndroid Build Coastguard Worker        )
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker    # Just make sure it doesn't crash
195*da0073e9SAndroid Build Coastguard Worker    def test_print_direct(self):
196*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(cnt)
199*da0073e9SAndroid Build Coastguard Worker        def f(x, z):
200*da0073e9SAndroid Build Coastguard Worker            y = x * 2
201*da0073e9SAndroid Build Coastguard Worker            lambda: z
202*da0073e9SAndroid Build Coastguard Worker            comptime.print(z)
203*da0073e9SAndroid Build Coastguard Worker            return y + 3
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker        f(torch.randn(2), torch.randn(2))
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker    def test_sleep(self):
208*da0073e9SAndroid Build Coastguard Worker        sleep_time = 5
209*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(cnt)
212*da0073e9SAndroid Build Coastguard Worker        def f(x, z, should_sleep):
213*da0073e9SAndroid Build Coastguard Worker            if should_sleep:
214*da0073e9SAndroid Build Coastguard Worker                comptime.sleep(sleep_time)
215*da0073e9SAndroid Build Coastguard Worker            y = x * 2
216*da0073e9SAndroid Build Coastguard Worker            return y + 3
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker        start = time.time()
219*da0073e9SAndroid Build Coastguard Worker        f(torch.randn(2), torch.randn(2), False)
220*da0073e9SAndroid Build Coastguard Worker        total_no_sleep = time.time() - start
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker        start = time.time()
223*da0073e9SAndroid Build Coastguard Worker        f(torch.randn(2), torch.randn(2), True)
224*da0073e9SAndroid Build Coastguard Worker        total_with_sleep = time.time() - start
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(total_with_sleep > sleep_time)
227*da0073e9SAndroid Build Coastguard Worker        # Hopefully this won't be flaky
228*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(abs(total_with_sleep - sleep_time - total_no_sleep) < 3)
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker    # Just make sure it doesn't crash
231*da0073e9SAndroid Build Coastguard Worker    def test_get_local_closure_variable(self):
232*da0073e9SAndroid Build Coastguard Worker        global SELF
233*da0073e9SAndroid Build Coastguard Worker        SELF = self
234*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(cnt)
237*da0073e9SAndroid Build Coastguard Worker        def f(x):
238*da0073e9SAndroid Build Coastguard Worker            z = 3
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker            def g():
241*da0073e9SAndroid Build Coastguard Worker                @comptime
242*da0073e9SAndroid Build Coastguard Worker                def _(ctx):
243*da0073e9SAndroid Build Coastguard Worker                    r = ctx.get_local("z")
244*da0073e9SAndroid Build Coastguard Worker                    SELF.assertEqual(repr(r), "3")
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker                comptime.print(z)
247*da0073e9SAndroid Build Coastguard Worker                return 2
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker            y = x * g()
250*da0073e9SAndroid Build Coastguard Worker            return y + 3
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker        f(torch.randn(2))
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker    def test_print_bt(self):
255*da0073e9SAndroid Build Coastguard Worker        global FILE
256*da0073e9SAndroid Build Coastguard Worker        FILE = StringIO()
257*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker        def g(x):
260*da0073e9SAndroid Build Coastguard Worker            @comptime
261*da0073e9SAndroid Build Coastguard Worker            def _(ctx):
262*da0073e9SAndroid Build Coastguard Worker                ctx.print_bt(file=FILE)
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker            comptime.print_bt()
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker            return x + 3
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(cnt)
269*da0073e9SAndroid Build Coastguard Worker        def f(x):
270*da0073e9SAndroid Build Coastguard Worker            y = x * 2
271*da0073e9SAndroid Build Coastguard Worker            y = g(y)
272*da0073e9SAndroid Build Coastguard Worker            return y + 3
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker        def munge_filenames(s):
275*da0073e9SAndroid Build Coastguard Worker            return re.sub(r'File "[^"]+", line \d+', 'File "X", line X', s)
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker        f(torch.randn(2))
278*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
279*da0073e9SAndroid Build Coastguard Worker        bt = FILE.getvalue()
280*da0073e9SAndroid Build Coastguard Worker        self.assertIn("y = g(y)", bt)
281*da0073e9SAndroid Build Coastguard Worker
282*da0073e9SAndroid Build Coastguard Worker    def test_print_guards(self):
283*da0073e9SAndroid Build Coastguard Worker        global FILE
284*da0073e9SAndroid Build Coastguard Worker        FILE = StringIO()
285*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(cnt)
288*da0073e9SAndroid Build Coastguard Worker        def f(x):
289*da0073e9SAndroid Build Coastguard Worker            y = x * 2
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker            @comptime
292*da0073e9SAndroid Build Coastguard Worker            def _(ctx):
293*da0073e9SAndroid Build Coastguard Worker                ctx.print_guards(file=FILE)
294*da0073e9SAndroid Build Coastguard Worker
295*da0073e9SAndroid Build Coastguard Worker            comptime.print_guards()
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker            return y + 3
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker        f(torch.randn(2))
300*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
301*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
302*da0073e9SAndroid Build Coastguard Worker            re.sub(r"\s+$", "", FILE.getvalue().rstrip(), flags=re.MULTILINE),
303*da0073e9SAndroid Build Coastguard Worker            """\
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker        local "L['x']" TENSOR_MATCH
306*da0073e9SAndroid Build Coastguard Worker        {
307*da0073e9SAndroid Build Coastguard Worker            'guard_types': None,
308*da0073e9SAndroid Build Coastguard Worker            'code': None,
309*da0073e9SAndroid Build Coastguard Worker            'obj_weakref': None
310*da0073e9SAndroid Build Coastguard Worker            'guarded_class': None
311*da0073e9SAndroid Build Coastguard Worker        }
312*da0073e9SAndroid Build Coastguard Worker        global '' GRAD_MODE
313*da0073e9SAndroid Build Coastguard Worker        {
314*da0073e9SAndroid Build Coastguard Worker            'guard_types': None,
315*da0073e9SAndroid Build Coastguard Worker            'code': None,
316*da0073e9SAndroid Build Coastguard Worker            'obj_weakref': None
317*da0073e9SAndroid Build Coastguard Worker            'guarded_class': None
318*da0073e9SAndroid Build Coastguard Worker        }
319*da0073e9SAndroid Build Coastguard Worker        global '' DETERMINISTIC_ALGORITHMS
320*da0073e9SAndroid Build Coastguard Worker        {
321*da0073e9SAndroid Build Coastguard Worker            'guard_types': None,
322*da0073e9SAndroid Build Coastguard Worker            'code': None,
323*da0073e9SAndroid Build Coastguard Worker            'obj_weakref': None
324*da0073e9SAndroid Build Coastguard Worker            'guarded_class': None
325*da0073e9SAndroid Build Coastguard Worker        }
326*da0073e9SAndroid Build Coastguard Worker        global '' TORCH_FUNCTION_STATE
327*da0073e9SAndroid Build Coastguard Worker        {
328*da0073e9SAndroid Build Coastguard Worker            'guard_types': None,
329*da0073e9SAndroid Build Coastguard Worker            'code': None,
330*da0073e9SAndroid Build Coastguard Worker            'obj_weakref': None
331*da0073e9SAndroid Build Coastguard Worker            'guarded_class': None
332*da0073e9SAndroid Build Coastguard Worker        }
333*da0073e9SAndroid Build Coastguard Worker        global '' DEFAULT_DEVICE
334*da0073e9SAndroid Build Coastguard Worker        {
335*da0073e9SAndroid Build Coastguard Worker            'guard_types': None,
336*da0073e9SAndroid Build Coastguard Worker            'code': None,
337*da0073e9SAndroid Build Coastguard Worker            'obj_weakref': None
338*da0073e9SAndroid Build Coastguard Worker            'guarded_class': None
339*da0073e9SAndroid Build Coastguard Worker        }
340*da0073e9SAndroid Build Coastguard Worker        shape_env '' SHAPE_ENV
341*da0073e9SAndroid Build Coastguard Worker        {
342*da0073e9SAndroid Build Coastguard Worker            'guard_types': None,
343*da0073e9SAndroid Build Coastguard Worker            'code': None,
344*da0073e9SAndroid Build Coastguard Worker            'obj_weakref': None
345*da0073e9SAndroid Build Coastguard Worker            'guarded_class': None
346*da0073e9SAndroid Build Coastguard Worker        }""",
347*da0073e9SAndroid Build Coastguard Worker        )
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker    def test_graph_break(self):
350*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(cnt)
353*da0073e9SAndroid Build Coastguard Worker        def f(x):
354*da0073e9SAndroid Build Coastguard Worker            y = x * 2
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker            @comptime
357*da0073e9SAndroid Build Coastguard Worker            def _(ctx):
358*da0073e9SAndroid Build Coastguard Worker                pass
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker            return y + 3
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker        f(torch.randn(2))
363*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
364*da0073e9SAndroid Build Coastguard Worker        cnt.frame_count = 0
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(cnt)
367*da0073e9SAndroid Build Coastguard Worker        def g(x):
368*da0073e9SAndroid Build Coastguard Worker            y = x * 2
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker            @comptime
371*da0073e9SAndroid Build Coastguard Worker            def _(ctx):
372*da0073e9SAndroid Build Coastguard Worker                ctx.graph_break()
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker            y = y + 2
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker            comptime.graph_break()
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker            return y * 3
379*da0073e9SAndroid Build Coastguard Worker
380*da0073e9SAndroid Build Coastguard Worker        g(torch.randn(2))
381*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 3)
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker    def test_get_local(self):
384*da0073e9SAndroid Build Coastguard Worker        global SELF, FILE
385*da0073e9SAndroid Build Coastguard Worker        SELF = self
386*da0073e9SAndroid Build Coastguard Worker        FILE = StringIO()
387*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize(cnt)
390*da0073e9SAndroid Build Coastguard Worker        def f(x):
391*da0073e9SAndroid Build Coastguard Worker            y = x * 2
392*da0073e9SAndroid Build Coastguard Worker            lit = 2
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker            @comptime
395*da0073e9SAndroid Build Coastguard Worker            def _(ctx):
396*da0073e9SAndroid Build Coastguard Worker                y = ctx.get_local("y")
397*da0073e9SAndroid Build Coastguard Worker                SELF.assertEqual(y.as_fake().size(0), 2)
398*da0073e9SAndroid Build Coastguard Worker                SELF.assertEqual(y.size(0), 2)
399*da0073e9SAndroid Build Coastguard Worker                # Trigger a graph write (TODO: this is not so
400*da0073e9SAndroid Build Coastguard Worker                # useful right now as there's no way to make use
401*da0073e9SAndroid Build Coastguard Worker                # of the output proxy; maybe it's useful for inserting
402*da0073e9SAndroid Build Coastguard Worker                # side-effectful operations into the graph)
403*da0073e9SAndroid Build Coastguard Worker                y.as_proxy() + 4
404*da0073e9SAndroid Build Coastguard Worker                ctx.print_graph(verbose=False, file=FILE)
405*da0073e9SAndroid Build Coastguard Worker                SELF.assertIs(y.python_type(), torch.Tensor)
406*da0073e9SAndroid Build Coastguard Worker                lit = ctx.get_local("lit")
407*da0073e9SAndroid Build Coastguard Worker                SELF.assertEqual(lit.as_python_constant(), 2)
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker            return y + 3
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker        f(torch.randn(2))
412*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
413*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
414*da0073e9SAndroid Build Coastguard Worker            FILE.getvalue().strip(),
415*da0073e9SAndroid Build Coastguard Worker            """\
416*da0073e9SAndroid Build Coastguard Workerdef forward(self, L_x_ : torch.Tensor):
417*da0073e9SAndroid Build Coastguard Worker    l_x_ = L_x_
418*da0073e9SAndroid Build Coastguard Worker    y = l_x_ * 2;  l_x_ = None
419*da0073e9SAndroid Build Coastguard Worker    add = y + 4;  y = add = None""",
420*da0073e9SAndroid Build Coastguard Worker        )
421*da0073e9SAndroid Build Coastguard Worker
422*da0073e9SAndroid Build Coastguard Worker
423*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
424*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Worker    run_tests()
427