1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Workerimport torch 3*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 4*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 5*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import same 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workertry: 9*da0073e9SAndroid Build Coastguard Worker from . import utils 10*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 11*da0073e9SAndroid Build Coastguard Worker import utils 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerclass Pair: # noqa: B903 15*da0073e9SAndroid Build Coastguard Worker def __init__(self, x, y): 16*da0073e9SAndroid Build Coastguard Worker self.x = x 17*da0073e9SAndroid Build Coastguard Worker self.y = y 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Workerdef Foo(): 21*da0073e9SAndroid Build Coastguard Worker return Pair(1, 1) 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerg_counter = 1 25*da0073e9SAndroid Build Coastguard Workerg_list = [0, 1, 2] 26*da0073e9SAndroid Build Coastguard Workerg_dict = {"a": 0, "b": 1} 27*da0073e9SAndroid Build Coastguard Workerg_object = Foo() 28*da0073e9SAndroid Build Coastguard Workerg_tensor = torch.zeros(10) 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker_name: int = 0 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Workerdef fresh_name() -> str: 35*da0073e9SAndroid Build Coastguard Worker """create a new unique name for a variable: v0, v1, v2""" 36*da0073e9SAndroid Build Coastguard Worker global _name 37*da0073e9SAndroid Build Coastguard Worker r = f"v{_name}" 38*da0073e9SAndroid Build Coastguard Worker _name += 1 39*da0073e9SAndroid Build Coastguard Worker return r 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Workerdef reset_name(): 43*da0073e9SAndroid Build Coastguard Worker global _name 44*da0073e9SAndroid Build Coastguard Worker _name = 0 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Workerclass TestGlobals(torch._dynamo.test_case.TestCase): 48*da0073e9SAndroid Build Coastguard Worker def test_store_global_1(self): 49*da0073e9SAndroid Build Coastguard Worker def fn(x): 50*da0073e9SAndroid Build Coastguard Worker global g_counter 51*da0073e9SAndroid Build Coastguard Worker val = x + g_counter 52*da0073e9SAndroid Build Coastguard Worker g_counter += 1 53*da0073e9SAndroid Build Coastguard Worker return val 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 56*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 57*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 58*da0073e9SAndroid Build Coastguard Worker res1 = opt_fn(x) 59*da0073e9SAndroid Build Coastguard Worker res2 = fn(x) 60*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res2 - res1, torch.ones(10))) 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker def test_store_global_2(self): 63*da0073e9SAndroid Build Coastguard Worker def fn(x): 64*da0073e9SAndroid Build Coastguard Worker global g_counter 65*da0073e9SAndroid Build Coastguard Worker val = x + g_counter 66*da0073e9SAndroid Build Coastguard Worker g_counter += 1 67*da0073e9SAndroid Build Coastguard Worker g_counter += 1 68*da0073e9SAndroid Build Coastguard Worker return val 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 71*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 72*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 73*da0073e9SAndroid Build Coastguard Worker res1 = opt_fn(x) 74*da0073e9SAndroid Build Coastguard Worker """Wrap the second call with torch._dynamo as well""" 75*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 76*da0073e9SAndroid Build Coastguard Worker res2 = opt_fn(x) 77*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res2 - res1, 2 * torch.ones(10))) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker def test_store_global_new(self): 80*da0073e9SAndroid Build Coastguard Worker def fn(x): 81*da0073e9SAndroid Build Coastguard Worker # Test create a new global 82*da0073e9SAndroid Build Coastguard Worker global g_counter_new 83*da0073e9SAndroid Build Coastguard Worker g_counter_new = x + 1 84*da0073e9SAndroid Build Coastguard Worker return x + g_counter_new 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 87*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 88*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 89*da0073e9SAndroid Build Coastguard Worker res1 = opt_fn(x) 90*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res1, x + x + 1)) 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker def test_store_global_list(self): 93*da0073e9SAndroid Build Coastguard Worker def fn(x): 94*da0073e9SAndroid Build Coastguard Worker global g_list 95*da0073e9SAndroid Build Coastguard Worker val = x + g_list[1] 96*da0073e9SAndroid Build Coastguard Worker """ 97*da0073e9SAndroid Build Coastguard Worker Strictly speaking, we are not testing STORE_GLOBAL 98*da0073e9SAndroid Build Coastguard Worker here, since STORE_SUBSCR is actually used to store. 99*da0073e9SAndroid Build Coastguard Worker """ 100*da0073e9SAndroid Build Coastguard Worker g_list[1] += 1 101*da0073e9SAndroid Build Coastguard Worker return val 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 104*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 105*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 106*da0073e9SAndroid Build Coastguard Worker res1 = opt_fn(x) 107*da0073e9SAndroid Build Coastguard Worker res2 = fn(x) 108*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res2 - res1, torch.ones(10))) 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker def test_store_global_list_2(self): 111*da0073e9SAndroid Build Coastguard Worker def fn(x): 112*da0073e9SAndroid Build Coastguard Worker global g_list 113*da0073e9SAndroid Build Coastguard Worker val = x + g_list[1] 114*da0073e9SAndroid Build Coastguard Worker g_list = [x + 1 for x in g_list] 115*da0073e9SAndroid Build Coastguard Worker return val 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 118*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 119*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 120*da0073e9SAndroid Build Coastguard Worker res1 = opt_fn(x) 121*da0073e9SAndroid Build Coastguard Worker res2 = fn(x) 122*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res2 - res1, torch.ones(10))) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker def test_store_global_dict(self): 125*da0073e9SAndroid Build Coastguard Worker def fn(x): 126*da0073e9SAndroid Build Coastguard Worker global g_dict 127*da0073e9SAndroid Build Coastguard Worker val = x + g_dict["b"] 128*da0073e9SAndroid Build Coastguard Worker """ 129*da0073e9SAndroid Build Coastguard Worker Strictly speaking, we are not testing STORE_GLOBAL 130*da0073e9SAndroid Build Coastguard Worker here, since STORE_SUBSCR is actually used to store. 131*da0073e9SAndroid Build Coastguard Worker """ 132*da0073e9SAndroid Build Coastguard Worker g_dict["b"] += 1 133*da0073e9SAndroid Build Coastguard Worker return val 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 136*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 137*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 138*da0073e9SAndroid Build Coastguard Worker res1 = opt_fn(x) 139*da0073e9SAndroid Build Coastguard Worker res2 = fn(x) 140*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res2 - res1, torch.ones(10))) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker def test_store_global_dict_2(self): 143*da0073e9SAndroid Build Coastguard Worker def fn(x): 144*da0073e9SAndroid Build Coastguard Worker global g_dict 145*da0073e9SAndroid Build Coastguard Worker g_dict = {key: value + 1 for key, value in g_dict.items()} 146*da0073e9SAndroid Build Coastguard Worker val = x + g_dict["b"] 147*da0073e9SAndroid Build Coastguard Worker return val 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 150*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 151*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 152*da0073e9SAndroid Build Coastguard Worker res1 = opt_fn(x) 153*da0073e9SAndroid Build Coastguard Worker res2 = fn(x) 154*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res2 - res1, torch.ones(10))) 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker def test_store_global_object(self): 157*da0073e9SAndroid Build Coastguard Worker def fn(x): 158*da0073e9SAndroid Build Coastguard Worker global g_object 159*da0073e9SAndroid Build Coastguard Worker val = x + g_object.y 160*da0073e9SAndroid Build Coastguard Worker g_object.y += 1 161*da0073e9SAndroid Build Coastguard Worker return val 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 164*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 165*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 166*da0073e9SAndroid Build Coastguard Worker res1 = opt_fn(x) 167*da0073e9SAndroid Build Coastguard Worker res2 = fn(x) 168*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res2 - res1, torch.ones(10))) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker def test_store_global_cross_file(self): 171*da0073e9SAndroid Build Coastguard Worker def fn(x): 172*da0073e9SAndroid Build Coastguard Worker val = x + utils.g_tensor_export 173*da0073e9SAndroid Build Coastguard Worker utils.g_tensor_export = utils.g_tensor_export + 1 174*da0073e9SAndroid Build Coastguard Worker return val 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 177*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 178*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 179*da0073e9SAndroid Build Coastguard Worker res1 = opt_fn(x) 180*da0073e9SAndroid Build Coastguard Worker res2 = fn(x) 181*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res2 - res1, torch.ones(10))) 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker def test_store_global_inline_1(self): 184*da0073e9SAndroid Build Coastguard Worker # Borrowed from test_python_autograd.py 185*da0073e9SAndroid Build Coastguard Worker class Variable: 186*da0073e9SAndroid Build Coastguard Worker def __init__(self, value: torch.Tensor, name: str = None): 187*da0073e9SAndroid Build Coastguard Worker self.value = value 188*da0073e9SAndroid Build Coastguard Worker self.name = name or fresh_name() 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 191*da0073e9SAndroid Build Coastguard Worker a = Variable(a) 192*da0073e9SAndroid Build Coastguard Worker b = Variable(b) 193*da0073e9SAndroid Build Coastguard Worker return a.value + b.value, a.name + b.name 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker a = torch.randn(10) 196*da0073e9SAndroid Build Coastguard Worker b = torch.randn(10) 197*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 198*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 199*da0073e9SAndroid Build Coastguard Worker v0, s0 = opt_fn(a, b) 200*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s0, "v0v1") 201*da0073e9SAndroid Build Coastguard Worker reset_name() 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker def test_store_global_inline_2(self): 204*da0073e9SAndroid Build Coastguard Worker # Borrowed from test_python_autograd.py 205*da0073e9SAndroid Build Coastguard Worker class Variable: 206*da0073e9SAndroid Build Coastguard Worker def __init__(self, value: torch.Tensor, name: str = None): 207*da0073e9SAndroid Build Coastguard Worker self.value = value 208*da0073e9SAndroid Build Coastguard Worker self.name = name or fresh_name() 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker @staticmethod 211*da0073e9SAndroid Build Coastguard Worker def constant(value: torch.Tensor, name: str = None): 212*da0073e9SAndroid Build Coastguard Worker return Variable(value, name) 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 215*da0073e9SAndroid Build Coastguard Worker a = Variable.constant(a) 216*da0073e9SAndroid Build Coastguard Worker b = Variable.constant(b) 217*da0073e9SAndroid Build Coastguard Worker return a.value + b.value, a.name + b.name 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker a = torch.randn(10) 220*da0073e9SAndroid Build Coastguard Worker b = torch.randn(10) 221*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 222*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnts)(fn) 223*da0073e9SAndroid Build Coastguard Worker v0, s0 = opt_fn(a, b) 224*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s0, "v0v1") 225*da0073e9SAndroid Build Coastguard Worker reset_name() 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker def test_store_global_crossfile_inline(self): 228*da0073e9SAndroid Build Coastguard Worker try: 229*da0073e9SAndroid Build Coastguard Worker from . import mock_store_global_crossfile_inline 230*da0073e9SAndroid Build Coastguard Worker except ImportError: 231*da0073e9SAndroid Build Coastguard Worker import mock_store_global_crossfile_inline 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker @torch.compile() 234*da0073e9SAndroid Build Coastguard Worker def fn(x): 235*da0073e9SAndroid Build Coastguard Worker mock_store_global_crossfile_inline.set_flag_true() 236*da0073e9SAndroid Build Coastguard Worker mock_store_global_crossfile_inline.set_flag_false() 237*da0073e9SAndroid Build Coastguard Worker return x + 1 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker @torch.compile() 240*da0073e9SAndroid Build Coastguard Worker def fn_set_true(x): 241*da0073e9SAndroid Build Coastguard Worker mock_store_global_crossfile_inline.set_flag_true() 242*da0073e9SAndroid Build Coastguard Worker return x + 1 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker fn_set_true(torch.ones(2, 2)) 245*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mock_store_global_crossfile_inline.global_flag) 246*da0073e9SAndroid Build Coastguard Worker fn(torch.ones(2, 2)) 247*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mock_store_global_crossfile_inline.global_flag) 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker 250*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 251*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker run_tests() 254