xref: /aosp_15_r20/external/pytorch/test/dynamo/test_global.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Workerimport torch
3*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
4*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
5*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import same
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workertry:
9*da0073e9SAndroid Build Coastguard Worker    from . import utils
10*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
11*da0073e9SAndroid Build Coastguard Worker    import utils
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerclass Pair:  # noqa: B903
15*da0073e9SAndroid Build Coastguard Worker    def __init__(self, x, y):
16*da0073e9SAndroid Build Coastguard Worker        self.x = x
17*da0073e9SAndroid Build Coastguard Worker        self.y = y
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerdef Foo():
21*da0073e9SAndroid Build Coastguard Worker    return Pair(1, 1)
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerg_counter = 1
25*da0073e9SAndroid Build Coastguard Workerg_list = [0, 1, 2]
26*da0073e9SAndroid Build Coastguard Workerg_dict = {"a": 0, "b": 1}
27*da0073e9SAndroid Build Coastguard Workerg_object = Foo()
28*da0073e9SAndroid Build Coastguard Workerg_tensor = torch.zeros(10)
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker_name: int = 0
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Workerdef fresh_name() -> str:
35*da0073e9SAndroid Build Coastguard Worker    """create a new unique name for a variable: v0, v1, v2"""
36*da0073e9SAndroid Build Coastguard Worker    global _name
37*da0073e9SAndroid Build Coastguard Worker    r = f"v{_name}"
38*da0073e9SAndroid Build Coastguard Worker    _name += 1
39*da0073e9SAndroid Build Coastguard Worker    return r
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Workerdef reset_name():
43*da0073e9SAndroid Build Coastguard Worker    global _name
44*da0073e9SAndroid Build Coastguard Worker    _name = 0
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Workerclass TestGlobals(torch._dynamo.test_case.TestCase):
48*da0073e9SAndroid Build Coastguard Worker    def test_store_global_1(self):
49*da0073e9SAndroid Build Coastguard Worker        def fn(x):
50*da0073e9SAndroid Build Coastguard Worker            global g_counter
51*da0073e9SAndroid Build Coastguard Worker            val = x + g_counter
52*da0073e9SAndroid Build Coastguard Worker            g_counter += 1
53*da0073e9SAndroid Build Coastguard Worker            return val
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
56*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
57*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
58*da0073e9SAndroid Build Coastguard Worker        res1 = opt_fn(x)
59*da0073e9SAndroid Build Coastguard Worker        res2 = fn(x)
60*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res2 - res1, torch.ones(10)))
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker    def test_store_global_2(self):
63*da0073e9SAndroid Build Coastguard Worker        def fn(x):
64*da0073e9SAndroid Build Coastguard Worker            global g_counter
65*da0073e9SAndroid Build Coastguard Worker            val = x + g_counter
66*da0073e9SAndroid Build Coastguard Worker            g_counter += 1
67*da0073e9SAndroid Build Coastguard Worker            g_counter += 1
68*da0073e9SAndroid Build Coastguard Worker            return val
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
71*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
72*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
73*da0073e9SAndroid Build Coastguard Worker        res1 = opt_fn(x)
74*da0073e9SAndroid Build Coastguard Worker        """Wrap the second call with torch._dynamo as well"""
75*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
76*da0073e9SAndroid Build Coastguard Worker        res2 = opt_fn(x)
77*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res2 - res1, 2 * torch.ones(10)))
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker    def test_store_global_new(self):
80*da0073e9SAndroid Build Coastguard Worker        def fn(x):
81*da0073e9SAndroid Build Coastguard Worker            # Test create a new global
82*da0073e9SAndroid Build Coastguard Worker            global g_counter_new
83*da0073e9SAndroid Build Coastguard Worker            g_counter_new = x + 1
84*da0073e9SAndroid Build Coastguard Worker            return x + g_counter_new
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
87*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
88*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
89*da0073e9SAndroid Build Coastguard Worker        res1 = opt_fn(x)
90*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res1, x + x + 1))
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker    def test_store_global_list(self):
93*da0073e9SAndroid Build Coastguard Worker        def fn(x):
94*da0073e9SAndroid Build Coastguard Worker            global g_list
95*da0073e9SAndroid Build Coastguard Worker            val = x + g_list[1]
96*da0073e9SAndroid Build Coastguard Worker            """
97*da0073e9SAndroid Build Coastguard Worker            Strictly speaking, we are not testing STORE_GLOBAL
98*da0073e9SAndroid Build Coastguard Worker            here, since STORE_SUBSCR is actually used to store.
99*da0073e9SAndroid Build Coastguard Worker            """
100*da0073e9SAndroid Build Coastguard Worker            g_list[1] += 1
101*da0073e9SAndroid Build Coastguard Worker            return val
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
104*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
105*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
106*da0073e9SAndroid Build Coastguard Worker        res1 = opt_fn(x)
107*da0073e9SAndroid Build Coastguard Worker        res2 = fn(x)
108*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res2 - res1, torch.ones(10)))
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker    def test_store_global_list_2(self):
111*da0073e9SAndroid Build Coastguard Worker        def fn(x):
112*da0073e9SAndroid Build Coastguard Worker            global g_list
113*da0073e9SAndroid Build Coastguard Worker            val = x + g_list[1]
114*da0073e9SAndroid Build Coastguard Worker            g_list = [x + 1 for x in g_list]
115*da0073e9SAndroid Build Coastguard Worker            return val
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
118*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
119*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
120*da0073e9SAndroid Build Coastguard Worker        res1 = opt_fn(x)
121*da0073e9SAndroid Build Coastguard Worker        res2 = fn(x)
122*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res2 - res1, torch.ones(10)))
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker    def test_store_global_dict(self):
125*da0073e9SAndroid Build Coastguard Worker        def fn(x):
126*da0073e9SAndroid Build Coastguard Worker            global g_dict
127*da0073e9SAndroid Build Coastguard Worker            val = x + g_dict["b"]
128*da0073e9SAndroid Build Coastguard Worker            """
129*da0073e9SAndroid Build Coastguard Worker            Strictly speaking, we are not testing STORE_GLOBAL
130*da0073e9SAndroid Build Coastguard Worker            here, since STORE_SUBSCR is actually used to store.
131*da0073e9SAndroid Build Coastguard Worker            """
132*da0073e9SAndroid Build Coastguard Worker            g_dict["b"] += 1
133*da0073e9SAndroid Build Coastguard Worker            return val
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
136*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
137*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
138*da0073e9SAndroid Build Coastguard Worker        res1 = opt_fn(x)
139*da0073e9SAndroid Build Coastguard Worker        res2 = fn(x)
140*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res2 - res1, torch.ones(10)))
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker    def test_store_global_dict_2(self):
143*da0073e9SAndroid Build Coastguard Worker        def fn(x):
144*da0073e9SAndroid Build Coastguard Worker            global g_dict
145*da0073e9SAndroid Build Coastguard Worker            g_dict = {key: value + 1 for key, value in g_dict.items()}
146*da0073e9SAndroid Build Coastguard Worker            val = x + g_dict["b"]
147*da0073e9SAndroid Build Coastguard Worker            return val
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
150*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
151*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
152*da0073e9SAndroid Build Coastguard Worker        res1 = opt_fn(x)
153*da0073e9SAndroid Build Coastguard Worker        res2 = fn(x)
154*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res2 - res1, torch.ones(10)))
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker    def test_store_global_object(self):
157*da0073e9SAndroid Build Coastguard Worker        def fn(x):
158*da0073e9SAndroid Build Coastguard Worker            global g_object
159*da0073e9SAndroid Build Coastguard Worker            val = x + g_object.y
160*da0073e9SAndroid Build Coastguard Worker            g_object.y += 1
161*da0073e9SAndroid Build Coastguard Worker            return val
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
164*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
165*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
166*da0073e9SAndroid Build Coastguard Worker        res1 = opt_fn(x)
167*da0073e9SAndroid Build Coastguard Worker        res2 = fn(x)
168*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res2 - res1, torch.ones(10)))
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker    def test_store_global_cross_file(self):
171*da0073e9SAndroid Build Coastguard Worker        def fn(x):
172*da0073e9SAndroid Build Coastguard Worker            val = x + utils.g_tensor_export
173*da0073e9SAndroid Build Coastguard Worker            utils.g_tensor_export = utils.g_tensor_export + 1
174*da0073e9SAndroid Build Coastguard Worker            return val
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10)
177*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
178*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
179*da0073e9SAndroid Build Coastguard Worker        res1 = opt_fn(x)
180*da0073e9SAndroid Build Coastguard Worker        res2 = fn(x)
181*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(same(res2 - res1, torch.ones(10)))
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker    def test_store_global_inline_1(self):
184*da0073e9SAndroid Build Coastguard Worker        # Borrowed from test_python_autograd.py
185*da0073e9SAndroid Build Coastguard Worker        class Variable:
186*da0073e9SAndroid Build Coastguard Worker            def __init__(self, value: torch.Tensor, name: str = None):
187*da0073e9SAndroid Build Coastguard Worker                self.value = value
188*da0073e9SAndroid Build Coastguard Worker                self.name = name or fresh_name()
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
191*da0073e9SAndroid Build Coastguard Worker            a = Variable(a)
192*da0073e9SAndroid Build Coastguard Worker            b = Variable(b)
193*da0073e9SAndroid Build Coastguard Worker            return a.value + b.value, a.name + b.name
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(10)
196*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(10)
197*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
198*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
199*da0073e9SAndroid Build Coastguard Worker        v0, s0 = opt_fn(a, b)
200*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s0, "v0v1")
201*da0073e9SAndroid Build Coastguard Worker        reset_name()
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker    def test_store_global_inline_2(self):
204*da0073e9SAndroid Build Coastguard Worker        # Borrowed from test_python_autograd.py
205*da0073e9SAndroid Build Coastguard Worker        class Variable:
206*da0073e9SAndroid Build Coastguard Worker            def __init__(self, value: torch.Tensor, name: str = None):
207*da0073e9SAndroid Build Coastguard Worker                self.value = value
208*da0073e9SAndroid Build Coastguard Worker                self.name = name or fresh_name()
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker            @staticmethod
211*da0073e9SAndroid Build Coastguard Worker            def constant(value: torch.Tensor, name: str = None):
212*da0073e9SAndroid Build Coastguard Worker                return Variable(value, name)
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
215*da0073e9SAndroid Build Coastguard Worker            a = Variable.constant(a)
216*da0073e9SAndroid Build Coastguard Worker            b = Variable.constant(b)
217*da0073e9SAndroid Build Coastguard Worker            return a.value + b.value, a.name + b.name
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(10)
220*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(10)
221*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
222*da0073e9SAndroid Build Coastguard Worker        opt_fn = torch._dynamo.optimize(cnts)(fn)
223*da0073e9SAndroid Build Coastguard Worker        v0, s0 = opt_fn(a, b)
224*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s0, "v0v1")
225*da0073e9SAndroid Build Coastguard Worker        reset_name()
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker    def test_store_global_crossfile_inline(self):
228*da0073e9SAndroid Build Coastguard Worker        try:
229*da0073e9SAndroid Build Coastguard Worker            from . import mock_store_global_crossfile_inline
230*da0073e9SAndroid Build Coastguard Worker        except ImportError:
231*da0073e9SAndroid Build Coastguard Worker            import mock_store_global_crossfile_inline
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker        @torch.compile()
234*da0073e9SAndroid Build Coastguard Worker        def fn(x):
235*da0073e9SAndroid Build Coastguard Worker            mock_store_global_crossfile_inline.set_flag_true()
236*da0073e9SAndroid Build Coastguard Worker            mock_store_global_crossfile_inline.set_flag_false()
237*da0073e9SAndroid Build Coastguard Worker            return x + 1
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker        @torch.compile()
240*da0073e9SAndroid Build Coastguard Worker        def fn_set_true(x):
241*da0073e9SAndroid Build Coastguard Worker            mock_store_global_crossfile_inline.set_flag_true()
242*da0073e9SAndroid Build Coastguard Worker            return x + 1
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker        fn_set_true(torch.ones(2, 2))
245*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mock_store_global_crossfile_inline.global_flag)
246*da0073e9SAndroid Build Coastguard Worker        fn(torch.ones(2, 2))
247*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mock_store_global_crossfile_inline.global_flag)
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
251*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker    run_tests()
254