1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport collections 4*da0073e9SAndroid Build Coastguard Workerimport re 5*da0073e9SAndroid Build Coastguard Workerimport sys 6*da0073e9SAndroid Build Coastguard Workerimport time 7*da0073e9SAndroid Build Coastguard Workerfrom io import StringIO 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 10*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 11*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.comptime import comptime 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker# Because we don't support free variables in comptime at the moment, 15*da0073e9SAndroid Build Coastguard Worker# we have to communicate via globals. This also means these tests cannot 16*da0073e9SAndroid Build Coastguard Worker# be run in parallel in a single process (not that you'd... ever want 17*da0073e9SAndroid Build Coastguard Worker# to do that?) 18*da0073e9SAndroid Build Coastguard WorkerFILE = None 19*da0073e9SAndroid Build Coastguard WorkerSELF = None 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Workerclass ComptimeTests(torch._dynamo.test_case.TestCase): 23*da0073e9SAndroid Build Coastguard Worker def test_print_single(self): 24*da0073e9SAndroid Build Coastguard Worker global FILE 25*da0073e9SAndroid Build Coastguard Worker FILE = StringIO() 26*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker def comptime_print(e): 29*da0073e9SAndroid Build Coastguard Worker @comptime 30*da0073e9SAndroid Build Coastguard Worker def _(ctx): 31*da0073e9SAndroid Build Coastguard Worker ctx.print(ctx.get_local("e"), file=FILE) 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker Employee = collections.namedtuple("Employee", ["name", "id"]) 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker class mylist(list): 36*da0073e9SAndroid Build Coastguard Worker pass 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnt, dynamic=True) 39*da0073e9SAndroid Build Coastguard Worker def f(x): 40*da0073e9SAndroid Build Coastguard Worker y = x * 2 41*da0073e9SAndroid Build Coastguard Worker comptime_print(y) 42*da0073e9SAndroid Build Coastguard Worker comptime_print(2) 43*da0073e9SAndroid Build Coastguard Worker comptime_print([y, 2]) 44*da0073e9SAndroid Build Coastguard Worker comptime_print((y, 2)) 45*da0073e9SAndroid Build Coastguard Worker comptime_print({"foo": y}) 46*da0073e9SAndroid Build Coastguard Worker comptime_print(range(1, 3)) 47*da0073e9SAndroid Build Coastguard Worker comptime_print(Employee("foo", 2)) 48*da0073e9SAndroid Build Coastguard Worker comptime_print(mylist([1, 2])) 49*da0073e9SAndroid Build Coastguard Worker comptime_print(collections.defaultdict(lambda: None)) 50*da0073e9SAndroid Build Coastguard Worker comptime_print(set()) 51*da0073e9SAndroid Build Coastguard Worker comptime_print({"a", "b"}) 52*da0073e9SAndroid Build Coastguard Worker comptime_print(x.size(0)) 53*da0073e9SAndroid Build Coastguard Worker return y + 3 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker f(torch.randn(2)) 56*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 57*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 58*da0073e9SAndroid Build Coastguard Worker FILE.getvalue().strip(), 59*da0073e9SAndroid Build Coastguard Worker """\ 60*da0073e9SAndroid Build Coastguard WorkerFakeTensor(..., size=(s0,)) 61*da0073e9SAndroid Build Coastguard Worker2 62*da0073e9SAndroid Build Coastguard Worker[FakeTensor(..., size=(s0,)), 2] 63*da0073e9SAndroid Build Coastguard Worker(FakeTensor(..., size=(s0,)), 2) 64*da0073e9SAndroid Build Coastguard Worker{'foo': FakeTensor(..., size=(s0,))} 65*da0073e9SAndroid Build Coastguard Workerrange(1, 3, 1) 66*da0073e9SAndroid Build Coastguard WorkerEmployee(name='foo', id=2) 67*da0073e9SAndroid Build Coastguard Worker[1, 2] 68*da0073e9SAndroid Build Coastguard Workerdefaultdict(NestedUserFunctionVariable(), {}) 69*da0073e9SAndroid Build Coastguard Workerset() 70*da0073e9SAndroid Build Coastguard Worker{'a','b'} 71*da0073e9SAndroid Build Coastguard Workers0""", 72*da0073e9SAndroid Build Coastguard Worker ) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker def test_print_graph(self): 75*da0073e9SAndroid Build Coastguard Worker global FILE 76*da0073e9SAndroid Build Coastguard Worker FILE = StringIO() 77*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnt) 80*da0073e9SAndroid Build Coastguard Worker def f(x): 81*da0073e9SAndroid Build Coastguard Worker y = x * 2 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker @comptime 84*da0073e9SAndroid Build Coastguard Worker def _(ctx): 85*da0073e9SAndroid Build Coastguard Worker ctx.print_graph(verbose=False, file=FILE) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker # Test the compact notation doesn't error or graph break; 88*da0073e9SAndroid Build Coastguard Worker # you'll have to visually inspect to see that it printed 89*da0073e9SAndroid Build Coastguard Worker comptime.print_graph() 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker return y + 3 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker f(torch.randn(2)) 94*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 95*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 96*da0073e9SAndroid Build Coastguard Worker FILE.getvalue().strip(), 97*da0073e9SAndroid Build Coastguard Worker """\ 98*da0073e9SAndroid Build Coastguard Workerdef forward(self, L_x_ : torch.Tensor): 99*da0073e9SAndroid Build Coastguard Worker l_x_ = L_x_ 100*da0073e9SAndroid Build Coastguard Worker y = l_x_ * 2; l_x_ = y = None""", 101*da0073e9SAndroid Build Coastguard Worker ) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker def test_print_disas(self): 104*da0073e9SAndroid Build Coastguard Worker global FILE 105*da0073e9SAndroid Build Coastguard Worker FILE = StringIO() 106*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnt) 109*da0073e9SAndroid Build Coastguard Worker def f(x): 110*da0073e9SAndroid Build Coastguard Worker y = x * 2 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker @comptime 113*da0073e9SAndroid Build Coastguard Worker def _(ctx): 114*da0073e9SAndroid Build Coastguard Worker ctx.print_disas(file=FILE) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker comptime.print_disas() 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker return y + 3 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker def munge_disas(s): 121*da0073e9SAndroid Build Coastguard Worker re.sub( 122*da0073e9SAndroid Build Coastguard Worker r"^(?: +\d+)?(?: +(-->)) \+\d+ ([A-Za-z0-9_]+)", 123*da0073e9SAndroid Build Coastguard Worker "\1 \3", 124*da0073e9SAndroid Build Coastguard Worker s, 125*da0073e9SAndroid Build Coastguard Worker flags=re.MULTILINE, 126*da0073e9SAndroid Build Coastguard Worker ) 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker f(torch.randn(2)) 129*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 130*da0073e9SAndroid Build Coastguard Worker out = FILE.getvalue() 131*da0073e9SAndroid Build Coastguard Worker # Check that the instruction offset is working 132*da0073e9SAndroid Build Coastguard Worker self.assertIn("-->", out) 133*da0073e9SAndroid Build Coastguard Worker # Check that the bytecode resembles what we expect 134*da0073e9SAndroid Build Coastguard Worker self.assertIn("STORE_FAST", out) 135*da0073e9SAndroid Build Coastguard Worker if sys.version_info < (3, 11): 136*da0073e9SAndroid Build Coastguard Worker self.assertIn("BINARY_MULTIPLY", out) 137*da0073e9SAndroid Build Coastguard Worker else: 138*da0073e9SAndroid Build Coastguard Worker self.assertIn("BINARY_OP", out) 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker def test_print_value_stack(self): 141*da0073e9SAndroid Build Coastguard Worker global FILE 142*da0073e9SAndroid Build Coastguard Worker FILE = StringIO() 143*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker def g(x): 146*da0073e9SAndroid Build Coastguard Worker @comptime 147*da0073e9SAndroid Build Coastguard Worker def _(ctx): 148*da0073e9SAndroid Build Coastguard Worker ctx.print_value_stack(file=FILE, stacklevel=1) 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker return x 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnt) 153*da0073e9SAndroid Build Coastguard Worker def f(x): 154*da0073e9SAndroid Build Coastguard Worker y = x + g(x) 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker return y + comptime.print_value_stack_and_return(y * 2) 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker f(torch.randn(2)) 159*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 160*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 161*da0073e9SAndroid Build Coastguard Worker FILE.getvalue(), 162*da0073e9SAndroid Build Coastguard Worker """\ 163*da0073e9SAndroid Build Coastguard Worker- FakeTensor(..., size=(2,)) 164*da0073e9SAndroid Build Coastguard Worker""", 165*da0073e9SAndroid Build Coastguard Worker ) 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker def test_print_locals(self): 168*da0073e9SAndroid Build Coastguard Worker global FILE 169*da0073e9SAndroid Build Coastguard Worker FILE = StringIO() 170*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnt) 173*da0073e9SAndroid Build Coastguard Worker def f(x): 174*da0073e9SAndroid Build Coastguard Worker y = x * 2 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker @comptime 177*da0073e9SAndroid Build Coastguard Worker def _(ctx): 178*da0073e9SAndroid Build Coastguard Worker ctx.print_locals(file=FILE) 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker comptime.print_locals() 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker return y + 3 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker f(torch.randn(2)) 185*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 186*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 187*da0073e9SAndroid Build Coastguard Worker FILE.getvalue(), 188*da0073e9SAndroid Build Coastguard Worker """\ 189*da0073e9SAndroid Build Coastguard Workerx = FakeTensor(..., size=(2,)) 190*da0073e9SAndroid Build Coastguard Workery = FakeTensor(..., size=(2,)) 191*da0073e9SAndroid Build Coastguard Worker""", 192*da0073e9SAndroid Build Coastguard Worker ) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker # Just make sure it doesn't crash 195*da0073e9SAndroid Build Coastguard Worker def test_print_direct(self): 196*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnt) 199*da0073e9SAndroid Build Coastguard Worker def f(x, z): 200*da0073e9SAndroid Build Coastguard Worker y = x * 2 201*da0073e9SAndroid Build Coastguard Worker lambda: z 202*da0073e9SAndroid Build Coastguard Worker comptime.print(z) 203*da0073e9SAndroid Build Coastguard Worker return y + 3 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker f(torch.randn(2), torch.randn(2)) 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker def test_sleep(self): 208*da0073e9SAndroid Build Coastguard Worker sleep_time = 5 209*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnt) 212*da0073e9SAndroid Build Coastguard Worker def f(x, z, should_sleep): 213*da0073e9SAndroid Build Coastguard Worker if should_sleep: 214*da0073e9SAndroid Build Coastguard Worker comptime.sleep(sleep_time) 215*da0073e9SAndroid Build Coastguard Worker y = x * 2 216*da0073e9SAndroid Build Coastguard Worker return y + 3 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker start = time.time() 219*da0073e9SAndroid Build Coastguard Worker f(torch.randn(2), torch.randn(2), False) 220*da0073e9SAndroid Build Coastguard Worker total_no_sleep = time.time() - start 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker start = time.time() 223*da0073e9SAndroid Build Coastguard Worker f(torch.randn(2), torch.randn(2), True) 224*da0073e9SAndroid Build Coastguard Worker total_with_sleep = time.time() - start 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker self.assertTrue(total_with_sleep > sleep_time) 227*da0073e9SAndroid Build Coastguard Worker # Hopefully this won't be flaky 228*da0073e9SAndroid Build Coastguard Worker self.assertTrue(abs(total_with_sleep - sleep_time - total_no_sleep) < 3) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker # Just make sure it doesn't crash 231*da0073e9SAndroid Build Coastguard Worker def test_get_local_closure_variable(self): 232*da0073e9SAndroid Build Coastguard Worker global SELF 233*da0073e9SAndroid Build Coastguard Worker SELF = self 234*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnt) 237*da0073e9SAndroid Build Coastguard Worker def f(x): 238*da0073e9SAndroid Build Coastguard Worker z = 3 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker def g(): 241*da0073e9SAndroid Build Coastguard Worker @comptime 242*da0073e9SAndroid Build Coastguard Worker def _(ctx): 243*da0073e9SAndroid Build Coastguard Worker r = ctx.get_local("z") 244*da0073e9SAndroid Build Coastguard Worker SELF.assertEqual(repr(r), "3") 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker comptime.print(z) 247*da0073e9SAndroid Build Coastguard Worker return 2 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker y = x * g() 250*da0073e9SAndroid Build Coastguard Worker return y + 3 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker f(torch.randn(2)) 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker def test_print_bt(self): 255*da0073e9SAndroid Build Coastguard Worker global FILE 256*da0073e9SAndroid Build Coastguard Worker FILE = StringIO() 257*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker def g(x): 260*da0073e9SAndroid Build Coastguard Worker @comptime 261*da0073e9SAndroid Build Coastguard Worker def _(ctx): 262*da0073e9SAndroid Build Coastguard Worker ctx.print_bt(file=FILE) 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker comptime.print_bt() 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker return x + 3 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnt) 269*da0073e9SAndroid Build Coastguard Worker def f(x): 270*da0073e9SAndroid Build Coastguard Worker y = x * 2 271*da0073e9SAndroid Build Coastguard Worker y = g(y) 272*da0073e9SAndroid Build Coastguard Worker return y + 3 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker def munge_filenames(s): 275*da0073e9SAndroid Build Coastguard Worker return re.sub(r'File "[^"]+", line \d+', 'File "X", line X', s) 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker f(torch.randn(2)) 278*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 279*da0073e9SAndroid Build Coastguard Worker bt = FILE.getvalue() 280*da0073e9SAndroid Build Coastguard Worker self.assertIn("y = g(y)", bt) 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker def test_print_guards(self): 283*da0073e9SAndroid Build Coastguard Worker global FILE 284*da0073e9SAndroid Build Coastguard Worker FILE = StringIO() 285*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnt) 288*da0073e9SAndroid Build Coastguard Worker def f(x): 289*da0073e9SAndroid Build Coastguard Worker y = x * 2 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker @comptime 292*da0073e9SAndroid Build Coastguard Worker def _(ctx): 293*da0073e9SAndroid Build Coastguard Worker ctx.print_guards(file=FILE) 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Worker comptime.print_guards() 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker return y + 3 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker f(torch.randn(2)) 300*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 301*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 302*da0073e9SAndroid Build Coastguard Worker re.sub(r"\s+$", "", FILE.getvalue().rstrip(), flags=re.MULTILINE), 303*da0073e9SAndroid Build Coastguard Worker """\ 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker local "L['x']" TENSOR_MATCH 306*da0073e9SAndroid Build Coastguard Worker { 307*da0073e9SAndroid Build Coastguard Worker 'guard_types': None, 308*da0073e9SAndroid Build Coastguard Worker 'code': None, 309*da0073e9SAndroid Build Coastguard Worker 'obj_weakref': None 310*da0073e9SAndroid Build Coastguard Worker 'guarded_class': None 311*da0073e9SAndroid Build Coastguard Worker } 312*da0073e9SAndroid Build Coastguard Worker global '' GRAD_MODE 313*da0073e9SAndroid Build Coastguard Worker { 314*da0073e9SAndroid Build Coastguard Worker 'guard_types': None, 315*da0073e9SAndroid Build Coastguard Worker 'code': None, 316*da0073e9SAndroid Build Coastguard Worker 'obj_weakref': None 317*da0073e9SAndroid Build Coastguard Worker 'guarded_class': None 318*da0073e9SAndroid Build Coastguard Worker } 319*da0073e9SAndroid Build Coastguard Worker global '' DETERMINISTIC_ALGORITHMS 320*da0073e9SAndroid Build Coastguard Worker { 321*da0073e9SAndroid Build Coastguard Worker 'guard_types': None, 322*da0073e9SAndroid Build Coastguard Worker 'code': None, 323*da0073e9SAndroid Build Coastguard Worker 'obj_weakref': None 324*da0073e9SAndroid Build Coastguard Worker 'guarded_class': None 325*da0073e9SAndroid Build Coastguard Worker } 326*da0073e9SAndroid Build Coastguard Worker global '' TORCH_FUNCTION_STATE 327*da0073e9SAndroid Build Coastguard Worker { 328*da0073e9SAndroid Build Coastguard Worker 'guard_types': None, 329*da0073e9SAndroid Build Coastguard Worker 'code': None, 330*da0073e9SAndroid Build Coastguard Worker 'obj_weakref': None 331*da0073e9SAndroid Build Coastguard Worker 'guarded_class': None 332*da0073e9SAndroid Build Coastguard Worker } 333*da0073e9SAndroid Build Coastguard Worker global '' DEFAULT_DEVICE 334*da0073e9SAndroid Build Coastguard Worker { 335*da0073e9SAndroid Build Coastguard Worker 'guard_types': None, 336*da0073e9SAndroid Build Coastguard Worker 'code': None, 337*da0073e9SAndroid Build Coastguard Worker 'obj_weakref': None 338*da0073e9SAndroid Build Coastguard Worker 'guarded_class': None 339*da0073e9SAndroid Build Coastguard Worker } 340*da0073e9SAndroid Build Coastguard Worker shape_env '' SHAPE_ENV 341*da0073e9SAndroid Build Coastguard Worker { 342*da0073e9SAndroid Build Coastguard Worker 'guard_types': None, 343*da0073e9SAndroid Build Coastguard Worker 'code': None, 344*da0073e9SAndroid Build Coastguard Worker 'obj_weakref': None 345*da0073e9SAndroid Build Coastguard Worker 'guarded_class': None 346*da0073e9SAndroid Build Coastguard Worker }""", 347*da0073e9SAndroid Build Coastguard Worker ) 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker def test_graph_break(self): 350*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnt) 353*da0073e9SAndroid Build Coastguard Worker def f(x): 354*da0073e9SAndroid Build Coastguard Worker y = x * 2 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker @comptime 357*da0073e9SAndroid Build Coastguard Worker def _(ctx): 358*da0073e9SAndroid Build Coastguard Worker pass 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker return y + 3 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker f(torch.randn(2)) 363*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 364*da0073e9SAndroid Build Coastguard Worker cnt.frame_count = 0 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnt) 367*da0073e9SAndroid Build Coastguard Worker def g(x): 368*da0073e9SAndroid Build Coastguard Worker y = x * 2 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker @comptime 371*da0073e9SAndroid Build Coastguard Worker def _(ctx): 372*da0073e9SAndroid Build Coastguard Worker ctx.graph_break() 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Worker y = y + 2 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker comptime.graph_break() 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker return y * 3 379*da0073e9SAndroid Build Coastguard Worker 380*da0073e9SAndroid Build Coastguard Worker g(torch.randn(2)) 381*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 3) 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker def test_get_local(self): 384*da0073e9SAndroid Build Coastguard Worker global SELF, FILE 385*da0073e9SAndroid Build Coastguard Worker SELF = self 386*da0073e9SAndroid Build Coastguard Worker FILE = StringIO() 387*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(cnt) 390*da0073e9SAndroid Build Coastguard Worker def f(x): 391*da0073e9SAndroid Build Coastguard Worker y = x * 2 392*da0073e9SAndroid Build Coastguard Worker lit = 2 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker @comptime 395*da0073e9SAndroid Build Coastguard Worker def _(ctx): 396*da0073e9SAndroid Build Coastguard Worker y = ctx.get_local("y") 397*da0073e9SAndroid Build Coastguard Worker SELF.assertEqual(y.as_fake().size(0), 2) 398*da0073e9SAndroid Build Coastguard Worker SELF.assertEqual(y.size(0), 2) 399*da0073e9SAndroid Build Coastguard Worker # Trigger a graph write (TODO: this is not so 400*da0073e9SAndroid Build Coastguard Worker # useful right now as there's no way to make use 401*da0073e9SAndroid Build Coastguard Worker # of the output proxy; maybe it's useful for inserting 402*da0073e9SAndroid Build Coastguard Worker # side-effectful operations into the graph) 403*da0073e9SAndroid Build Coastguard Worker y.as_proxy() + 4 404*da0073e9SAndroid Build Coastguard Worker ctx.print_graph(verbose=False, file=FILE) 405*da0073e9SAndroid Build Coastguard Worker SELF.assertIs(y.python_type(), torch.Tensor) 406*da0073e9SAndroid Build Coastguard Worker lit = ctx.get_local("lit") 407*da0073e9SAndroid Build Coastguard Worker SELF.assertEqual(lit.as_python_constant(), 2) 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker return y + 3 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker f(torch.randn(2)) 412*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 413*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 414*da0073e9SAndroid Build Coastguard Worker FILE.getvalue().strip(), 415*da0073e9SAndroid Build Coastguard Worker """\ 416*da0073e9SAndroid Build Coastguard Workerdef forward(self, L_x_ : torch.Tensor): 417*da0073e9SAndroid Build Coastguard Worker l_x_ = L_x_ 418*da0073e9SAndroid Build Coastguard Worker y = l_x_ * 2; l_x_ = None 419*da0073e9SAndroid Build Coastguard Worker add = y + 4; y = add = None""", 420*da0073e9SAndroid Build Coastguard Worker ) 421*da0073e9SAndroid Build Coastguard Worker 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 424*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Worker run_tests() 427