1# Owner(s): ["module: dynamo"] 2 3import collections 4import re 5import sys 6import time 7from io import StringIO 8 9import torch._dynamo.test_case 10import torch._dynamo.testing 11from torch._dynamo.comptime import comptime 12 13 14# Because we don't support free variables in comptime at the moment, 15# we have to communicate via globals. This also means these tests cannot 16# be run in parallel in a single process (not that you'd... ever want 17# to do that?) 18FILE = None 19SELF = None 20 21 22class ComptimeTests(torch._dynamo.test_case.TestCase): 23 def test_print_single(self): 24 global FILE 25 FILE = StringIO() 26 cnt = torch._dynamo.testing.CompileCounter() 27 28 def comptime_print(e): 29 @comptime 30 def _(ctx): 31 ctx.print(ctx.get_local("e"), file=FILE) 32 33 Employee = collections.namedtuple("Employee", ["name", "id"]) 34 35 class mylist(list): 36 pass 37 38 @torch._dynamo.optimize(cnt, dynamic=True) 39 def f(x): 40 y = x * 2 41 comptime_print(y) 42 comptime_print(2) 43 comptime_print([y, 2]) 44 comptime_print((y, 2)) 45 comptime_print({"foo": y}) 46 comptime_print(range(1, 3)) 47 comptime_print(Employee("foo", 2)) 48 comptime_print(mylist([1, 2])) 49 comptime_print(collections.defaultdict(lambda: None)) 50 comptime_print(set()) 51 comptime_print({"a", "b"}) 52 comptime_print(x.size(0)) 53 return y + 3 54 55 f(torch.randn(2)) 56 self.assertEqual(cnt.frame_count, 1) 57 self.assertExpectedInline( 58 FILE.getvalue().strip(), 59 """\ 60FakeTensor(..., size=(s0,)) 612 62[FakeTensor(..., size=(s0,)), 2] 63(FakeTensor(..., size=(s0,)), 2) 64{'foo': FakeTensor(..., size=(s0,))} 65range(1, 3, 1) 66Employee(name='foo', id=2) 67[1, 2] 68defaultdict(NestedUserFunctionVariable(), {}) 69set() 70{'a','b'} 71s0""", 72 ) 73 74 def test_print_graph(self): 75 global FILE 76 FILE = StringIO() 77 cnt = torch._dynamo.testing.CompileCounter() 78 79 @torch._dynamo.optimize(cnt) 80 def f(x): 81 y = x * 2 82 83 @comptime 84 def _(ctx): 85 ctx.print_graph(verbose=False, file=FILE) 86 87 # Test the compact notation doesn't error or graph break; 88 # you'll have to visually inspect to see that it printed 89 comptime.print_graph() 90 91 return y + 3 92 93 f(torch.randn(2)) 94 self.assertEqual(cnt.frame_count, 1) 95 self.assertExpectedInline( 96 FILE.getvalue().strip(), 97 """\ 98def forward(self, L_x_ : torch.Tensor): 99 l_x_ = L_x_ 100 y = l_x_ * 2; l_x_ = y = None""", 101 ) 102 103 def test_print_disas(self): 104 global FILE 105 FILE = StringIO() 106 cnt = torch._dynamo.testing.CompileCounter() 107 108 @torch._dynamo.optimize(cnt) 109 def f(x): 110 y = x * 2 111 112 @comptime 113 def _(ctx): 114 ctx.print_disas(file=FILE) 115 116 comptime.print_disas() 117 118 return y + 3 119 120 def munge_disas(s): 121 re.sub( 122 r"^(?: +\d+)?(?: +(-->)) \+\d+ ([A-Za-z0-9_]+)", 123 "\1 \3", 124 s, 125 flags=re.MULTILINE, 126 ) 127 128 f(torch.randn(2)) 129 self.assertEqual(cnt.frame_count, 1) 130 out = FILE.getvalue() 131 # Check that the instruction offset is working 132 self.assertIn("-->", out) 133 # Check that the bytecode resembles what we expect 134 self.assertIn("STORE_FAST", out) 135 if sys.version_info < (3, 11): 136 self.assertIn("BINARY_MULTIPLY", out) 137 else: 138 self.assertIn("BINARY_OP", out) 139 140 def test_print_value_stack(self): 141 global FILE 142 FILE = StringIO() 143 cnt = torch._dynamo.testing.CompileCounter() 144 145 def g(x): 146 @comptime 147 def _(ctx): 148 ctx.print_value_stack(file=FILE, stacklevel=1) 149 150 return x 151 152 @torch._dynamo.optimize(cnt) 153 def f(x): 154 y = x + g(x) 155 156 return y + comptime.print_value_stack_and_return(y * 2) 157 158 f(torch.randn(2)) 159 self.assertEqual(cnt.frame_count, 1) 160 self.assertExpectedInline( 161 FILE.getvalue(), 162 """\ 163- FakeTensor(..., size=(2,)) 164""", 165 ) 166 167 def test_print_locals(self): 168 global FILE 169 FILE = StringIO() 170 cnt = torch._dynamo.testing.CompileCounter() 171 172 @torch._dynamo.optimize(cnt) 173 def f(x): 174 y = x * 2 175 176 @comptime 177 def _(ctx): 178 ctx.print_locals(file=FILE) 179 180 comptime.print_locals() 181 182 return y + 3 183 184 f(torch.randn(2)) 185 self.assertEqual(cnt.frame_count, 1) 186 self.assertExpectedInline( 187 FILE.getvalue(), 188 """\ 189x = FakeTensor(..., size=(2,)) 190y = FakeTensor(..., size=(2,)) 191""", 192 ) 193 194 # Just make sure it doesn't crash 195 def test_print_direct(self): 196 cnt = torch._dynamo.testing.CompileCounter() 197 198 @torch._dynamo.optimize(cnt) 199 def f(x, z): 200 y = x * 2 201 lambda: z 202 comptime.print(z) 203 return y + 3 204 205 f(torch.randn(2), torch.randn(2)) 206 207 def test_sleep(self): 208 sleep_time = 5 209 cnt = torch._dynamo.testing.CompileCounter() 210 211 @torch._dynamo.optimize(cnt) 212 def f(x, z, should_sleep): 213 if should_sleep: 214 comptime.sleep(sleep_time) 215 y = x * 2 216 return y + 3 217 218 start = time.time() 219 f(torch.randn(2), torch.randn(2), False) 220 total_no_sleep = time.time() - start 221 222 start = time.time() 223 f(torch.randn(2), torch.randn(2), True) 224 total_with_sleep = time.time() - start 225 226 self.assertTrue(total_with_sleep > sleep_time) 227 # Hopefully this won't be flaky 228 self.assertTrue(abs(total_with_sleep - sleep_time - total_no_sleep) < 3) 229 230 # Just make sure it doesn't crash 231 def test_get_local_closure_variable(self): 232 global SELF 233 SELF = self 234 cnt = torch._dynamo.testing.CompileCounter() 235 236 @torch._dynamo.optimize(cnt) 237 def f(x): 238 z = 3 239 240 def g(): 241 @comptime 242 def _(ctx): 243 r = ctx.get_local("z") 244 SELF.assertEqual(repr(r), "3") 245 246 comptime.print(z) 247 return 2 248 249 y = x * g() 250 return y + 3 251 252 f(torch.randn(2)) 253 254 def test_print_bt(self): 255 global FILE 256 FILE = StringIO() 257 cnt = torch._dynamo.testing.CompileCounter() 258 259 def g(x): 260 @comptime 261 def _(ctx): 262 ctx.print_bt(file=FILE) 263 264 comptime.print_bt() 265 266 return x + 3 267 268 @torch._dynamo.optimize(cnt) 269 def f(x): 270 y = x * 2 271 y = g(y) 272 return y + 3 273 274 def munge_filenames(s): 275 return re.sub(r'File "[^"]+", line \d+', 'File "X", line X', s) 276 277 f(torch.randn(2)) 278 self.assertEqual(cnt.frame_count, 1) 279 bt = FILE.getvalue() 280 self.assertIn("y = g(y)", bt) 281 282 def test_print_guards(self): 283 global FILE 284 FILE = StringIO() 285 cnt = torch._dynamo.testing.CompileCounter() 286 287 @torch._dynamo.optimize(cnt) 288 def f(x): 289 y = x * 2 290 291 @comptime 292 def _(ctx): 293 ctx.print_guards(file=FILE) 294 295 comptime.print_guards() 296 297 return y + 3 298 299 f(torch.randn(2)) 300 self.assertEqual(cnt.frame_count, 1) 301 self.assertExpectedInline( 302 re.sub(r"\s+$", "", FILE.getvalue().rstrip(), flags=re.MULTILINE), 303 """\ 304 305 local "L['x']" TENSOR_MATCH 306 { 307 'guard_types': None, 308 'code': None, 309 'obj_weakref': None 310 'guarded_class': None 311 } 312 global '' GRAD_MODE 313 { 314 'guard_types': None, 315 'code': None, 316 'obj_weakref': None 317 'guarded_class': None 318 } 319 global '' DETERMINISTIC_ALGORITHMS 320 { 321 'guard_types': None, 322 'code': None, 323 'obj_weakref': None 324 'guarded_class': None 325 } 326 global '' TORCH_FUNCTION_STATE 327 { 328 'guard_types': None, 329 'code': None, 330 'obj_weakref': None 331 'guarded_class': None 332 } 333 global '' DEFAULT_DEVICE 334 { 335 'guard_types': None, 336 'code': None, 337 'obj_weakref': None 338 'guarded_class': None 339 } 340 shape_env '' SHAPE_ENV 341 { 342 'guard_types': None, 343 'code': None, 344 'obj_weakref': None 345 'guarded_class': None 346 }""", 347 ) 348 349 def test_graph_break(self): 350 cnt = torch._dynamo.testing.CompileCounter() 351 352 @torch._dynamo.optimize(cnt) 353 def f(x): 354 y = x * 2 355 356 @comptime 357 def _(ctx): 358 pass 359 360 return y + 3 361 362 f(torch.randn(2)) 363 self.assertEqual(cnt.frame_count, 1) 364 cnt.frame_count = 0 365 366 @torch._dynamo.optimize(cnt) 367 def g(x): 368 y = x * 2 369 370 @comptime 371 def _(ctx): 372 ctx.graph_break() 373 374 y = y + 2 375 376 comptime.graph_break() 377 378 return y * 3 379 380 g(torch.randn(2)) 381 self.assertEqual(cnt.frame_count, 3) 382 383 def test_get_local(self): 384 global SELF, FILE 385 SELF = self 386 FILE = StringIO() 387 cnt = torch._dynamo.testing.CompileCounter() 388 389 @torch._dynamo.optimize(cnt) 390 def f(x): 391 y = x * 2 392 lit = 2 393 394 @comptime 395 def _(ctx): 396 y = ctx.get_local("y") 397 SELF.assertEqual(y.as_fake().size(0), 2) 398 SELF.assertEqual(y.size(0), 2) 399 # Trigger a graph write (TODO: this is not so 400 # useful right now as there's no way to make use 401 # of the output proxy; maybe it's useful for inserting 402 # side-effectful operations into the graph) 403 y.as_proxy() + 4 404 ctx.print_graph(verbose=False, file=FILE) 405 SELF.assertIs(y.python_type(), torch.Tensor) 406 lit = ctx.get_local("lit") 407 SELF.assertEqual(lit.as_python_constant(), 2) 408 409 return y + 3 410 411 f(torch.randn(2)) 412 self.assertEqual(cnt.frame_count, 1) 413 self.assertExpectedInline( 414 FILE.getvalue().strip(), 415 """\ 416def forward(self, L_x_ : torch.Tensor): 417 l_x_ = L_x_ 418 y = l_x_ * 2; l_x_ = None 419 add = y + 4; y = add = None""", 420 ) 421 422 423if __name__ == "__main__": 424 from torch._dynamo.test_case import run_tests 425 426 run_tests() 427