xref: /aosp_15_r20/external/pytorch/test/dynamo/test_global.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import torch
3import torch._dynamo.test_case
4import torch._dynamo.testing
5from torch._dynamo.testing import same
6
7
8try:
9    from . import utils
10except ImportError:
11    import utils
12
13
14class Pair:  # noqa: B903
15    def __init__(self, x, y):
16        self.x = x
17        self.y = y
18
19
20def Foo():
21    return Pair(1, 1)
22
23
24g_counter = 1
25g_list = [0, 1, 2]
26g_dict = {"a": 0, "b": 1}
27g_object = Foo()
28g_tensor = torch.zeros(10)
29
30
31_name: int = 0
32
33
34def fresh_name() -> str:
35    """create a new unique name for a variable: v0, v1, v2"""
36    global _name
37    r = f"v{_name}"
38    _name += 1
39    return r
40
41
42def reset_name():
43    global _name
44    _name = 0
45
46
47class TestGlobals(torch._dynamo.test_case.TestCase):
48    def test_store_global_1(self):
49        def fn(x):
50            global g_counter
51            val = x + g_counter
52            g_counter += 1
53            return val
54
55        x = torch.randn(10)
56        cnts = torch._dynamo.testing.CompileCounter()
57        opt_fn = torch._dynamo.optimize(cnts)(fn)
58        res1 = opt_fn(x)
59        res2 = fn(x)
60        self.assertTrue(same(res2 - res1, torch.ones(10)))
61
62    def test_store_global_2(self):
63        def fn(x):
64            global g_counter
65            val = x + g_counter
66            g_counter += 1
67            g_counter += 1
68            return val
69
70        x = torch.randn(10)
71        cnts = torch._dynamo.testing.CompileCounter()
72        opt_fn = torch._dynamo.optimize(cnts)(fn)
73        res1 = opt_fn(x)
74        """Wrap the second call with torch._dynamo as well"""
75        opt_fn = torch._dynamo.optimize(cnts)(fn)
76        res2 = opt_fn(x)
77        self.assertTrue(same(res2 - res1, 2 * torch.ones(10)))
78
79    def test_store_global_new(self):
80        def fn(x):
81            # Test create a new global
82            global g_counter_new
83            g_counter_new = x + 1
84            return x + g_counter_new
85
86        x = torch.randn(10)
87        cnts = torch._dynamo.testing.CompileCounter()
88        opt_fn = torch._dynamo.optimize(cnts)(fn)
89        res1 = opt_fn(x)
90        self.assertTrue(same(res1, x + x + 1))
91
92    def test_store_global_list(self):
93        def fn(x):
94            global g_list
95            val = x + g_list[1]
96            """
97            Strictly speaking, we are not testing STORE_GLOBAL
98            here, since STORE_SUBSCR is actually used to store.
99            """
100            g_list[1] += 1
101            return val
102
103        x = torch.randn(10)
104        cnts = torch._dynamo.testing.CompileCounter()
105        opt_fn = torch._dynamo.optimize(cnts)(fn)
106        res1 = opt_fn(x)
107        res2 = fn(x)
108        self.assertTrue(same(res2 - res1, torch.ones(10)))
109
110    def test_store_global_list_2(self):
111        def fn(x):
112            global g_list
113            val = x + g_list[1]
114            g_list = [x + 1 for x in g_list]
115            return val
116
117        x = torch.randn(10)
118        cnts = torch._dynamo.testing.CompileCounter()
119        opt_fn = torch._dynamo.optimize(cnts)(fn)
120        res1 = opt_fn(x)
121        res2 = fn(x)
122        self.assertTrue(same(res2 - res1, torch.ones(10)))
123
124    def test_store_global_dict(self):
125        def fn(x):
126            global g_dict
127            val = x + g_dict["b"]
128            """
129            Strictly speaking, we are not testing STORE_GLOBAL
130            here, since STORE_SUBSCR is actually used to store.
131            """
132            g_dict["b"] += 1
133            return val
134
135        x = torch.randn(10)
136        cnts = torch._dynamo.testing.CompileCounter()
137        opt_fn = torch._dynamo.optimize(cnts)(fn)
138        res1 = opt_fn(x)
139        res2 = fn(x)
140        self.assertTrue(same(res2 - res1, torch.ones(10)))
141
142    def test_store_global_dict_2(self):
143        def fn(x):
144            global g_dict
145            g_dict = {key: value + 1 for key, value in g_dict.items()}
146            val = x + g_dict["b"]
147            return val
148
149        x = torch.randn(10)
150        cnts = torch._dynamo.testing.CompileCounter()
151        opt_fn = torch._dynamo.optimize(cnts)(fn)
152        res1 = opt_fn(x)
153        res2 = fn(x)
154        self.assertTrue(same(res2 - res1, torch.ones(10)))
155
156    def test_store_global_object(self):
157        def fn(x):
158            global g_object
159            val = x + g_object.y
160            g_object.y += 1
161            return val
162
163        x = torch.randn(10)
164        cnts = torch._dynamo.testing.CompileCounter()
165        opt_fn = torch._dynamo.optimize(cnts)(fn)
166        res1 = opt_fn(x)
167        res2 = fn(x)
168        self.assertTrue(same(res2 - res1, torch.ones(10)))
169
170    def test_store_global_cross_file(self):
171        def fn(x):
172            val = x + utils.g_tensor_export
173            utils.g_tensor_export = utils.g_tensor_export + 1
174            return val
175
176        x = torch.randn(10)
177        cnts = torch._dynamo.testing.CompileCounter()
178        opt_fn = torch._dynamo.optimize(cnts)(fn)
179        res1 = opt_fn(x)
180        res2 = fn(x)
181        self.assertTrue(same(res2 - res1, torch.ones(10)))
182
183    def test_store_global_inline_1(self):
184        # Borrowed from test_python_autograd.py
185        class Variable:
186            def __init__(self, value: torch.Tensor, name: str = None):
187                self.value = value
188                self.name = name or fresh_name()
189
190        def fn(a, b):
191            a = Variable(a)
192            b = Variable(b)
193            return a.value + b.value, a.name + b.name
194
195        a = torch.randn(10)
196        b = torch.randn(10)
197        cnts = torch._dynamo.testing.CompileCounter()
198        opt_fn = torch._dynamo.optimize(cnts)(fn)
199        v0, s0 = opt_fn(a, b)
200        self.assertEqual(s0, "v0v1")
201        reset_name()
202
203    def test_store_global_inline_2(self):
204        # Borrowed from test_python_autograd.py
205        class Variable:
206            def __init__(self, value: torch.Tensor, name: str = None):
207                self.value = value
208                self.name = name or fresh_name()
209
210            @staticmethod
211            def constant(value: torch.Tensor, name: str = None):
212                return Variable(value, name)
213
214        def fn(a, b):
215            a = Variable.constant(a)
216            b = Variable.constant(b)
217            return a.value + b.value, a.name + b.name
218
219        a = torch.randn(10)
220        b = torch.randn(10)
221        cnts = torch._dynamo.testing.CompileCounter()
222        opt_fn = torch._dynamo.optimize(cnts)(fn)
223        v0, s0 = opt_fn(a, b)
224        self.assertEqual(s0, "v0v1")
225        reset_name()
226
227    def test_store_global_crossfile_inline(self):
228        try:
229            from . import mock_store_global_crossfile_inline
230        except ImportError:
231            import mock_store_global_crossfile_inline
232
233        @torch.compile()
234        def fn(x):
235            mock_store_global_crossfile_inline.set_flag_true()
236            mock_store_global_crossfile_inline.set_flag_false()
237            return x + 1
238
239        @torch.compile()
240        def fn_set_true(x):
241            mock_store_global_crossfile_inline.set_flag_true()
242            return x + 1
243
244        fn_set_true(torch.ones(2, 2))
245        self.assertTrue(mock_store_global_crossfile_inline.global_flag)
246        fn(torch.ones(2, 2))
247        self.assertFalse(mock_store_global_crossfile_inline.global_flag)
248
249
250if __name__ == "__main__":
251    from torch._dynamo.test_case import run_tests
252
253    run_tests()
254