# Owner(s): ["module: dynamo"] import torch import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.testing import same try: from . import utils except ImportError: import utils class Pair: # noqa: B903 def __init__(self, x, y): self.x = x self.y = y def Foo(): return Pair(1, 1) g_counter = 1 g_list = [0, 1, 2] g_dict = {"a": 0, "b": 1} g_object = Foo() g_tensor = torch.zeros(10) _name: int = 0 def fresh_name() -> str: """create a new unique name for a variable: v0, v1, v2""" global _name r = f"v{_name}" _name += 1 return r def reset_name(): global _name _name = 0 class TestGlobals(torch._dynamo.test_case.TestCase): def test_store_global_1(self): def fn(x): global g_counter val = x + g_counter g_counter += 1 return val x = torch.randn(10) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) res1 = opt_fn(x) res2 = fn(x) self.assertTrue(same(res2 - res1, torch.ones(10))) def test_store_global_2(self): def fn(x): global g_counter val = x + g_counter g_counter += 1 g_counter += 1 return val x = torch.randn(10) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) res1 = opt_fn(x) """Wrap the second call with torch._dynamo as well""" opt_fn = torch._dynamo.optimize(cnts)(fn) res2 = opt_fn(x) self.assertTrue(same(res2 - res1, 2 * torch.ones(10))) def test_store_global_new(self): def fn(x): # Test create a new global global g_counter_new g_counter_new = x + 1 return x + g_counter_new x = torch.randn(10) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) res1 = opt_fn(x) self.assertTrue(same(res1, x + x + 1)) def test_store_global_list(self): def fn(x): global g_list val = x + g_list[1] """ Strictly speaking, we are not testing STORE_GLOBAL here, since STORE_SUBSCR is actually used to store. """ g_list[1] += 1 return val x = torch.randn(10) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) res1 = opt_fn(x) res2 = fn(x) self.assertTrue(same(res2 - res1, torch.ones(10))) def test_store_global_list_2(self): def fn(x): global g_list val = x + g_list[1] g_list = [x + 1 for x in g_list] return val x = torch.randn(10) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) res1 = opt_fn(x) res2 = fn(x) self.assertTrue(same(res2 - res1, torch.ones(10))) def test_store_global_dict(self): def fn(x): global g_dict val = x + g_dict["b"] """ Strictly speaking, we are not testing STORE_GLOBAL here, since STORE_SUBSCR is actually used to store. """ g_dict["b"] += 1 return val x = torch.randn(10) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) res1 = opt_fn(x) res2 = fn(x) self.assertTrue(same(res2 - res1, torch.ones(10))) def test_store_global_dict_2(self): def fn(x): global g_dict g_dict = {key: value + 1 for key, value in g_dict.items()} val = x + g_dict["b"] return val x = torch.randn(10) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) res1 = opt_fn(x) res2 = fn(x) self.assertTrue(same(res2 - res1, torch.ones(10))) def test_store_global_object(self): def fn(x): global g_object val = x + g_object.y g_object.y += 1 return val x = torch.randn(10) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) res1 = opt_fn(x) res2 = fn(x) self.assertTrue(same(res2 - res1, torch.ones(10))) def test_store_global_cross_file(self): def fn(x): val = x + utils.g_tensor_export utils.g_tensor_export = utils.g_tensor_export + 1 return val x = torch.randn(10) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) res1 = opt_fn(x) res2 = fn(x) self.assertTrue(same(res2 - res1, torch.ones(10))) def test_store_global_inline_1(self): # Borrowed from test_python_autograd.py class Variable: def __init__(self, value: torch.Tensor, name: str = None): self.value = value self.name = name or fresh_name() def fn(a, b): a = Variable(a) b = Variable(b) return a.value + b.value, a.name + b.name a = torch.randn(10) b = torch.randn(10) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) v0, s0 = opt_fn(a, b) self.assertEqual(s0, "v0v1") reset_name() def test_store_global_inline_2(self): # Borrowed from test_python_autograd.py class Variable: def __init__(self, value: torch.Tensor, name: str = None): self.value = value self.name = name or fresh_name() @staticmethod def constant(value: torch.Tensor, name: str = None): return Variable(value, name) def fn(a, b): a = Variable.constant(a) b = Variable.constant(b) return a.value + b.value, a.name + b.name a = torch.randn(10) b = torch.randn(10) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) v0, s0 = opt_fn(a, b) self.assertEqual(s0, "v0v1") reset_name() def test_store_global_crossfile_inline(self): try: from . import mock_store_global_crossfile_inline except ImportError: import mock_store_global_crossfile_inline @torch.compile() def fn(x): mock_store_global_crossfile_inline.set_flag_true() mock_store_global_crossfile_inline.set_flag_false() return x + 1 @torch.compile() def fn_set_true(x): mock_store_global_crossfile_inline.set_flag_true() return x + 1 fn_set_true(torch.ones(2, 2)) self.assertTrue(mock_store_global_crossfile_inline.global_flag) fn(torch.ones(2, 2)) self.assertFalse(mock_store_global_crossfile_inline.global_flag) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()