1# Owner(s): ["module: dynamo"] 2import math 3import random 4import unittest 5 6import numpy as np 7 8import torch 9import torch._dynamo.test_case 10import torch._dynamo.testing 11import torch.nn.functional as F 12from torch._dynamo.comptime import comptime 13from torch._dynamo.testing import CompileCounter, same 14from torch.testing._internal.common_utils import skipIfWindows 15from torch.testing._internal.logging_utils import logs_to_string 16 17 18# The intention of this test file is you should put test cases specifically 19# for assume_static_by_default=False, aka you want to YOLO make everything as 20# dynamic as possible. If you want to test the more normal situation where 21# you assume static by default, put it in a regular test file and 22# test_dynamic_shapes will cover both the YOLO and non-YOLO cases. 23 24 25@torch._dynamo.config.patch(assume_static_by_default=False) 26class UnspecTests(torch._dynamo.test_case.TestCase): 27 def test_numpy_correctness(self): 28 def fn(x, y, z): 29 xy = [x + y, y, False] 30 np_x = x.numpy() 31 np_y = y.numpy() 32 return { 33 "x": x, 34 "z": z, 35 "a": np_y.sum(), 36 "b": xy, 37 "c": np_y[0][0] / 68, 38 "d": np_x.sum(), 39 "e": np_x + np_y, 40 }, x + np_y.sum() + z 41 42 x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64) 43 y = torch.ones([2, 2], dtype=torch.int64) 44 z = np.int64(12) 45 res1 = fn(x, y, z) 46 cnts = torch._dynamo.testing.CompileCounter() 47 opt_fn = torch._dynamo.optimize(cnts)(fn) 48 res2 = opt_fn(x, y, z) 49 self.assertEqual(res1, res2) 50 51 def test_no_recompilations(self): 52 # no recompilations if passing on different numpy int values 53 def fn(x, y): 54 return {"a": x + 1, "b": y / 2} 55 56 x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64) 57 cnts = torch._dynamo.testing.CompileCounter() 58 opt_fn = torch._dynamo.optimize(cnts)(fn) 59 for i in range(10): 60 opt_fn(x, np.int64(i)) 61 self.assertEqual(cnts.frame_count, 1) 62 self.assertEqual(cnts.op_count, 2) 63 64 @unittest.expectedFailure # array scalars decay to 0D arrays 65 def test_builtin_max_min(self): 66 # test unspecialized primitive max/min 67 def fn(x, y, z): 68 return z + 1, max(x, y), min(x - 4, y) 69 70 x = np.int64(12) 71 y = 10 72 z = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64) 73 res1 = fn(x, y, z) 74 cnts = torch._dynamo.testing.CompileCounter() 75 opt_fn = torch._dynamo.optimize(cnts)(fn) 76 res2 = opt_fn(x, y, z) 77 self.assertTrue(same(res1, res2, relax_numpy_equality=True)) 78 79 def test_feed_random_values_into_graph_only(self): 80 def fn(shape): 81 torch.manual_seed(123) 82 x = torch.randn(shape, device="cpu") * random.randint(30, 100) 83 return x 84 85 shape = [2, 3] 86 random.seed(1) 87 res1 = fn(shape) 88 cnts = torch._dynamo.testing.CompileCounter() 89 opt_fn = torch._dynamo.optimize(cnts)(fn) 90 random.seed(1) 91 res2 = opt_fn(shape) 92 93 self.assertTrue(same(res1, res2)) 94 95 def test_random_values_with_graph_break(self): 96 def fn(x): 97 r1 = random.random() 98 y = x + random.uniform(10, 20) 99 y.sum().item() 100 r2 = random.randint(2, 18) # no graph output in this frame 101 y.sum().item() 102 return y + r1, r2 103 104 x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) 105 random.seed(1) 106 res1 = fn(x) 107 cnts = torch._dynamo.testing.CompileCounter() 108 opt_fn = torch._dynamo.optimize(cnts)(fn) 109 random.seed(1) 110 res2 = opt_fn(x) 111 self.assertTrue(same(res1, res2)) 112 113 # Really annoying intersection of specialization and RandomValueSource 114 # If we get a RandomValueSource with a single element tensor, we should return a ConstantVariable like other 115 # unspects... but if we do, we break the bytecode assumptions and guards will not work as we will be referring 116 # to a name from a source that is not there. If we call .item() and take the wrapped_value out, where we do 117 # wrapped_value = wrapped_value.item() where we send unspec down to wrap_fx_proxy, this test passes and then 118 # some models fail on missing codegen.tx.output.random_values_var. If we let the tensor value go into wrap as 119 # it is, this test fails. 120 # The real solution here is to rewrite RandomValueSource and all the codegen it does from the ground up. 121 def test_multiple_consecutive_random_calls_before_graph(self): 122 def fn(x): 123 dim1 = random.randrange(start=0, stop=5) 124 dim2 = random.randrange(start=0, stop=5) 125 dim3 = random.randrange(start=0, stop=5) 126 y = torch.rand(dim1, dim2, dim3) 127 return x + 2, y 128 129 x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) 130 random.seed(1) 131 res1 = fn(x) 132 cnts = torch._dynamo.testing.CompileCounter() 133 opt_fn = torch._dynamo.optimize(cnts)(fn) 134 random.seed(1) 135 res2 = opt_fn(x) 136 self.assertTrue(same(res1, res2)) 137 138 def test_compiled_random_calls_are_random(self): 139 # For compiled functions with random calls, 140 # it should return different values for every iteration. 141 # https://github.com/pytorch/pytorch/issues/95425 142 @torch.compile(backend="eager", fullgraph=True) 143 def fn(x): 144 return (x + 1) * random.uniform(0, 1) 145 146 res = [] 147 for _ in range(5): 148 res.append(fn(torch.ones(2))) 149 for i in range(1, 5): 150 self.assertFalse(same(res[i - 1], res[i])) 151 152 def test_random_call_with_while_loop(self): 153 def fn(x): 154 dim1 = random.randrange(start=0, stop=3) 155 dim2 = dim1 156 while dim1 == dim2: 157 dim2 = random.randrange(start=0, stop=3) 158 return x * 2 159 160 x = torch.randn(4) 161 random.seed(1) 162 res1 = fn(x) 163 opt_fn = torch._dynamo.optimize("eager")(fn) 164 random.seed(1) 165 res2 = opt_fn(x) 166 self.assertTrue(same(res1, res2)) 167 168 random.seed(10) 169 res1 = fn(x) 170 random.seed(10) 171 res2 = opt_fn(x) 172 self.assertTrue(same(res1, res2)) 173 174 def test_random_object(self): 175 # test argument passing, mutation, reconstruction, state correctness 176 def fn(x, rand2): 177 r1 = random.randint(1, 9) 178 r2 = rand2.randint(1, 9) 179 rand3 = random.Random(42) 180 r3 = rand3.randint(1, 9) 181 182 y = x + r1 + r2 + r3 183 return y, rand2, rand3 184 185 inp = torch.randn(3, 3) 186 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 187 random.seed(0) 188 y_1, rand2_1, rand3_1 = fn(inp, random.Random(12)) 189 state_1 = random.getstate() 190 random.seed(0) 191 y_2, rand2_2, rand3_2 = opt_fn(inp, random.Random(12)) 192 state_2 = random.getstate() 193 self.assertEqual(y_1, y_2) 194 self.assertEqual(state_1, state_2) 195 self.assertEqual(rand2_1.getstate(), rand2_2.getstate()) 196 self.assertEqual(rand3_1.getstate(), rand3_2.getstate()) 197 198 def test_random_object_methods(self): 199 def fn(x, rand1, rand2, rand3): 200 rand1.seed(42) 201 rand4 = random.Random(9002) 202 rand2.setstate(rand4.getstate()) 203 r1 = rand1.random() 204 r2 = rand2.randint(1, 10) 205 r3 = rand3.randrange(10) 206 r4 = rand4.uniform(0, 1) 207 return x + r1 + r2 + r3 + r4 208 209 inp = torch.randn(3, 3) 210 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 211 rand1_1 = random.Random(1) 212 rand2_1 = random.Random(2) 213 rand3_1 = random.Random(3) 214 rand1_2 = random.Random(1) 215 rand2_2 = random.Random(2) 216 rand3_2 = random.Random(3) 217 y1 = fn(inp, rand1_1, rand2_1, rand3_1) 218 y2 = opt_fn(inp, rand1_2, rand2_2, rand3_2) 219 self.assertEqual(y1, y2) 220 self.assertEqual(rand1_1.getstate(), rand1_2.getstate()) 221 self.assertEqual(rand2_1.getstate(), rand2_2.getstate()) 222 self.assertEqual(rand3_1.getstate(), rand3_2.getstate()) 223 224 def test_random_object_overriden_methods(self): 225 # these will result in graph breaks, but we shouldn't crash 226 def get_rng(): 227 rand1 = random.Random(1) 228 rand2 = random.Random(2) 229 230 orig_random = rand1.random 231 232 def custom_random(): 233 return orig_random() 234 235 orig_getstate = rand2.getstate 236 237 def custom_getstate(): 238 return orig_getstate() 239 240 rand1.random = custom_random 241 rand2.getstate = custom_getstate 242 return rand1, rand2 243 244 def fn(x, rand1, rand2): 245 r1 = rand1.random() 246 rand3 = random.Random() 247 rand3.setstate(rand2.getstate()) 248 r2 = rand3.random() 249 return x + r1 + r2 250 251 inp = torch.randn(3, 3) 252 opt_fn = torch.compile(fn, backend="eager") 253 y1 = fn(inp, *get_rng()) 254 y2 = opt_fn(inp, *get_rng()) 255 self.assertEqual(y1, y2) 256 257 def test_builtin_getitem(self): 258 # builtin getitem args[0] is python list and args[1] is unspec 259 def fn(x, idx): 260 return (torch.zeros(idx), x[idx], x[idx:]) 261 262 x = list(range(50)) 263 ref = fn(x, 48) # 48 is unspecialized 264 cnts = torch._dynamo.testing.CompileCounter() 265 opt_fn = torch._dynamo.optimize(cnts)(fn) 266 res = opt_fn(x, 48) 267 self.assertTrue(same(ref, res)) 268 269 def test_use_and_specialize(self): 270 cnt = CompileCounter() 271 272 @torch.compile(backend=cnt, fullgraph=True, dynamic=True) 273 def fn(x, y): 274 x = x + y 275 if y == 2: 276 return x - 1 277 else: 278 return x + 1 279 280 self.assertTrue(same(fn(torch.tensor([5]), 2), 6)) 281 self.assertTrue(same(fn(torch.tensor([6]), 2), 7)) 282 self.assertTrue(same(fn(torch.tensor([5]), 3), 9)) 283 self.assertTrue(same(fn(torch.tensor([4]), 3), 8)) 284 self.assertEqual(cnt.frame_count, 2) 285 286 def test_no_recompiles(self): 287 cnt = CompileCounter() 288 289 @torch.compile(backend=cnt, fullgraph=True, dynamic=True) 290 def fn(x, y): 291 return x + y 292 293 self.assertTrue(same(fn(torch.tensor([5]), 100), 105)) 294 self.assertTrue(same(fn(torch.tensor([4]), 200), 204)) 295 self.assertTrue(same(fn(torch.tensor([3]), 300), 303)) 296 self.assertTrue(same(fn(torch.tensor([2]), 400), 402)) 297 self.assertEqual(cnt.frame_count, 1) 298 self.assertEqual(cnt.op_count, 1) 299 300 def test_no_recompiles_prod_backward(self): 301 # https://github.com/pytorch/pytorch/issues/120608 302 cnt = CompileCounter() 303 304 @torch.compile(backend=cnt, fullgraph=True, dynamic=True) 305 def fn(t): 306 return torch.prod(t, 3, keepdim=True) 307 308 input_shapes = [(8, 10, 3, 2), (8, 3, 5, 2), (8, 4, 8, 2)] 309 for s in input_shapes: 310 t1 = torch.randn(s, requires_grad=True) 311 h_result = fn(t1) 312 grad = torch.ones_like(h_result) 313 h_result.backward(grad) 314 315 self.assertEqual(cnt.frame_count, 1) 316 self.assertEqual(cnt.op_count, 1) 317 318 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 319 def test_builtin_functions_on_cuda(self): 320 def fn(x, scaler): 321 m = torch.nn.ReLU() 322 y = m(x) * scaler 323 return y 324 325 x = torch.randn([3, 6], device="cuda") 326 scaler = 0.23 # 0.23 is unspecialized 327 ref = fn(x, scaler) 328 cnts = torch._dynamo.testing.CompileCounter() 329 opt_fn = torch._dynamo.optimize(cnts)(fn) 330 res = opt_fn(x, scaler) 331 self.assertTrue(same(ref, res)) 332 self.assertEqual(ref.device, res.device) 333 334 def test_unspec_float_precision(self): 335 def fn(image, scale_factor): 336 image = torch.nn.functional.interpolate( 337 image[None], 338 size=None, 339 scale_factor=scale_factor, 340 mode="bilinear", 341 recompute_scale_factor=True, 342 align_corners=False, 343 )[0] 344 345 return image.shape 346 347 x = torch.rand([3, 427, 640]) 348 scale_factor = 1.873536229133606 349 ref = fn(x, scale_factor) 350 cnts = torch._dynamo.testing.CompileCounter() 351 opt_fn = torch._dynamo.optimize(cnts)(fn) 352 res = opt_fn(x, scale_factor) 353 self.assertTrue(same(ref, res)) 354 355 @unittest.expectedFailure # fails as long as numpy scalars are 0D arrays 356 def test_specializing_numpy_float_in_control_flow(self): 357 # np.float64 is unspecialized by default, 358 # but it should be specialized when used in control flow. 359 def fn(x, y): 360 if y > 1.0: 361 return x + 1 362 else: 363 return x - 1 364 365 x = torch.rand(4) 366 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 367 for t in [np.float16, np.float32, np.float64]: 368 y = t(1.23) 369 ref = fn(x, y) 370 res = opt_fn(x, y) 371 self.assertTrue(same(ref, res)) 372 373 def test_mark_static_inside(self): 374 def fn(x): 375 torch._dynamo.mark_static(x, 0) 376 comptime.assert_static(x.size(0)) 377 return x + 1 378 379 opt_fn = torch.compile(fn, dynamic=True, fullgraph=True) 380 opt_fn(torch.randn(12, 23)) 381 382 def test_shape_graph_break(self): 383 from torch._dynamo.comptime import comptime 384 385 def fn(x): 386 x_shape = x.size() 387 comptime.graph_break() 388 return x + torch.randn(x_shape) 389 390 x = torch.randn(20) 391 opt_fn = torch._dynamo.optimize("eager")(fn) 392 opt_fn(x) 393 394 def test_isinstance_symint(self): 395 def fn(x): 396 assert isinstance(x.size(0), int) 397 return x * 2 398 399 x = torch.randn(20) 400 opt_fn = torch._dynamo.optimize("eager")(fn) 401 opt_fn(x) 402 y = torch.randn(30) 403 torch._dynamo.mark_dynamic(y, 0) 404 opt_fn(y) 405 406 def test_mark_01_dynamic(self): 407 def fn(x): 408 return x * 2 409 410 x = torch.randn(1) 411 torch._dynamo.mark_dynamic(x, 0) 412 opt_fn = torch._dynamo.optimize("eager")(fn) 413 # This will fail to compile a generic kernel, but we should not 414 # complain about it (mark dynamic will try its best but 0/1 415 # specialization is allowed) 416 opt_fn(x) 417 418 def test_conv1d_symint_padding(self): 419 kernel = torch.randn(1, 1, 4) 420 421 def func(x): 422 padding = math.ceil((kernel.shape[-1] + x.shape[-1] % 2) / 2) - 1 423 out = F.conv1d(x, kernel, padding=padding, stride=2) 424 return out 425 426 opt_func = torch.compile(func) 427 428 x = torch.randn(1, 1, 175) 429 opt_func(x) # passes 430 x = torch.randn(1, 1, 249) 431 opt_func(x) # crashes 432 433 @torch._dynamo.config.patch("assume_static_by_default", True) 434 def test_propagate_dynamic_dim(self): 435 x = torch.randn(20) 436 torch._dynamo.mark_dynamic(x, 0) 437 438 @torch.compile() 439 def fn(x): 440 y = x * 2 441 comptime.graph_break() 442 z = y * 2 443 return z 444 445 z = fn(x) 446 self.assertEqual(z._dynamo_weak_dynamic_indices, {0}) 447 448 def test_rshift_dynamic(self): 449 def shift_right(tensor: torch.Tensor) -> torch.Tensor: 450 return (tensor >> 2).to(torch.long) 451 452 opt_fn = torch.compile(shift_right, fullgraph=True, dynamic=True) 453 sample_input = torch.tensor([4, 4, 16, 32], dtype=torch.uint8) 454 opt_fn(sample_input) 455 456 @torch._dynamo.config.patch(capture_scalar_outputs=True) 457 def test_symfloat_to_tensor(self): 458 def f1(v): 459 return torch.tensor([v.item()]) 460 461 def f2(v): 462 return torch.tensor([[v.item()], [2.0]]) 463 464 def f3(v): 465 return torch.tensor(v.item()) 466 467 def f4(v): 468 return torch.tensor((v.item(),)) 469 470 optimize = torch.compile(backend="aot_eager", fullgraph=True) 471 472 r = torch.randn(1) 473 474 self.assertEqual(f1(r), optimize(f1)(r)) 475 self.assertEqual(f2(r), optimize(f2)(r)) 476 self.assertEqual(f3(r), optimize(f3)(r)) 477 self.assertEqual(f4(r), optimize(f4)(r)) 478 479 @skipIfWindows( 480 msg="AssertionError: The values for attribute 'dtype' do not match: torch.int32 != torch.int64." 481 ) 482 def test_to_tensor(self): 483 def f1(): 484 a = np.random.uniform(low=-1, high=1, size=(20, 1)) 485 return torch.tensor([a, a, a, a], dtype=torch.float64, device="cpu") 486 487 def f2(): 488 a = torch.tensor([[[123]]]) 489 return torch.tensor([a, a]) 490 491 def f3(): 492 a = torch.tensor(123) 493 return torch.tensor([a, a]) 494 495 def f4(): 496 a = torch.tensor(123) 497 b = torch.tensor([[[456]]]) 498 return torch.tensor([a, b]) 499 500 def f5(): 501 a = np.array([1, 2]) 502 return torch.tensor([a, a]) 503 504 optimize = torch.compile(backend="aot_eager", fullgraph=True) 505 506 self.assertEqual(f1().shape, optimize(f1)().shape) 507 self.assertEqual(f2(), optimize(f2)()) 508 self.assertEqual(f3(), optimize(f3)()) 509 self.assertEqual(f4(), optimize(f4)()) 510 self.assertEqual(f5(), optimize(f5)()) 511 512 def test_sym_int_conversion(self): 513 def f(x): 514 y = x.size(0) 515 return x * int(y == 0) 516 517 opt_fn = torch.compile(f, backend="eager", fullgraph=True) 518 x = torch.randn(2, 3) 519 opt_fn(x) 520 521 def test_sum_dimlist_spec(self): 522 def fn(inputs, dim): 523 return torch.sum(inputs, dim) 524 525 inputs = torch.randn(128, 5, 24, 24) 526 dim = (-1, 1, 0, 2) 527 compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True) 528 self.assertEqual(compl_fn(inputs, dim), fn(inputs, dim)) 529 530 @torch._dynamo.config.patch(capture_scalar_outputs=True) 531 def test_item_max(self): 532 def fn(x): 533 return torch.ones(max(x.item(), 1024)) 534 535 x = torch.tensor([1000]) 536 y = torch.tensor([2000]) 537 compl_fn = torch.compile(fn, backend="eager", fullgraph=True) 538 self.assertEqual(fn(x), compl_fn(x)) 539 self.assertEqual(fn(y), compl_fn(y)) 540 541 # https://github.com/pytorch/pytorch/issues/104812 542 def test_argmin_coerces_symint_to_intlist_spec(self): 543 def fn(x, dim): 544 # the python arg parser coerces dim into a vector<int> 545 return torch.amin(x, dim=dim, keepdim=True) 546 547 x = torch.randn(4, 4, 4) 548 dim = 2 549 compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True) 550 self.assertEqual(compl_fn(x, dim), fn(x, dim)) 551 552 def test_exponential(self): 553 def fn(inputs, op_inputs_dict): 554 res = inputs.exponential_(**op_inputs_dict) 555 return res 556 557 inputs = torch.randn(2, 3, 4) 558 op_inputs_dict = {"lambd": 10, "generator": None} 559 compl_fn = torch.compile(fn, dynamic=True, backend="eager", fullgraph=True) 560 self.assertEqual(compl_fn(inputs, op_inputs_dict), fn(inputs, op_inputs_dict)) 561 562 def test_symbol_guard_limit_before_specialize(self): 563 cnts = torch._dynamo.testing.CompileCounter() 564 565 @torch._dynamo.optimize(cnts, dynamic=True) 566 def fn(x): 567 torch._check(x.size(0) != 3) 568 torch._check(x.size(0) != 4) 569 torch._check(x.size(0) != 5) 570 torch._check(x.size(0) != 6) 571 return x + 2 572 573 # Control test 574 fn(torch.randn(12)) 575 fn(torch.randn(13)) 576 fn(torch.randn(14)) 577 578 self.assertExpectedInline(cnts.frame_count, """1""") 579 cnts.frame_count = 0 580 581 torch._dynamo.reset() 582 583 with torch.fx.experimental._config.patch( 584 symbol_guard_limit_before_specialize=3 585 ): 586 fn(torch.randn(12)) 587 fn(torch.randn(13)) 588 fn(torch.randn(14)) 589 590 self.assertExpectedInline(cnts.frame_count, """3""") 591 592 def test_defaults(self): 593 def g(x, i=8): 594 comptime.assert_static(i) 595 return x * i 596 597 def fn(x): 598 return g(x) 599 600 inputs = torch.randn(2, 3, 4) 601 compl_fn = torch.compile(fn, dynamic=True, backend="eager") 602 self.assertEqual(compl_fn(inputs), fn(inputs)) 603 604 @torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=True) 605 def test_unspec_float_input(self): 606 cnts = torch._dynamo.testing.CompileCounter() 607 608 def f(x, y): 609 if y == 5.0: 610 return x + 2 611 else: 612 return x + y 613 614 cf = torch.compile(backend=cnts, fullgraph=True)(f) 615 616 x = torch.randn(3) 617 self.assertEqual(f(x, 3.0), cf(x, 3.0)) 618 self.assertEqual(f(x, 4.0), cf(x, 4.0)) 619 self.assertExpectedInline(cnts.frame_count, """1""") # no recompile 620 self.assertEqual(f(x, 5.0), cf(x, 5.0)) 621 self.assertExpectedInline(cnts.frame_count, """2""") # guard worked 622 self.assertEqual(f(x, math.nan), cf(x, math.nan)) 623 self.assertExpectedInline(cnts.frame_count, """3""") # nan always recompiles 624 625 @torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=True) 626 def test_unspec_float_output(self): 627 cnts = torch._dynamo.testing.CompileCounter() 628 629 def f(x, y): 630 return x + 1, y * 2 631 632 cf = torch.compile(backend=cnts, fullgraph=True)(f) 633 x = torch.randn(3) 634 635 self.assertEqual(f(x, 3.0), cf(x, 3.0)) 636 self.assertEqual(f(x, 4.0), cf(x, 4.0)) 637 self.assertEqual(f(x, 5.0), cf(x, 5.0)) 638 639 @torch._dynamo.config.patch(capture_scalar_outputs=True) 640 def test_data_dependent_evaluate_expr_graph_break(self): 641 cnts = torch._dynamo.testing.CompileCounter() 642 643 # To ensure that the continuation frame is compiled, 644 # have to write the test function in this funny way. 645 # See https://github.com/pytorch/pytorch/issues/111918 646 def test(y): 647 if y > 2: 648 return True 649 else: 650 return False 651 652 @torch._dynamo.optimize(cnts) 653 def fn(x): 654 x = x + 1 655 y = x.item() 656 if test(y): 657 return x * 2 658 else: 659 return x * 3 660 661 x = torch.tensor([3.0]) 662 fn(x) 663 664 self.assertExpectedInline(cnts.frame_count, """2""") 665 self.assertExpectedInline(cnts.op_count, """4""") 666 667 def test_prune_torch_check(self): 668 log_stream, ctx = logs_to_string("torch._dynamo.output_graph", "graph_code") 669 670 @torch.compile(fullgraph=True, dynamic=True, backend="eager") 671 def f(x, y): 672 torch._check(y + 5 == 85) 673 torch._check(x.size(0) == 80) 674 675 with ctx(): 676 f(torch.randn(80, 100), 80) 677 678 out = "\n".join(log_stream.getvalue().strip().split("\n")[3:]).strip() 679 self.assertExpectedInline( 680 out, 681 """\ 682def forward(self): 683 return ()""", 684 ) 685 686 @torch._dynamo.config.patch(capture_scalar_outputs=True) 687 def test_split_aot_autograd(self): 688 @torch.compile(backend="aot_eager", fullgraph=True) 689 def f(x, i): 690 y, z = i.tolist() 691 return torch.split(x, [y, z]) 692 693 print(f(torch.randn(10, requires_grad=True), torch.tensor([7, 3]))) 694 695 def test_bool_tensor_ctor(self): 696 cnts = torch._dynamo.testing.CompileCounter() 697 698 @torch.compile(backend=cnts, dynamic=True, fullgraph=True) 699 def f(x): 700 y = torch.empty((x.size(0) // 13) * 13) 701 return torch.tensor(y.numel() == 0) 702 703 self.assertTrue(f(torch.empty(8)).item()) 704 self.assertFalse(f(torch.empty(13)).item()) 705 706 @torch._dynamo.config.patch(error_on_recompile=True) 707 def test_mark_unbacked(self): 708 class TestModel(torch.nn.Module): 709 def __init__( 710 self, 711 ): 712 super().__init__() 713 714 def forward(self, x: torch.Tensor, val: int) -> torch.Tensor: 715 return x * 2 716 717 main_model = TestModel() 718 opt_model = torch.compile(main_model, mode="max-autotune", dynamic=True) 719 720 x1 = torch.rand(3, 5, 4, 8) 721 x2 = torch.rand(1, 5, 4, 8) 722 723 torch._dynamo.decorators.mark_unbacked(x1, 0) 724 725 o1_ref = main_model(x1, 2) 726 o1 = opt_model(x1, 2) 727 self.assertEqual(o1_ref, o1) 728 729 o1_2_ref = main_model(x2, 2) 730 o1_2 = opt_model(x2, 2) 731 self.assertEqual(o1_2_ref, o1_2) 732 733 @torch._dynamo.config.patch(error_on_recompile=True) 734 def test_mark_unbacked_hint_consistency(self): 735 from torch.fx.experimental.symbolic_shapes import guard_size_oblivious 736 737 x = torch.randn(1) 738 torch._dynamo.decorators.mark_unbacked(x, 0) 739 740 @torch.compile() 741 def f(x): 742 if guard_size_oblivious(x.size(0) != 1): 743 return x + 3 744 else: 745 return x + 4 746 747 self.assertEqual(f(x), x + 3) 748 749 @torch._dynamo.config.patch(error_on_recompile=True) 750 def test_mark_unbacked_channels_last(self): 751 class TestModel(torch.nn.Module): 752 def __init__( 753 self, 754 ): 755 super().__init__() 756 757 def forward(self, x: torch.Tensor, val: int) -> torch.Tensor: 758 return x * 2 759 760 main_model = TestModel() 761 opt_model = torch.compile(main_model, mode="max-autotune", dynamic=True) 762 763 x1 = torch.rand(3, 5, 4, 8).to(memory_format=torch.channels_last) 764 x2 = torch.rand(1, 5, 4, 8).to(memory_format=torch.channels_last) 765 766 torch._dynamo.decorators.mark_unbacked(x1, 0) 767 768 o1_ref = main_model(x1, 2) 769 o1 = opt_model(x1, 2) 770 self.assertEqual(o1_ref, o1) 771 772 o1_2_ref = main_model(x2, 2) 773 o1_2 = opt_model(x2, 2) 774 self.assertEqual(o1_2_ref, o1_2) 775 776 777if __name__ == "__main__": 778 from torch._dynamo.test_case import run_tests 779 780 run_tests() 781