1# Owner(s): ["module: dynamo"] 2import torch 3import torch._dynamo.test_case 4import torch._dynamo.testing 5from torch._dynamo.testing import same 6 7 8try: 9 from . import utils 10except ImportError: 11 import utils 12 13 14class Pair: # noqa: B903 15 def __init__(self, x, y): 16 self.x = x 17 self.y = y 18 19 20def Foo(): 21 return Pair(1, 1) 22 23 24g_counter = 1 25g_list = [0, 1, 2] 26g_dict = {"a": 0, "b": 1} 27g_object = Foo() 28g_tensor = torch.zeros(10) 29 30 31_name: int = 0 32 33 34def fresh_name() -> str: 35 """create a new unique name for a variable: v0, v1, v2""" 36 global _name 37 r = f"v{_name}" 38 _name += 1 39 return r 40 41 42def reset_name(): 43 global _name 44 _name = 0 45 46 47class TestGlobals(torch._dynamo.test_case.TestCase): 48 def test_store_global_1(self): 49 def fn(x): 50 global g_counter 51 val = x + g_counter 52 g_counter += 1 53 return val 54 55 x = torch.randn(10) 56 cnts = torch._dynamo.testing.CompileCounter() 57 opt_fn = torch._dynamo.optimize(cnts)(fn) 58 res1 = opt_fn(x) 59 res2 = fn(x) 60 self.assertTrue(same(res2 - res1, torch.ones(10))) 61 62 def test_store_global_2(self): 63 def fn(x): 64 global g_counter 65 val = x + g_counter 66 g_counter += 1 67 g_counter += 1 68 return val 69 70 x = torch.randn(10) 71 cnts = torch._dynamo.testing.CompileCounter() 72 opt_fn = torch._dynamo.optimize(cnts)(fn) 73 res1 = opt_fn(x) 74 """Wrap the second call with torch._dynamo as well""" 75 opt_fn = torch._dynamo.optimize(cnts)(fn) 76 res2 = opt_fn(x) 77 self.assertTrue(same(res2 - res1, 2 * torch.ones(10))) 78 79 def test_store_global_new(self): 80 def fn(x): 81 # Test create a new global 82 global g_counter_new 83 g_counter_new = x + 1 84 return x + g_counter_new 85 86 x = torch.randn(10) 87 cnts = torch._dynamo.testing.CompileCounter() 88 opt_fn = torch._dynamo.optimize(cnts)(fn) 89 res1 = opt_fn(x) 90 self.assertTrue(same(res1, x + x + 1)) 91 92 def test_store_global_list(self): 93 def fn(x): 94 global g_list 95 val = x + g_list[1] 96 """ 97 Strictly speaking, we are not testing STORE_GLOBAL 98 here, since STORE_SUBSCR is actually used to store. 99 """ 100 g_list[1] += 1 101 return val 102 103 x = torch.randn(10) 104 cnts = torch._dynamo.testing.CompileCounter() 105 opt_fn = torch._dynamo.optimize(cnts)(fn) 106 res1 = opt_fn(x) 107 res2 = fn(x) 108 self.assertTrue(same(res2 - res1, torch.ones(10))) 109 110 def test_store_global_list_2(self): 111 def fn(x): 112 global g_list 113 val = x + g_list[1] 114 g_list = [x + 1 for x in g_list] 115 return val 116 117 x = torch.randn(10) 118 cnts = torch._dynamo.testing.CompileCounter() 119 opt_fn = torch._dynamo.optimize(cnts)(fn) 120 res1 = opt_fn(x) 121 res2 = fn(x) 122 self.assertTrue(same(res2 - res1, torch.ones(10))) 123 124 def test_store_global_dict(self): 125 def fn(x): 126 global g_dict 127 val = x + g_dict["b"] 128 """ 129 Strictly speaking, we are not testing STORE_GLOBAL 130 here, since STORE_SUBSCR is actually used to store. 131 """ 132 g_dict["b"] += 1 133 return val 134 135 x = torch.randn(10) 136 cnts = torch._dynamo.testing.CompileCounter() 137 opt_fn = torch._dynamo.optimize(cnts)(fn) 138 res1 = opt_fn(x) 139 res2 = fn(x) 140 self.assertTrue(same(res2 - res1, torch.ones(10))) 141 142 def test_store_global_dict_2(self): 143 def fn(x): 144 global g_dict 145 g_dict = {key: value + 1 for key, value in g_dict.items()} 146 val = x + g_dict["b"] 147 return val 148 149 x = torch.randn(10) 150 cnts = torch._dynamo.testing.CompileCounter() 151 opt_fn = torch._dynamo.optimize(cnts)(fn) 152 res1 = opt_fn(x) 153 res2 = fn(x) 154 self.assertTrue(same(res2 - res1, torch.ones(10))) 155 156 def test_store_global_object(self): 157 def fn(x): 158 global g_object 159 val = x + g_object.y 160 g_object.y += 1 161 return val 162 163 x = torch.randn(10) 164 cnts = torch._dynamo.testing.CompileCounter() 165 opt_fn = torch._dynamo.optimize(cnts)(fn) 166 res1 = opt_fn(x) 167 res2 = fn(x) 168 self.assertTrue(same(res2 - res1, torch.ones(10))) 169 170 def test_store_global_cross_file(self): 171 def fn(x): 172 val = x + utils.g_tensor_export 173 utils.g_tensor_export = utils.g_tensor_export + 1 174 return val 175 176 x = torch.randn(10) 177 cnts = torch._dynamo.testing.CompileCounter() 178 opt_fn = torch._dynamo.optimize(cnts)(fn) 179 res1 = opt_fn(x) 180 res2 = fn(x) 181 self.assertTrue(same(res2 - res1, torch.ones(10))) 182 183 def test_store_global_inline_1(self): 184 # Borrowed from test_python_autograd.py 185 class Variable: 186 def __init__(self, value: torch.Tensor, name: str = None): 187 self.value = value 188 self.name = name or fresh_name() 189 190 def fn(a, b): 191 a = Variable(a) 192 b = Variable(b) 193 return a.value + b.value, a.name + b.name 194 195 a = torch.randn(10) 196 b = torch.randn(10) 197 cnts = torch._dynamo.testing.CompileCounter() 198 opt_fn = torch._dynamo.optimize(cnts)(fn) 199 v0, s0 = opt_fn(a, b) 200 self.assertEqual(s0, "v0v1") 201 reset_name() 202 203 def test_store_global_inline_2(self): 204 # Borrowed from test_python_autograd.py 205 class Variable: 206 def __init__(self, value: torch.Tensor, name: str = None): 207 self.value = value 208 self.name = name or fresh_name() 209 210 @staticmethod 211 def constant(value: torch.Tensor, name: str = None): 212 return Variable(value, name) 213 214 def fn(a, b): 215 a = Variable.constant(a) 216 b = Variable.constant(b) 217 return a.value + b.value, a.name + b.name 218 219 a = torch.randn(10) 220 b = torch.randn(10) 221 cnts = torch._dynamo.testing.CompileCounter() 222 opt_fn = torch._dynamo.optimize(cnts)(fn) 223 v0, s0 = opt_fn(a, b) 224 self.assertEqual(s0, "v0v1") 225 reset_name() 226 227 def test_store_global_crossfile_inline(self): 228 try: 229 from . import mock_store_global_crossfile_inline 230 except ImportError: 231 import mock_store_global_crossfile_inline 232 233 @torch.compile() 234 def fn(x): 235 mock_store_global_crossfile_inline.set_flag_true() 236 mock_store_global_crossfile_inline.set_flag_false() 237 return x + 1 238 239 @torch.compile() 240 def fn_set_true(x): 241 mock_store_global_crossfile_inline.set_flag_true() 242 return x + 1 243 244 fn_set_true(torch.ones(2, 2)) 245 self.assertTrue(mock_store_global_crossfile_inline.global_flag) 246 fn(torch.ones(2, 2)) 247 self.assertFalse(mock_store_global_crossfile_inline.global_flag) 248 249 250if __name__ == "__main__": 251 from torch._dynamo.test_case import run_tests 252 253 run_tests() 254